forked from yixinliu233/ARC
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
107 lines (94 loc) · 3.65 KB
/
main.py
File metadata and controls
107 lines (94 loc) · 3.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import argparse
from utils import *
import warnings
from train_test import ARCDetector
import numpy as np
def set_seed(seed):
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
warnings.filterwarnings("ignore")
parser = argparse.ArgumentParser()
parser.add_argument('--trials', type=int, default=5)
parser.add_argument('--model', type=str, default='ARC')
parser.add_argument('--shot', type=int, default=10)
parser.add_argument('--json_dir', type=str, default='./params')
args = parser.parse_known_args()[0]
datasets_test = ['cora', 'citeseer', 'ACM', 'BlogCatalog', 'Facebook', 'weibo', 'Reddit', 'Amazon', 'cs', 'photo', 'tolokers', 'tfinance']
datasets_train = ['pubmed', 'Flickr', 'questions', 'YelpChi']
model = args.model
model_result = {'name': model}
print('Training on {} datasets:'.format(len(datasets_train)), datasets_train)
print('Test on {} datasets:'.format(len(datasets_test)), datasets_test)
train_config = {
'device': 'cuda:0',
'epochs': 40,
'testdsets': datasets_test,
}
dims = 64
data_train = [Dataset(dims, name) for name in datasets_train]
data_test = [Dataset(dims, name) for name in datasets_test] # CPU
model_config = read_json(model, args.shot, args.json_dir)
if model_config is None:
model_config = {
"model": "ARC",
"lr": 1e-5,
"drop_rate": 0,
"h_feats": 1024,
"num_prompt": 10,
"num_hops": 2,
"weight_decay": 5e-5,
"in_feats": 64,
"num_layers": 4,
"activation": "ELU"
}
print('use default model config')
else:
print('use saved best model config')
print(model_config)
for tr_data in data_train:
tr_data.propagated(model_config['num_hops'])
for te_data in data_test:
te_data.propagated(model_config['num_hops'])
model_config['model'] = model
model_config['in_feats'] = dims
# Initialize dictionaries to store scores for each test dataset
auc_dict = {}
pre_dict = {}
for t in range(args.trials):
seed = t
set_seed(seed)
print("Model {}, Trial {}".format(model, seed))
train_config['seed'] = seed
for te_data in data_test:
te_data.few_shot(args.shot)
data = {'train': data_train, 'test': data_test}
detector = ARCDetector(train_config, model_config, data)
test_score_list = detector.train()
# Aggregate scores for each test dataset
for test_data_name, test_score in test_score_list.items():
if test_data_name not in auc_dict:
auc_dict[test_data_name] = []
pre_dict[test_data_name] = []
auc_dict[test_data_name].append(test_score['AUROC'])
pre_dict[test_data_name].append(test_score['AUPRC'])
print(f'Test on {test_data_name}, AUC is {auc_dict[test_data_name]}')
# Calculate mean and standard deviation for each test dataset
auc_mean_dict, auc_std_dict, pre_mean_dict, pre_std_dict = {}, {}, {}, {}
for test_data_name in auc_dict:
auc_mean_dict[test_data_name] = np.mean(auc_dict[test_data_name])
auc_std_dict[test_data_name] = np.std(auc_dict[test_data_name])
pre_mean_dict[test_data_name] = np.mean(pre_dict[test_data_name])
pre_std_dict[test_data_name] = np.std(pre_dict[test_data_name])
# Output the results for each test dataset
for test_data_name in auc_mean_dict:
str_result = 'AUROC:{:.4f}+-{:.4f}, AUPRC:{:.4f}+-{:.4f}'.format(
auc_mean_dict[test_data_name],
auc_std_dict[test_data_name],
pre_mean_dict[test_data_name],
pre_std_dict[test_data_name])
print('-' * 50 + test_data_name + '-' * 50)
print('str_result', str_result)