-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
116 lines (83 loc) · 3.24 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
import argparse
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from copypaste.utils import inspect, inspect_pair
from copypaste.data import MVTecDataset
from copypaste.transforms import FancyPCA
from copypaste.models import AnomolyRepresentationLearner
from copypaste.losses import CopyPasteLoss
def train_epoch(train_dict, writer):
optimizer = train_dict['optimizer']
dl = train_dict['dataloader']
criterion = train_dict['criterion']
running_loss = 0.0
for i, data in enumerate(dl):
img, cp = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
pred_norm = model(img)
pred_cp = model(cp)
loss = criterion(pred_norm, pred_cp)
loss.backward()
writer.add_scalar('Loss/train', loss, epoch*n_batches + i)
optimizer.step()
running_loss += loss.item()
return running_loss / i
def parse_cli_arguments() -> dict:
parser = argparse.ArgumentParser()
# train setup
parser.add_argument("--dataset",
help="the intermediate dataset will be stored/loaded here",
default='/Users/marcobertolini/Documents/Arbeit/data/pathology/tggates/segmentation/')
parser.add_argument("--workers", default=0, type=int, help="the number of workers used by the loaders")
parser.add_argument('--save_dir', action='store', dest='save_dir', default="results", type=str,
help="Where the results should be saved. Defaults to results/")
# model hyperparams
parser.add_argument('--lr', action='store', dest='lr', default=0.01, type=float,
help="Initial learning rate for training. Defaults to 0.01.")
parser.add_argument("--batch_size", default=8, type=int, help="the batch size for the loaders")
parser.add_argument("--n_epochs", default=10, type=int, help="train for this many epochs")
parser.add_argument('--transform', dest='use_transform', action='store_true',
help="Use geometric transformations")
# miscellaneous
# check command line arguments
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_cli_arguments()
# path = "/Users/jake/bayer/copypaste/data/toothbrush/train"
options = {
'area_ratio_range': (0.01, 0.05)
}
transforms = [FancyPCA()]
ds = MVTecDataset(
args.dataset,
copypaste=True,
transforms=transforms, **options
)
dl = DataLoader(
ds,
batch_size=args.batch_size,
num_workers=2
)
model = AnomolyRepresentationLearner()
criterion = CopyPasteLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
training_dict = {
'dataloader': dl,
'model': model,
'criterion': criterion,
'optimizer': optimizer,
}
n_batches = len(ds) / args.batch_size
writer = SummaryWriter()
for epoch in range(2):
running_loss = 0.0
avg_train_loss = train_epoch(training_dict, writer)
print('Finished Training')