-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_model.py
109 lines (90 loc) · 3.64 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import argparse
import copy
import random
import h5py
import time
import cPickle as pk
import numpy as np
from datetime import datetime
from os import makedirs, remove
from os.path import join, exists, abspath, dirname, basename, isfile
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torchvision import datasets, models, transforms
from sklearn.preprocessing import label_binarize
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
from sklearn.metrics import roc_curve, auc
from contact_dataset import ContactDataset
def train_model(dataloaders,
dataset_sizes,
phase_names,
model_ft,
criterion,
optimizer_ft,
exp_lr_scheduler,
num_epochs,
use_gpu):
since = time.time()
best_model = copy.deepcopy(model_ft.state_dict())
memo_optimizer = copy.deepcopy(optimizer_ft.state_dict())
lowest_loss = 9999.0
best_epoch = -1
acc_record = {x: np.zeros(num_epochs) for x in phase_names}
loss_record = {x: np.zeros(num_epochs) for x in phase_names}
for epoch in range(num_epochs):
print(' - (train_model.py) Epoch {0:d} / {1:d}'.format(
epoch + 1, num_epochs))
# Each epoch has training and validation phases
for phase in phase_names:
if phase == 'train':
exp_lr_scheduler.step()
model_ft.train(True) # Set training mode
else:
model_ft.train(False) # Set evaluate mode
running_loss = 0.0
running_corrects = 0.
# Iterate over data.
for inputs, labels in dataloaders[phase]:
# wrap them in Variable
if use_gpu:
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
else:
inputs, labels = Variable(inputs), Variable(labels)
# zero the parameter gradients
optimizer_ft.zero_grad()
# forward
outputs = model_ft(inputs)
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer_ft.step()
# statistics
running_loss += loss.data[0] * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects / dataset_sizes[phase]
loss_record[phase][epoch] = epoch_loss
acc_record[phase][epoch] = epoch_acc
print(' - {0:s} loss: {1:.4f}, accuracy: {2:.2f}%'.format(
phase, epoch_loss, epoch_acc*100))
validation_loss = loss_record['val'][epoch]
if validation_loss < lowest_loss:
lowest_loss = validation_loss
best_epoch = epoch
best_model = copy.deepcopy(model_ft.state_dict())
memo_optimizer = copy.deepcopy(optimizer_ft.state_dict())
time_elapsed = time.time() - since
print(' - (train_model.py) Training took {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
# Load best model weights
model_ft.load_state_dict(best_model)
optimizer_ft.load_state_dict(memo_optimizer)
return model_ft, optimizer_ft, best_epoch, loss_record, acc_record