Skip to content

Commit ab96149

Browse files
committedSep 21, 2021
add training scripts
1 parent 578bcfe commit ab96149

File tree

4 files changed

+437
-0
lines changed

4 files changed

+437
-0
lines changed
 

‎Dockerfile

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
FROM carlduke/eidos-base:latest
2+
RUN pip3 install gdown albumentations torchstain sklearn tqdm
3+
WORKDIR /app
4+
COPY train.py /app/train.py
5+
COPY utils.py /app/utils.py
6+
COPY unitopatho.py /app/unitopatho.py
7+
COPY data /app/data
8+
ENTRYPOINT ["/app/train.py"]

‎data/target.jpg

52.9 KB
Loading

‎train.py

+265
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
#!/usr/bin/env python3
2+
3+
import torch
4+
import torch.multiprocessing as mp
5+
import torch.distributed as dist
6+
import torch.nn.functional as F
7+
import torchvision
8+
import albumentations
9+
import albumentations.pytorch
10+
from albumentations.pytorch import ToTensorV2
11+
12+
import numpy as np
13+
import pandas as pd
14+
import cv2
15+
import os
16+
import wandb
17+
import copy
18+
19+
import unitopatho
20+
import utils
21+
22+
import re
23+
import argparse
24+
25+
import torchstain
26+
27+
from tqdm import tqdm
28+
from collections import defaultdict
29+
from sklearn.model_selection import GroupShuffleSplit
30+
from pathlib import Path
31+
from functools import partial
32+
from multiprocessing import Manager
33+
34+
torch.multiprocessing.set_sharing_strategy('file_system')
35+
manager = Manager()
36+
37+
38+
def resnet18(n_classes=2):
39+
model = torchvision.models.resnet18(pretrained='imagenet')
40+
model.fc = torch.nn.Linear(in_features=model.fc.in_features, out_features=n_classes, bias=True)
41+
return model
42+
43+
def preprocess_df(df, label):
44+
if label == 'norm':
45+
df.loc[df.grade == 0, 'grade'] = -1
46+
df.loc[df.type == 'norm', 'grade'] = 0
47+
48+
df = df[df.grade >= 0].copy()
49+
50+
if label != 'both' and label != 'norm':
51+
df = df[df.type == label].copy()
52+
return df
53+
54+
def main(config):
55+
checkpoint = None
56+
if config.test is not None:
57+
print('=> Loading saved checkpoint')
58+
checkpoint = torch.hub.load_state_dict_from_url(f'https://api.wandb.ai/files/eidos/UnitoPath-v1/{config.test}/model.pt',
59+
map_location='cpu', progress=True, check_hash=False)
60+
test = config.test
61+
device = config.device
62+
p = config.path
63+
config = checkpoint['config']
64+
config.test = test
65+
config.device = device
66+
config.path = p
67+
68+
utils.set_seed(config.seed)
69+
scaler = torch.cuda.amp.GradScaler()
70+
71+
if config.test is None:
72+
wandb.init(config=config,
73+
project=f'unitopatho')
74+
75+
path = os.path.join(config.path, str(config.size))
76+
train_df = pd.read_csv(os.path.join(path, 'train.csv'))
77+
test_df = pd.read_csv(os.path.join(path, 'test.csv'))
78+
79+
groupby = config.target + ''
80+
print('=> Raw data (train)')
81+
print(train_df.groupby(groupby).count())
82+
83+
print('\n=> Raw data (test)')
84+
print(test_df.groupby(groupby).count())
85+
86+
if config.target == 'grade':
87+
train_df = preprocess_df(train_df, config.label)
88+
test_df = preprocess_df(test_df, config.label)
89+
90+
# balance train_df (sample mean size)
91+
groups = train_df.groupby('grade').count()
92+
grade_min = int(groups.image_id.idxmin())
93+
mean_size = int(train_df.groupby('grade').count().mean()['image_id'])
94+
95+
train_df = pd.concat((
96+
train_df[train_df.grade == 0].sample(mean_size, replace=(grade_min==0), random_state=config.seed).copy(),
97+
train_df[train_df.grade == 1].sample(mean_size, replace=(grade_min==1), random_state=config.seed).copy()
98+
))
99+
100+
else:
101+
# balance train_df (sample 3rd min_size)
102+
min_size = np.sort(train_df.groupby(groupby).count()['image_id'])[2]
103+
train_df = train_df.groupby(groupby).apply(lambda group: group.sample(min_size, replace=len(group) < min_size, random_state=config.seed)).reset_index(drop=True)
104+
105+
print('\n---- DATA SUMMARY ----')
106+
print('---------------------------------- Train ----------------------------------')
107+
print(train_df.groupby(groupby).count())
108+
print(len(train_df.wsi.unique()), 'WSIs')
109+
110+
print('\n---------------------------------- Test ----------------------------------')
111+
print(test_df.groupby(groupby).count())
112+
print(len(test_df.wsi.unique()), 'WSIs')
113+
114+
im_mean, im_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] # ImageNet
115+
norm = dict(
116+
rgb=dict(mean=im_mean,
117+
std=im_std),
118+
he=dict(mean=im_mean,
119+
std=im_std),
120+
gray=dict(mean=[0.5],
121+
std=[1.0])
122+
)
123+
124+
T_aug = albumentations.Compose([
125+
albumentations.HorizontalFlip(p=0.5),
126+
albumentations.VerticalFlip(p=0.5),
127+
albumentations.Rotate(90, p=0.5)
128+
])
129+
T_jitter = albumentations.ColorJitter()
130+
131+
mean, std = norm[config.preprocess]['mean'], norm[config.preprocess]['std']
132+
print('=> mean, std:', mean, std)
133+
T_tensor = ToTensorV2()
134+
T_post = albumentations.Compose([
135+
albumentations.Normalize(mean, std),
136+
T_tensor
137+
])
138+
139+
print('=> Preparing stain normalizer..')
140+
he_target = cv2.cvtColor(cv2.imread('data/target.jpg'), cv2.COLOR_BGR2RGB)
141+
normalizer = torchstain.MacenkoNormalizer(backend='torch')
142+
normalizer.fit(T_tensor(image=he_target)['image']*255)
143+
print('=> Done')
144+
145+
def normalize_he(x):
146+
if config.preprocess == 'he':
147+
img = x
148+
try:
149+
img = T_tensor(image=img)['image']*255
150+
img, _, _ = normalizer.normalize(img, stains=False)
151+
img = img.numpy().astype(np.uint8)
152+
except Exception as e:
153+
print('Could not normalize image:', e)
154+
img = x
155+
return img
156+
return x
157+
158+
def apply_transforms(train, img):
159+
img = normalize_he(img)
160+
if train:
161+
img = T_aug(image=img)['image']
162+
if config.preprocess == 'rgb':
163+
img = T_jitter(image=img)['image']
164+
x = img
165+
return T_post(image=x)['image']
166+
167+
T_train = partial(apply_transforms, True)
168+
T_test = partial(apply_transforms, False)
169+
170+
datasets_kwargs = {
171+
'path': path,
172+
'subsample': config.subsample,
173+
'target': config.target,
174+
'gray': config.preprocess == 'gray',
175+
'mock': config.mock
176+
}
177+
178+
train_dataset = unitopatho.UTP(train_df, T=T_train, **datasets_kwargs)
179+
test_dataset = unitopatho.UTP(test_df, T=T_test, **datasets_kwargs)
180+
181+
# Final loaders
182+
train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True,
183+
batch_size=config.batch_size,
184+
num_workers=config.n_workers,
185+
pin_memory=True)
186+
test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False,
187+
batch_size=config.batch_size,
188+
num_workers=config.n_workers,
189+
pin_memory=True)
190+
191+
n_classes = len(train_df[config.target].unique())
192+
print(f'=> Training for {n_classes} classes')
193+
194+
n_channels = {
195+
'rgb': 3,
196+
'he': 3,
197+
'gray': 1
198+
}
199+
200+
model = resnet18(n_classes=n_classes)
201+
model.conv1 = torch.nn.Conv2d(n_channels[config.preprocess], 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
202+
if checkpoint is not None:
203+
model.load_state_dict(checkpoint['model'])
204+
model = model.to(config.device)
205+
206+
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
207+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
208+
criterion = F.cross_entropy
209+
210+
for epoch in range(config.epochs):
211+
if config.test is None:
212+
train_metrics = utils.train(model, train_loader, criterion,
213+
optimizer, config.device, metrics=utils.metrics,
214+
accumulation_steps=config.accumulation_steps, scaler=scaler)
215+
scheduler.step()
216+
217+
test_metrics = utils.test(model, test_loader, criterion, config.device, metrics=utils.metrics)
218+
219+
if config.test is None:
220+
print(f'Epoch {epoch}: train: {train_metrics}')
221+
wandb.log({'train': train_metrics,
222+
'test': test_metrics})
223+
torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'config': config},
224+
os.path.join(wandb.run.dir, 'model.pt'))
225+
226+
print(f'test: {test_metrics}')
227+
if config.test is not None:
228+
break
229+
230+
231+
if __name__ == '__main__':
232+
parser = argparse.ArgumentParser()
233+
234+
# data config
235+
parser.add_argument('--path', default=f'{os.path.expanduser("~")}/data/UNITOPATHO', type=str, help='UNITOPATHO dataset path')
236+
parser.add_argument('--size', default=100, type=int, help='patch size in µm (default 100)')
237+
parser.add_argument('--subsample', default=-1, type=int, help='subsample size for data (-1 to disable, default -1)')
238+
239+
# optimizer & network config
240+
parser.add_argument('--epochs', type=int, default=50)
241+
parser.add_argument('--lr', default=0.0001, type=float, help='learning rate')
242+
parser.add_argument('--momentum', default=0.99, type=float, help='momentum')
243+
parser.add_argument('--weight_decay', default=1e-5, type=float, help='weight decay')
244+
parser.add_argument('--batch_size', default=256, type=int, help='batch size')
245+
parser.add_argument('--accumulation_steps', default=1, type=int, help='gradient accumulation steps')
246+
parser.add_argument('--n_workers', default=8, type=int)
247+
parser.add_argument('--architecture', default='resnet18', help='resnet18, resnet50, densenet121')
248+
249+
# training config
250+
parser.add_argument('--preprocess', default='rgb', help='preprocessing type, rgb, he or gray. Default: rgb')
251+
parser.add_argument('--target', default='grade', help='target attribute: grade, type, top_label (default: grade)')
252+
parser.add_argument('--label', default='both', type=str, help='only when target=grade; values: ta, tva, norm or both (default: both)')
253+
parser.add_argument('--test', type=str, help='Run id to test', default=None)
254+
255+
# misc config
256+
parser.add_argument('--name', type=str, default=None)
257+
parser.add_argument('--device', default='cuda', type=str)
258+
parser.add_argument('--mock', action='store_true', dest='mock', help='mock dataset (random noise)')
259+
parser.add_argument('--seed', type=int, default=42)
260+
parser.set_defaults(mock=False)
261+
262+
config = parser.parse_args()
263+
config.device = torch.device(config.device)
264+
265+
main(config)

