-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_distill.py
81 lines (66 loc) · 2.96 KB
/
train_distill.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import sys
import os
import math
import time
import cv2
import numpy as np
import torch
from torch import nn
from torch import optim
import deepvac
from deepvac import LOG, DeepvacTrain
from deepvac.experimental.core import DeepvacDistill
from modules.utils_IOU_eval import IOUEval
class ESPNetTrain(DeepvacDistill):
def __init__(self, deepvac_config):
super(ESPNetTrain, self).__init__(deepvac_config)
self.config.epoch_loss = []
def train(self):
self.iou_eval_val = IOUEval(self.config.cls_num)
self.iou_eval_train = IOUEval(self.config.cls_num)
for i, loader in enumerate(self.config.train_loader_list):
self.config.train_loader = loader
super(ESPNetTrain, self).train()
def postIter(self):
if not self.config.train_loader.is_last_loader:
return
self.config.epoch_loss.append(self.config.loss.item())
if self.config.phase == 'TRAIN':
self.iou_eval_train.addBatch(self.config.output[0].max(1)[1].data, self.config.target.data)
else:
self.iou_eval_val.addBatch(self.config.output[0].max(1)[1].data, self.config.target.data)
def preEpoch(self):
self.config.epoch_loss = []
def postEpoch(self):
if not self.config.train_loader.is_last_loader:
return
average_epoch_loss = sum(self.config.epoch_loss) / len(self.epoch_loss)
if self.config.phase == 'TRAIN':
overall_acc, per_class_acc, per_class_iu, mIOU = self.iou_eval_train.getMetric()
else:
overall_acc, per_class_acc, per_class_iu, mIOU = self.iou_eval_val.getMetric()
self.config.acc = mIOU
LOG.logI("Epoch : {} Details".format(self.config.epoch))
LOG.logI("\nEpoch No.: %d\t%s Loss = %.4f\t %s mIOU = %.4f\t" % (self.config.epoch, self.config.phase, average_epoch_loss, self.config.phase, mIOU))
def doSchedule(self):
if not self.config.train_loader.is_last_loader:
return
self.config.scheduler.step()
def doLoss(self):
if not self.config.train_loader.is_last_loader:
return
loss1, loss2 = self.config.criterion(self.config.output[0], self.config.target), self.config.criterion(self.config.output[1], self.config.target)
loss3, loss4 = self.config.criterion(self.config.teacher.output[0], self.config.target), self.config.criterion(self.config.teacher.output[1], self.config.target)
self.config.loss = loss1 + loss2
self.config.teacher.loss = loss3 + loss4
LOG.logI('loss1: {}, loss2: {}, loss3: {}, loss4: {}'.format(loss1, loss2, loss3, loss4))
def doOptimize(self):
super(DeepvacDistill, self).doOptimize()
if self.config.iter % self.config.nominal_batch_factor != 0:
return
self.config.teacher.optimizer.step()
self.config.teacher.optimizer.zero_grad()
if __name__ == "__main__":
from config import config
train = ESPNetTrain(config)
train()