-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainfn.py
111 lines (88 loc) · 3.14 KB
/
trainfn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import torch.optim as optim
import argparse
import time
from tensorboardX import SummaryWriter
from fn import config, datacore
from fn.trainer import Trainer
from fn.checkpoints import CheckpointIO
import pickle
if __name__ == '__main__':
# Arguments
cfg = config.load_config('config/fn.yaml')
is_cuda = (torch.cuda.is_available() )
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# Set t0
t0 = time.time()
# Shorthands
out_dir = 'out/fn'
logfile = open('out/fn/log.txt','a')
batch_size=cfg['training']['batch_size']
if not os.path.exists(out_dir):
os.makedirs(out_dir)
train_dataset = config.get_dataset('train', cfg)
val_dataset = config.get_dataset('val', cfg)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, num_workers=4, shuffle=True,
collate_fn=datacore.collate_remove_none,
worker_init_fn=datacore.worker_init_fn)
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=batch_size, num_workers=4, shuffle=False,
collate_fn=datacore.collate_remove_none,
worker_init_fn=datacore.worker_init_fn)
model = config.get_model(cfg, device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
trainer = Trainer(model, optimizer, device=device)
checkpoint_io = CheckpointIO(out_dir, model=model, optimizer=optimizer)
try:
load_dict = checkpoint_io.load('model.pt')
except FileExistsError:
load_dict = dict()
epoch_it = load_dict.get('epoch_it', -1)
it = load_dict.get('it', -1)
metric_val_best = np.inf
logger = SummaryWriter(os.path.join(out_dir, 'logs'))
# Shorthands
nparameters = sum(p.numel() for p in model.parameters())
logfile.write('Total number of parameters: %d' % nparameters)
print_every = cfg['training']['print_every']
checkpoint_every = cfg['training']['checkpoint_every']
validate_every = cfg['training']['validate_every']
while True:
epoch_it += 1
# scheduler.step()
logfile.flush()
if epoch_it>2000:
logfile.close()
break
for batch in train_loader:
it += 1
if batch['input'].shape[0]==1:
continue
loss = trainer.train_step(batch)
logger.add_scalar('train/loss', loss, it)
if print_every > 0 and (it % print_every) == 0 and it > 0 :
logfile.write('[Epoch %02d] it=%03d, loss=%.6f\n'
% (epoch_it, it, loss))
print('[Epoch %02d] it=%03d, loss=%.6f'
% (epoch_it, it, loss))
# Save checkpoint
if (checkpoint_every > 0 and (it % checkpoint_every) == 0) and it > 0 :
logfile.write('Saving checkpoint')
checkpoint_io.save('model.pt', epoch_it=epoch_it, it=it,
loss_val_best=metric_val_best)
# Run validation
if validate_every > 0 and (it % validate_every) == 0 and it > 0 :
metric_val = trainer.evaluate(val_loader)
metric_val=metric_val.float()
logfile.write('Validation metric : %.6f\n'
% (metric_val))
if metric_val < metric_val_best:
metric_val_best = metric_val
logfile.write('New best model (loss %.6f)\n' % metric_val_best)
checkpoint_io.save('model_best.pt', epoch_it=epoch_it, it=it,
loss_val_best=metric_val_best)