‎utils.py

+164
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import random
2+
import os
3+
import numpy as np
4+
import torch
5+
import wandb
6+
import pandas as pd
7+
8+
from tqdm import tqdm
9+
from pathlib import Path
10+
from sklearn.metrics import balanced_accuracy_score, roc_auc_score, confusion_matrix, recall_score
11+
12+
def ensure_dir(dirname):
13+
dirname = Path(dirname)
14+
if not dirname.is_dir():
15+
dirname.mkdir(parents=True, exist_ok=False)
16+
17+
def set_seed(seed):
18+
random.seed(seed)
19+
os.environ["PYTHONHASHSEED"] = str(seed)
20+
np.random.seed(seed)
21+
torch.cuda.manual_seed(seed)
22+
torch.cuda.manual_seed_all(seed)
23+
torch.backends.cudnn.deterministic = True
24+
torch.backends.cudnn.benchmark = True
25+
torch.manual_seed(seed)
26+
27+
def binary_accuracy(outputs, labels):
28+
preds = (torch.sigmoid(outputs) > 0.5).long()
29+
correct = preds.eq(labels.long()).sum()
30+
return (correct.float() / float(len(outputs))).item()
31+
32+
def binary_ba(outputs, labels):
33+
preds = (torch.sigmoid(outputs) > 0.5).long()
34+
return balanced_accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())
35+
36+
def roc(outputs, labels, average='macro', multi_class='raise'):
37+
if average is None:
38+
outputs = torch.softmax(outputs, dim=1)
39+
else:
40+
outputs = torch.sigmoid(outputs)
41+
return {c: r for c,r in enumerate(roc_auc_score(labels.cpu().numpy(), outputs.cpu().numpy(), average=average, multi_class=multi_class))}
42+
43+
def binary_metrics(outputs, labels):
44+
return dict(
45+
accuracy=binary_accuracy(outputs, labels),
46+
ba=binary_ba(outputs, labels),
47+
roc=roc(outputs, labels)
48+
)
49+
50+
def accuracy(outputs, labels):
51+
_, preds = torch.max(outputs, dim=1)
52+
return (preds.eq(labels.long()).sum().float() / labels.shape[0]).item()
53+
54+
def ba(outputs, labels):
55+
_, preds = torch.max(outputs, dim=1)
56+
return balanced_accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())
57+
58+
def class_ba(outputs, labels):
59+
_, preds = torch.max(outputs, dim=1)
60+
preds = preds.cpu().numpy()
61+
targets = torch.unique(labels.long()).cpu().numpy()
62+
labels = labels.long().cpu().numpy()
63+
64+
class_ba = {}
65+
for target in targets:
66+
class_labels = (labels == target).astype(np.uint8)
67+
class_preds = (preds == target).astype(np.uint8)
68+
class_ba[int(target)] = balanced_accuracy_score(class_labels, class_preds)
69+
70+
return class_ba
71+
72+
def recall(outputs, labels, average='binary'):
73+
_, preds = torch.max(outputs, dim=1)
74+
return {c: r for c, r in enumerate(recall_score(labels.cpu().numpy(), preds.cpu().numpy(), average=average))}
75+
76+
def cm(outputs, labels):
77+
_, preds = torch.max(outputs, dim=1)
78+
cm = confusion_matrix(labels.cpu().numpy(), preds.cpu().numpy())
79+
print(cm)
80+
return cm
81+
82+
def metrics(outputs, labels):
83+
return dict(
84+
accuracy=accuracy(outputs, labels),
85+
ba=ba(outputs, labels),
86+
class_ba=class_ba(outputs, labels),
87+
recall=recall(outputs, labels, average=None),
88+
#roc=roc(outputs, labels, average=None, multi_class='ovo'),
89+
cm=wandb.Table(dataframe=pd.DataFrame(cm(outputs, labels)))
90+
)
91+
92+
def train(model, dataloader, criterion, optimizer, device, metrics, accumulation_steps=1, scaler=None, verbose=True):
93+
num_samples, tot_loss = 0., 0.
94+
all_outputs, all_labels = [], []
95+
96+
model.train()
97+
itr = tqdm(dataloader, leave=False) if verbose else dataloader
98+
for step, (data, labels) in enumerate(itr):
99+
data, labels = data.to(device), labels.to(device)
100+
101+
outputs, loss = None, None
102+
103+
if scaler is None:
104+
with torch.enable_grad():
105+
outputs = model(data)
106+
loss = criterion(outputs, labels) / accumulation_steps
107+
else:
108+
with torch.cuda.amp.autocast():
109+
outputs = model(data)
110+
loss = criterion(outputs, labels) / accumulation_steps
111+
112+
if scaler is None:
113+
loss.backward()
114+
else:
115+
scaler.scale(loss).backward()
116+
117+
if (step+1) % accumulation_steps == 0 or step == len(dataloader)-1:
118+
if scaler is None:
119+
optimizer.step()
120+
else:
121+
scaler.step(optimizer)
122+
scaler.update()
123+
124+
optimizer.zero_grad()
125+
126+
all_outputs.append(outputs.detach())
127+
all_labels.append(labels.detach())
128+
129+
batch_size = data.shape[0]
130+
num_samples += batch_size
131+
tot_loss += loss.item() * accumulation_steps * batch_size
132+
133+
134+
all_outputs = torch.cat(all_outputs, dim=0)
135+
all_labels = torch.cat(all_labels, dim=0)
136+
137+
tracked_metrics = metrics(all_outputs, all_labels)
138+
tracked_metrics.update({'loss': tot_loss / num_samples})
139+
return tracked_metrics
140+
141+
def test(model, dataloader, criterion, device, metrics):
142+
num_samples, tot_loss = 0., 0.
143+
all_outputs, all_labels = [], []
144+
145+
model.eval()
146+
for data, labels in tqdm(dataloader, leave=False):
147+
data, labels = data.to(device), labels.to(device)
148+
149+
with torch.no_grad():
150+
outputs = model(data)
151+
loss = criterion(outputs, labels)
152+
153+
all_outputs.append(outputs.detach())
154+
all_labels.append(labels.detach())
155+
156+
batch_size = data.shape[0]
157+
num_samples += batch_size
158+
tot_loss += loss.item() * batch_size
159+
160+
all_outputs = torch.cat(all_outputs, dim=0)
161+
all_labels = torch.cat(all_labels, dim=0)
162+
tracked_metrics = metrics(all_outputs, all_labels)
163+
tracked_metrics.update({'loss': tot_loss / num_samples})
164+
return tracked_metrics

0 commit comments

Comments
 (0)
Please sign in to comment.