|
| 1 | +#!/usr/bin/env python3 |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.multiprocessing as mp |
| 5 | +import torch.distributed as dist |
| 6 | +import torch.nn.functional as F |
| 7 | +import torchvision |
| 8 | +import albumentations |
| 9 | +import albumentations.pytorch |
| 10 | +from albumentations.pytorch import ToTensorV2 |
| 11 | + |
| 12 | +import numpy as np |
| 13 | +import pandas as pd |
| 14 | +import cv2 |
| 15 | +import os |
| 16 | +import wandb |
| 17 | +import copy |
| 18 | + |
| 19 | +import unitopatho |
| 20 | +import utils |
| 21 | + |
| 22 | +import re |
| 23 | +import argparse |
| 24 | + |
| 25 | +import torchstain |
| 26 | + |
| 27 | +from tqdm import tqdm |
| 28 | +from collections import defaultdict |
| 29 | +from sklearn.model_selection import GroupShuffleSplit |
| 30 | +from pathlib import Path |
| 31 | +from functools import partial |
| 32 | +from multiprocessing import Manager |
| 33 | + |
| 34 | +torch.multiprocessing.set_sharing_strategy('file_system') |
| 35 | +manager = Manager() |
| 36 | + |
| 37 | + |
| 38 | +def resnet18(n_classes=2): |
| 39 | + model = torchvision.models.resnet18(pretrained='imagenet') |
| 40 | + model.fc = torch.nn.Linear(in_features=model.fc.in_features, out_features=n_classes, bias=True) |
| 41 | + return model |
| 42 | + |
| 43 | +def preprocess_df(df, label): |
| 44 | + if label == 'norm': |
| 45 | + df.loc[df.grade == 0, 'grade'] = -1 |
| 46 | + df.loc[df.type == 'norm', 'grade'] = 0 |
| 47 | + |
| 48 | + df = df[df.grade >= 0].copy() |
| 49 | + |
| 50 | + if label != 'both' and label != 'norm': |
| 51 | + df = df[df.type == label].copy() |
| 52 | + return df |
| 53 | + |
| 54 | +def main(config): |
| 55 | + checkpoint = None |
| 56 | + if config.test is not None: |
| 57 | + print('=> Loading saved checkpoint') |
| 58 | + checkpoint = torch.hub.load_state_dict_from_url(f'https://api.wandb.ai/files/eidos/UnitoPath-v1/{config.test}/model.pt', |
| 59 | + map_location='cpu', progress=True, check_hash=False) |
| 60 | + test = config.test |
| 61 | + device = config.device |
| 62 | + p = config.path |
| 63 | + config = checkpoint['config'] |
| 64 | + config.test = test |
| 65 | + config.device = device |
| 66 | + config.path = p |
| 67 | + |
| 68 | + utils.set_seed(config.seed) |
| 69 | + scaler = torch.cuda.amp.GradScaler() |
| 70 | + |
| 71 | + if config.test is None: |
| 72 | + wandb.init(config=config, |
| 73 | + project=f'unitopatho') |
| 74 | + |
| 75 | + path = os.path.join(config.path, str(config.size)) |
| 76 | + train_df = pd.read_csv(os.path.join(path, 'train.csv')) |
| 77 | + test_df = pd.read_csv(os.path.join(path, 'test.csv')) |
| 78 | + |
| 79 | + groupby = config.target + '' |
| 80 | + print('=> Raw data (train)') |
| 81 | + print(train_df.groupby(groupby).count()) |
| 82 | + |
| 83 | + print('\n=> Raw data (test)') |
| 84 | + print(test_df.groupby(groupby).count()) |
| 85 | + |
| 86 | + if config.target == 'grade': |
| 87 | + train_df = preprocess_df(train_df, config.label) |
| 88 | + test_df = preprocess_df(test_df, config.label) |
| 89 | + |
| 90 | + # balance train_df (sample mean size) |
| 91 | + groups = train_df.groupby('grade').count() |
| 92 | + grade_min = int(groups.image_id.idxmin()) |
| 93 | + mean_size = int(train_df.groupby('grade').count().mean()['image_id']) |
| 94 | + |
| 95 | + train_df = pd.concat(( |
| 96 | + train_df[train_df.grade == 0].sample(mean_size, replace=(grade_min==0), random_state=config.seed).copy(), |
| 97 | + train_df[train_df.grade == 1].sample(mean_size, replace=(grade_min==1), random_state=config.seed).copy() |
| 98 | + )) |
| 99 | + |
| 100 | + else: |
| 101 | + # balance train_df (sample 3rd min_size) |
| 102 | + min_size = np.sort(train_df.groupby(groupby).count()['image_id'])[2] |
| 103 | + train_df = train_df.groupby(groupby).apply(lambda group: group.sample(min_size, replace=len(group) < min_size, random_state=config.seed)).reset_index(drop=True) |
| 104 | + |
| 105 | + print('\n---- DATA SUMMARY ----') |
| 106 | + print('---------------------------------- Train ----------------------------------') |
| 107 | + print(train_df.groupby(groupby).count()) |
| 108 | + print(len(train_df.wsi.unique()), 'WSIs') |
| 109 | + |
| 110 | + print('\n---------------------------------- Test ----------------------------------') |
| 111 | + print(test_df.groupby(groupby).count()) |
| 112 | + print(len(test_df.wsi.unique()), 'WSIs') |
| 113 | + |
| 114 | + im_mean, im_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] # ImageNet |
| 115 | + norm = dict( |
| 116 | + rgb=dict(mean=im_mean, |
| 117 | + std=im_std), |
| 118 | + he=dict(mean=im_mean, |
| 119 | + std=im_std), |
| 120 | + gray=dict(mean=[0.5], |
| 121 | + std=[1.0]) |
| 122 | + ) |
| 123 | + |
| 124 | + T_aug = albumentations.Compose([ |
| 125 | + albumentations.HorizontalFlip(p=0.5), |
| 126 | + albumentations.VerticalFlip(p=0.5), |
| 127 | + albumentations.Rotate(90, p=0.5) |
| 128 | + ]) |
| 129 | + T_jitter = albumentations.ColorJitter() |
| 130 | + |
| 131 | + mean, std = norm[config.preprocess]['mean'], norm[config.preprocess]['std'] |
| 132 | + print('=> mean, std:', mean, std) |
| 133 | + T_tensor = ToTensorV2() |
| 134 | + T_post = albumentations.Compose([ |
| 135 | + albumentations.Normalize(mean, std), |
| 136 | + T_tensor |
| 137 | + ]) |
| 138 | + |
| 139 | + print('=> Preparing stain normalizer..') |
| 140 | + he_target = cv2.cvtColor(cv2.imread('data/target.jpg'), cv2.COLOR_BGR2RGB) |
| 141 | + normalizer = torchstain.MacenkoNormalizer(backend='torch') |
| 142 | + normalizer.fit(T_tensor(image=he_target)['image']*255) |
| 143 | + print('=> Done') |
| 144 | + |
| 145 | + def normalize_he(x): |
| 146 | + if config.preprocess == 'he': |
| 147 | + img = x |
| 148 | + try: |
| 149 | + img = T_tensor(image=img)['image']*255 |
| 150 | + img, _, _ = normalizer.normalize(img, stains=False) |
| 151 | + img = img.numpy().astype(np.uint8) |
| 152 | + except Exception as e: |
| 153 | + print('Could not normalize image:', e) |
| 154 | + img = x |
| 155 | + return img |
| 156 | + return x |
| 157 | + |
| 158 | + def apply_transforms(train, img): |
| 159 | + img = normalize_he(img) |
| 160 | + if train: |
| 161 | + img = T_aug(image=img)['image'] |
| 162 | + if config.preprocess == 'rgb': |
| 163 | + img = T_jitter(image=img)['image'] |
| 164 | + x = img |
| 165 | + return T_post(image=x)['image'] |
| 166 | + |
| 167 | + T_train = partial(apply_transforms, True) |
| 168 | + T_test = partial(apply_transforms, False) |
| 169 | + |
| 170 | + datasets_kwargs = { |
| 171 | + 'path': path, |
| 172 | + 'subsample': config.subsample, |
| 173 | + 'target': config.target, |
| 174 | + 'gray': config.preprocess == 'gray', |
| 175 | + 'mock': config.mock |
| 176 | + } |
| 177 | + |
| 178 | + train_dataset = unitopatho.UTP(train_df, T=T_train, **datasets_kwargs) |
| 179 | + test_dataset = unitopatho.UTP(test_df, T=T_test, **datasets_kwargs) |
| 180 | + |
| 181 | + # Final loaders |
| 182 | + train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True, |
| 183 | + batch_size=config.batch_size, |
| 184 | + num_workers=config.n_workers, |
| 185 | + pin_memory=True) |
| 186 | + test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False, |
| 187 | + batch_size=config.batch_size, |
| 188 | + num_workers=config.n_workers, |
| 189 | + pin_memory=True) |
| 190 | + |
| 191 | + n_classes = len(train_df[config.target].unique()) |
| 192 | + print(f'=> Training for {n_classes} classes') |
| 193 | + |
| 194 | + n_channels = { |
| 195 | + 'rgb': 3, |
| 196 | + 'he': 3, |
| 197 | + 'gray': 1 |
| 198 | + } |
| 199 | + |
| 200 | + model = resnet18(n_classes=n_classes) |
| 201 | + model.conv1 = torch.nn.Conv2d(n_channels[config.preprocess], 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) |
| 202 | + if checkpoint is not None: |
| 203 | + model.load_state_dict(checkpoint['model']) |
| 204 | + model = model.to(config.device) |
| 205 | + |
| 206 | + optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay) |
| 207 | + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1) |
| 208 | + criterion = F.cross_entropy |
| 209 | + |
| 210 | + for epoch in range(config.epochs): |
| 211 | + if config.test is None: |
| 212 | + train_metrics = utils.train(model, train_loader, criterion, |
| 213 | + optimizer, config.device, metrics=utils.metrics, |
| 214 | + accumulation_steps=config.accumulation_steps, scaler=scaler) |
| 215 | + scheduler.step() |
| 216 | + |
| 217 | + test_metrics = utils.test(model, test_loader, criterion, config.device, metrics=utils.metrics) |
| 218 | + |
| 219 | + if config.test is None: |
| 220 | + print(f'Epoch {epoch}: train: {train_metrics}') |
| 221 | + wandb.log({'train': train_metrics, |
| 222 | + 'test': test_metrics}) |
| 223 | + torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'config': config}, |
| 224 | + os.path.join(wandb.run.dir, 'model.pt')) |
| 225 | + |
| 226 | + print(f'test: {test_metrics}') |
| 227 | + if config.test is not None: |
| 228 | + break |
| 229 | + |
| 230 | + |
| 231 | +if __name__ == '__main__': |
| 232 | + parser = argparse.ArgumentParser() |
| 233 | + |
| 234 | + # data config |
| 235 | + parser.add_argument('--path', default=f'{os.path.expanduser("~")}/data/UNITOPATHO', type=str, help='UNITOPATHO dataset path') |
| 236 | + parser.add_argument('--size', default=100, type=int, help='patch size in µm (default 100)') |
| 237 | + parser.add_argument('--subsample', default=-1, type=int, help='subsample size for data (-1 to disable, default -1)') |
| 238 | + |
| 239 | + # optimizer & network config |
| 240 | + parser.add_argument('--epochs', type=int, default=50) |
| 241 | + parser.add_argument('--lr', default=0.0001, type=float, help='learning rate') |
| 242 | + parser.add_argument('--momentum', default=0.99, type=float, help='momentum') |
| 243 | + parser.add_argument('--weight_decay', default=1e-5, type=float, help='weight decay') |
| 244 | + parser.add_argument('--batch_size', default=256, type=int, help='batch size') |
| 245 | + parser.add_argument('--accumulation_steps', default=1, type=int, help='gradient accumulation steps') |
| 246 | + parser.add_argument('--n_workers', default=8, type=int) |
| 247 | + parser.add_argument('--architecture', default='resnet18', help='resnet18, resnet50, densenet121') |
| 248 | + |
| 249 | + # training config |
| 250 | + parser.add_argument('--preprocess', default='rgb', help='preprocessing type, rgb, he or gray. Default: rgb') |
| 251 | + parser.add_argument('--target', default='grade', help='target attribute: grade, type, top_label (default: grade)') |
| 252 | + parser.add_argument('--label', default='both', type=str, help='only when target=grade; values: ta, tva, norm or both (default: both)') |
| 253 | + parser.add_argument('--test', type=str, help='Run id to test', default=None) |
| 254 | + |
| 255 | + # misc config |
| 256 | + parser.add_argument('--name', type=str, default=None) |
| 257 | + parser.add_argument('--device', default='cuda', type=str) |
| 258 | + parser.add_argument('--mock', action='store_true', dest='mock', help='mock dataset (random noise)') |
| 259 | + parser.add_argument('--seed', type=int, default=42) |
| 260 | + parser.set_defaults(mock=False) |
| 261 | + |
| 262 | + config = parser.parse_args() |
| 263 | + config.device = torch.device(config.device) |
| 264 | + |
| 265 | + main(config) |
0 commit comments