-
Notifications
You must be signed in to change notification settings - Fork 30
/
train.py
316 lines (285 loc) · 15 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
#Common imports
import os
import sys
import numpy as np
import argparse
import copy
import random
import json
import pickle
#Sklearn
import sklearn
from sklearn.manifold import TSNE
#Pytorch
import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils
#robustdg
from utils.helper import *
from utils.match_function import *
# Input Parsing
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_name', type=str, default='rot_mnist',
help='Datasets: rot_mnist; fashion_mnist; pacs')
parser.add_argument('--method_name', type=str, default='erm_match',
help=' Training Algorithm: erm_match; matchdg_ctr; matchdg_erm')
parser.add_argument('--model_name', type=str, default='resnet18',
help='Architecture of the model to be trained')
parser.add_argument('--train_domains', nargs='+', type=str, default=["15", "30", "45", "60", "75"],
help='List of train domains')
parser.add_argument('--test_domains', nargs='+', type=str, default=["0", "90"],
help='List of test domains')
parser.add_argument('--out_classes', type=int, default=10,
help='Total number of classes in the dataset')
parser.add_argument('--img_c', type=int, default= 1,
help='Number of channels of the image in dataset')
parser.add_argument('--img_h', type=int, default= 224,
help='Height of the image in dataset')
parser.add_argument('--img_w', type=int, default= 224,
help='Width of the image in dataset')
parser.add_argument('--fc_layer', type=int, default= 1,
help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet')
parser.add_argument('--match_layer', type=str, default='logit_match',
help='rep_match: Matching at an intermediate representation level; logit_match: Matching at the logit level')
parser.add_argument('--pos_metric', type=str, default='l2',
help='Cost to function to evaluate distance between two representations; Options: l1; l2; cos')
parser.add_argument('--rep_dim', type=int, default=250,
help='Representation dimension for contrsative learning')
parser.add_argument('--pre_trained',type=int, default=0,
help='0: No Pretrained Architecture; 1: Pretrained Architecture')
parser.add_argument('--perfect_match', type=int, default=1,
help='0: No perfect match known (PACS); 1: perfect match known (MNIST)')
parser.add_argument('--opt', type=str, default='sgd',
help='Optimizer Choice: sgd; adam')
parser.add_argument('--weight_decay', type=float, default=5e-4,
help='Weight Decay in SGD')
parser.add_argument('--lr', type=float, default=0.01,
help='Learning rate for training the model')
parser.add_argument('--batch_size', type=int, default=16,
help='Batch size foe training the model')
parser.add_argument('--epochs', type=int, default=15,
help='Total number of epochs for training the model')
parser.add_argument('--penalty_s', type=int, default=-1,
help='Epoch threshold over which Matching Loss to be optimised')
parser.add_argument('--penalty_irm', type=float, default=0.0,
help='Penalty weight for IRM invariant classifier loss')
parser.add_argument('--penalty_aug', type=float, default=1.0,
help='Penalty weight for Augmentation in Hybrid approach loss')
parser.add_argument('--penalty_ws', type=float, default=0.1,
help='Penalty weight for Matching Loss')
parser.add_argument('--penalty_diff_ctr',type=float, default=1.0,
help='Penalty weight for Contrastive Loss')
parser.add_argument('--tau', type=float, default=0.05,
help='Temperature hyper param for NTXent contrastive loss ')
parser.add_argument('--match_flag', type=int, default=0,
help='0: No Update to Match Strategy; 1: Updates to Match Strategy')
parser.add_argument('--match_case', type=float, default=1.0,
help='0: Random Match; 1: Perfect Match. 0.x" x% correct Match')
parser.add_argument('--match_interrupt', type=int, default=5,
help='Number of epochs before inferring the match strategy')
parser.add_argument('--ctr_abl', type=int, default=0,
help='0: Randomization til class level ; 1: Randomization completely')
parser.add_argument('--match_abl', type=int, default=0,
help='0: Randomization til class level ; 1: Randomization completely')
parser.add_argument('--n_runs', type=int, default=3,
help='Number of iterations to repeat the training process')
parser.add_argument('--n_runs_matchdg_erm', type=int, default=1,
help='Number of iterations to repeat training process for matchdg_erm')
parser.add_argument('--ctr_model_name', type=str, default='resnet18',
help='(For matchdg_ctr phase) Architecture of the model to be trained')
parser.add_argument('--ctr_match_layer', type=str, default='logit_match',
help='(For matchdg_ctr phase) rep_match: Matching at an intermediate representation level; logit_match: Matching at the logit level')
parser.add_argument('--ctr_match_flag', type=int, default=1,
help='(For matchdg_ctr phase) 0: No Update to Match Strategy; 1: Updates to Match Strategy')
parser.add_argument('--ctr_match_case', type=float, default=0.01,
help='(For matchdg_ctr phase) 0: Random Match; 1: Perfect Match. 0.x" x% correct Match')
parser.add_argument('--ctr_match_interrupt', type=int, default=5,
help='(For matchdg_ctr phase) Number of epochs before inferring the match strategy')
parser.add_argument('--mnist_seed', type=int, default=0,
help='Change it between 0-6 for different subsets of Mnist and Fashion Mnist dataset')
parser.add_argument('--retain', type=float, default=0,
help='0: Train from scratch in MatchDG Phase 2; 2: Finetune from MatchDG Phase 1 in MatchDG is Phase 2')
parser.add_argument('--cuda_device', type=int, default=0,
help='Select the cuda device by id among the avaliable devices' )
parser.add_argument('--os_env', type=int, default=0,
help='0: Code execution on local server/machine; 1: Code execution in docker/clusters' )
#Differential Privacy
parser.add_argument('--dp_noise', type=int, default=0,
help='0: No DP noise; 1: Add DP noise')
parser.add_argument('--dp_epsilon', type=float, default=1.0,
help='Epsilon value for Differential Privacy')
# Special case when you want to check results with the dp setting for the infinite epsilon case
parser.add_argument('--dp_attach_opt', type=int, default=1,
help='0: Infinite Epsilon; 1: Finite Epsilion')
#MMD, DANN
parser.add_argument('--d_steps_per_g_step', type=int, default=1)
parser.add_argument('--grad_penalty', type=float, default=0.0)
parser.add_argument('--conditional', type=int, default=1)
parser.add_argument('--gaussian', type=int, default=1)
#Slab Dataset
parser.add_argument('--slab_data_dim', type=int, default= 2,
help='Number of features in the slab dataset')
parser.add_argument('--slab_total_slabs', type=int, default=7)
parser.add_argument('--slab_num_samples', type=int, default=1000)
parser.add_argument('--slab_noise', type=float, default=0.1)
#Differentiate between resnet, lenet, domainbed cases of mnist
parser.add_argument('--mnist_case', type=str, default='resnet18',
help='MNIST Dataset Case: resnet18; lenet, domainbed')
parser.add_argument('--mnist_aug', type=int, default=0,
help='MNIST Data Augmentation: 0 (MNIST, FMNIST Privacy Evaluation); 1 (FMNIST)')
#Multiple random matches
parser.add_argument('--total_matches_per_point', type=int, default=1,
help='Multiple random matches')
# Evaluation specific
parser.add_argument('--test_metric', type=str, default='match_score',
help='Evaluation Metrics: acc; match_score, t_sne, mia')
parser.add_argument('--acc_data_case', type=str, default='test',
help='Dataset Train/Val/Test for the accuracy evaluation metric')
parser.add_argument('--top_k', type=int, default=10,
help='Top K matches to consider for the match score evaluation metric')
parser.add_argument('--match_func_aug_case', type=int, default=0,
help='0: Evaluate match func on train domains; 1: Evaluate match func on self augmentations')
parser.add_argument('--match_func_data_case', type=str, default='val',
help='Dataset Train/Val/Test for the match score evaluation metric')
args = parser.parse_args()
#GPU
cuda= torch.device("cuda:" + str(args.cuda_device))
if cuda:
kwargs = {'num_workers': 0, 'pin_memory': False}
else:
kwargs= {}
#List of Train; Test domains
train_domains= args.train_domains
test_domains= args.test_domains
#Initialize
final_accuracy_target_val=[]
final_accuracy_source_val=[]
if args.os_env:
res_dir= os.getenv('PT_OUTPUT_DIR') + '/'
else:
res_dir= 'results/'
if args.dp_noise:
base_res_dir=(
res_dir + args.dataset_name + '/' + 'dp_' + str(args.dp_epsilon) + '_' + args.method_name + '/' + args.match_layer
+ '/' + 'train_' + str(args.train_domains)
)
else:
base_res_dir=(
res_dir + args.dataset_name + '/' + args.method_name + '/' + args.match_layer
+ '/' + 'train_' + str(args.train_domains)
)
#TODO: Handle slab noise case in helper functions
if args.dataset_name == 'slab':
base_res_dir= base_res_dir + '/slab_noise_' + str(args.slab_noise)
if not os.path.exists(base_res_dir):
os.makedirs(base_res_dir)
#Execute the method for multiple runs ( total args.n_runs )
for run in range(args.n_runs):
print('Run', run)
#Seed for repoduability
random.seed(run*10)
np.random.seed(run*10)
torch.manual_seed(run*10)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(run*10)
#DataLoader
train_dataset= get_dataloader( args, run, train_domains, 'train', 0, kwargs )
if args.method_name == 'matchdg_ctr':
val_dataset= get_dataloader( args, run, train_domains, 'val', 1, kwargs )
else:
val_dataset= get_dataloader( args, run, train_domains, 'val', 0, kwargs )
test_dataset= get_dataloader( args, run, test_domains, 'test', 0, kwargs )
# print('Train Domains, Domain Size, BaseDomainIdx, Total Domains: ', train_domains, total_domains, domain_size, training_list_size)
#Import the module as per the current training method
if args.method_name == 'erm_match' or args.method_name == 'mask_linear' or args.method_name == 'dp_erm':
from algorithms.erm_match import ErmMatch
train_method= ErmMatch(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda
)
elif args.method_name == 'matchdg_ctr':
from algorithms.match_dg import MatchDG
ctr_phase=1
train_method= MatchDG(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda, ctr_phase
)
elif args.method_name == 'matchdg_erm':
from algorithms.match_dg import MatchDG
ctr_phase=0
train_method= MatchDG(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda, ctr_phase
)
elif args.method_name == 'hybrid':
from algorithms.hybrid import Hybrid
train_method= Hybrid(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda
)
elif args.method_name == 'erm':
from algorithms.erm import Erm
train_method= Erm(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda
)
elif args.method_name == 'irm':
from algorithms.irm import Irm
train_method= Irm(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda
)
elif args.method_name == 'dro':
from algorithms.dro import DRO
train_method= DRO(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda
)
elif args.method_name == 'csd':
from algorithms.csd import CSD
train_method= CSD(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda
)
elif args.method_name == 'mmd':
from algorithms.mmd import MMD
train_method= MMD(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda
)
elif args.method_name == 'dann':
from algorithms.dann import DANN
train_method= DANN(
args, train_dataset, val_dataset,
test_dataset, base_res_dir,
run, cuda
)
#Train the method: It will save the model's weights post training and evalute it on test accuracy
train_method.train()
# Final Report Accuacy
if args.method_name != 'matchdg_ctr':
final_acc= np.max(train_method.final_acc)
final_accuracy_target_val.append( final_acc )
idx= np.argmax(train_method.val_acc)
final_acc= train_method.final_acc[idx]
final_accuracy_source_val.append( final_acc )
if args.method_name != 'matchdg_ctr':
print('\n')
print('Done for the Model..')
print('Final Test Accuracy (Source Validation)', np.mean(final_accuracy_source_val), np.std(final_accuracy_source_val) )
print('Final Test Accuracy (Target Validation)', np.mean(final_accuracy_target_val), np.std(final_accuracy_target_val) )
print('\n')