-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathtrain_classifier.py
199 lines (185 loc) · 9.42 KB
/
train_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# COVID-CT-Mask-Net
# I re-implemented Torchvision's detection library (Faster and Mask R-CNN) as a classifier
# Alex Ter-Sarkisov @ City, University of London
#
import os
import pickle
import sys
import sys
import time
import config_classifier
import cv2
import datasets.dataset_classifier as dataset
# IMPORT LOCAL IMPLEMENTATION OF TORCHVISION'S DETECTION LIBRARY
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
import utils
from PIL import Image as PILImage
# IMPORT LOCAL IMPLEMENTATION OF TORCHVISION'S DETECTION LIBRARY
# Faster R-CNN interface
import models.mask_net as mask_net
from models.mask_net.faster_rcnn import FastRCNNPredictor, TwoMLPHead
from models.mask_net.rpn import AnchorGenerator
from torch.utils import data
from torchvision import transforms
# main method
def main(config, main_step):
torch.manual_seed(time.time())
start_time = time.time()
devices = ['cpu', 'cuda']
backbones = ['resnet50', 'resnet34', 'resnet18']
truncation_levels = ['0','1','2']
assert config.device in devices
assert config.backbone_name in backbones
assert config.truncation in truncation_levels
start_epoch, pretrained_classifier, pretrained_segment, model_name, num_epochs, save_dir, train_data_dir, val_data_dir, \
batch_size, device, save_every, lrate, rpn_nms, roi_nms, backbone_name, truncation, roi_batch_size, n_c, s_features = \
config.start_epoch, config.pretrained_classification_model, \
config.pretrained_segmentation_model, \
config.model_name, config.num_epochs, config.save_dir, \
config.train_data_dir, config.val_data_dir, \
config.batch_size, config.device, config.save_every, \
config.lrate, config.rpn_nms_th, config.roi_nms_th, \
config.backbone_name, config.truncation, \
config.roi_batch_size, config.num_classes, config.s_features
if pretrained_classifier is not None and pretrained_segment is not None:
print("Not clear which model to use, switching to the classifier")
pretrained_model = pretrained_classifier
elif pretrained_classifier is not None and pretrained_segment is None:
pretrained_model = pretrained_classifier
else:
pretrained_model = pretrained_segment
if device == 'cuda' and torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
##############################################################################################
# DATASETS+DATALOADERS
# Alex: could be added in the config file in the future
# parameters for the dataset
# 512x512 is the recommended image size input
dataset_covid_pars_train_cl = {'stage': 'train', 'data': train_data_dir, 'img_size': (512,512)}
datapoint_covid_train_cl = dataset.COVID_CT_DATA(**dataset_covid_pars_train_cl)
#
dataset_covid_pars_eval_cl = {'stage': 'eval', 'data': val_data_dir, 'img_size': (512,512)}
datapoint_covid_eval_cl = dataset.COVID_CT_DATA(**dataset_covid_pars_eval_cl)
#
dataloader_covid_pars_train_cl = {'shuffle': True, 'batch_size': batch_size}
dataloader_covid_train_cl = data.DataLoader(datapoint_covid_train_cl, **dataloader_covid_pars_train_cl)
#
dataloader_covid_pars_eval_cl = {'shuffle': True, 'batch_size': batch_size}
dataloader_covid_eval_cl = data.DataLoader(datapoint_covid_eval_cl, **dataloader_covid_pars_eval_cl)
#
##### LOAD PRETRAINED WEIGHTS FROM MASK R-CNN MODEL
# This must be the full path to the checkpoint with the anchor generator and model weights
# Assumed that the keys in the checkpoint are model_weights and anchor_generator
ckpt = torch.load(pretrained_model, map_location=device)
# keyword arguments
# box_score_threshold:negative!
# set both NMS thresholds to 0.75 to get adjacent RoIs
# Box detections/image: batch size for the classifier
#
covid_mask_net_args = {'num_classes': None, 'min_size': 512, 'max_size': 1024, 'box_detections_per_img': roi_batch_size,
'box_nms_thresh': roi_nms, 'box_score_thresh': -0.01, 'rpn_nms_thresh': rpn_nms}
# copy the anchor generator parameters, create a new one to avoid implementations' clash
sizes = ckpt['anchor_generator'].sizes
aspect_ratios = ckpt['anchor_generator'].aspect_ratios
anchor_generator = AnchorGenerator(sizes, aspect_ratios)
# out_channels:256, FPN
# num_classes:3 (1+2)
box_head = TwoMLPHead(in_channels=256*7*7, representation_size=128)
box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
covid_mask_net_args['rpn_anchor_generator'] = anchor_generator
covid_mask_net_args['box_predictor'] = box_predictor
covid_mask_net_args['box_head'] = box_head
covid_mask_net_args['s_representation_size'] = s_features
# Instantiate the model
covid_mask_net_model = mask_net.fasterrcnn_resnet_fpn(backbone_name, truncation, **covid_mask_net_args)
# which parameters to train?
trained_pars = []
# if the weights are loaded from the segmentation model:
if pretrained_classifier is None:
for _n, _par in covid_mask_net_model.state_dict().items():
if _n in ckpt['model_weights']:
print('Loading parameter', _n)
_par.copy_(ckpt['model_weights'][_n])
# if the weights are loaded from the classification model
else:
covid_mask_net_model.load_state_dict(ckpt['model_weights'])
if 'epoch' in ckpt.keys():
start_epoch = int(ckpt['epoch']) + 1
if 'model_name' in ckpt.keys():
model_name = str(ckpt['model_name'])
# Evaluation mode, no labels!
covid_mask_net_model.eval()
# set the model to training mode without triggering the 'training' mode of Mask R-CNN
# set up the optimizer
utils.switch_model_on(covid_mask_net_model, ckpt, trained_pars)
utils.set_to_train_mode(covid_mask_net_model)
print(covid_mask_net_model)
covid_mask_net_model = covid_mask_net_model.to(device)
total_trained_pars = sum([x.numel() for x in trained_pars])
print("Total trained pars {0:d}".format(total_trained_pars))
optimizer_pars = {'lr': lrate, 'weight_decay': 1e-3}
optimizer = torch.optim.Adam(trained_pars, **optimizer_pars)
if pretrained_classifier is not None and 'optimizer_state' in ckpt.keys():
optimizer.load_state_dict(ckpt['optimizer_state'])
if start_epoch>0:
num_epochs += start_epoch
print("Start training, epoch = {:d}".format(start_epoch))
for e in range(start_epoch, num_epochs):
train_loss_epoch = main_step("train", e, dataloader_covid_train_cl, optimizer, device, covid_mask_net_model,
save_every, lrate, model_name, None, None)
eval_loss_epoch = main_step("eval", e, dataloader_covid_eval_cl, optimizer, device, covid_mask_net_model,
save_every, lrate, model_name, anchor_generator, save_dir)
print(
"Epoch {0:d}: train loss = {1:.3f}, validation loss = {2:.3f}".format(e, train_loss_epoch, eval_loss_epoch))
end_time = time.time()
print("Training took {0:.1f} seconds".format(end_time - start_time))
def step(stage, e, dataloader, optimizer, device, model, save_every, lrate, model_name, anchors, save_dir):
epoch_loss = 0
for id, b in enumerate(dataloader):
optimizer.zero_grad()
X, y = b
if device == torch.device('cuda'):
X, y = X.to(device), y.to(device)
# some batches are less than batch_size
batch_s = X.size()[0]
batch_scores = []
# input all images in the batch into COVID-Mask-Net to get B scores
for id in range(batch_s):
image = [X[id]] # remove the batch dimension
predict_scores = model(image)
batch_scores.append(predict_scores[0]['final_scores'])
# batchify scores/image and compute binary cross-entropy loss
batch_scores = torch.stack(batch_scores)
batch_loss = F.binary_cross_entropy_with_logits(batch_scores, y)
if stage == "train":
batch_loss.backward()
optimizer.step()
else:
pass
epoch_loss += batch_loss.clone().detach().cpu().numpy()
epoch_loss = epoch_loss / len(dataloader)
if not (e+1) % save_every and stage == "eval":
model.eval()
state = {'epoch': str(e+1), 'model_weights': model.state_dict(),
'optimizer_state': optimizer.state_dict(), 'lrate': lrate, 'anchor_generator': anchors,
'model_name': model_name}
if model_name is None:
torch.save(state, os.path.join(save_dir, "covid_ct_mask_net_ckpt_" + str(e+1) + ".pth"))
else:
torch.save(state, os.path.join(save_dir, model_name + "_ckpt_" + str(e+1) + ".pth"))
utils.set_to_train_mode(model)
return epoch_loss
# run the training
if __name__ == '__main__':
config_train = config_classifier.get_config_pars_classifier("trainval")
if config_train.pretrained_classification_model is None and config_train.pretrained_segmentation_model is None:
print("You must have at least one pretrained model!")
sys.exit(0)
else:
main(config_train, step)