-
Notifications
You must be signed in to change notification settings - Fork 8
/
main.py
73 lines (55 loc) · 1.67 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
##################################################
# Imports
##################################################
import json
import pytorch_lightning as pl
# Custom
from config import parse_args
from dataloader import get_dataloaders
from models import cvaecaposr
from utils import get_logger, get_callbacks
# Main function
def main(args):
# Dataloaders
dls, data_info = get_dataloaders(args)
# Model
model = cvaecaposr.get_model(args, data_info)
# Callbacks and logger
callbacks = get_callbacks(args)
tb_logger = get_logger(args)
# Trainer
if args.mode in ['train', 'training']:
trainer = pl.Trainer(
max_epochs=30,
gpus=1,
callbacks=callbacks,
num_sanity_val_steps=0,
logger=tb_logger,
)
# Fit
trainer.fit(
model,
train_dataloader=dls['known']['train_aug'],
val_dataloaders=dls['known']['validation'],
)
# Test (loading best model)
trainer.test(model=None, test_dataloaders=dls['test'])
elif args.mode in ['test', 'testing']:
trainer = pl.Trainer(
gpus=1,
callbacks=callbacks,
logger=tb_logger,
)
# Test
trainer.test(model=model, test_dataloaders=dls['test'])
else:
raise Exception(f'Error. Mode "{args.mode}" is not supported.')
##################################################
# Main
##################################################
if __name__ == '__main__':
# Parse args
args = parse_args()
print(json.dumps(vars(args), indent=4))
# Main
main(args)