-
Notifications
You must be signed in to change notification settings - Fork 123
/
train.py
142 lines (122 loc) · 6.32 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#!/usr/bin/env python
# -*- coding:UTF-8 -*-
# File Name : train.py
# Purpose :
# Creation Date : 09-12-2017
# Last Modified : Fri 19 Jan 2018 10:38:47 AM CST
# Created By : Jeasine Ma [jeasinema[at]gmail[dot]com]
import glob
import argparse
import os
import time
import sys
import tensorflow as tf
from itertools import count
from config import cfg
from model import RPN3D
from utils.kitti_loader import KittiLoader
from train_hook import check_if_should_pause
parser = argparse.ArgumentParser(description='training')
parser.add_argument('-i', '--max-epoch', type=int, nargs='?', default=10,
help='max epoch')
parser.add_argument('-n', '--tag', type=str, nargs='?', default='default',
help='set log tag')
parser.add_argument('-b', '--single-batch-size', type=int, nargs='?', default=1,
help='set batch size for each gpu')
parser.add_argument('-l', '--lr', type=float, nargs='?', default=0.001,
help='set learning rate')
args = parser.parse_args()
dataset_dir = './data/object'
log_dir = os.path.join('./log', args.tag)
save_model_dir = os.path.join('./save_model', args.tag)
os.makedirs(log_dir, exist_ok=True)
os.makedirs(save_model_dir, exist_ok=True)
def main(_):
# TODO: split file support
with tf.Graph().as_default():
global save_model_dir
with KittiLoader(object_dir=os.path.join(dataset_dir, 'training'), queue_size=50, require_shuffle=True,
is_testset=False, batch_size=args.single_batch_size * cfg.GPU_USE_COUNT, use_multi_process_num=8, multi_gpu_sum=cfg.GPU_USE_COUNT, aug=True) as train_loader, \
KittiLoader(object_dir=os.path.join(dataset_dir, 'testing'), queue_size=50, require_shuffle=True,
is_testset=False, batch_size=args.single_batch_size * cfg.GPU_USE_COUNT, use_multi_process_num=8, multi_gpu_sum=cfg.GPU_USE_COUNT, aug=False) as valid_loader:
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=cfg.GPU_MEMORY_FRACTION,
visible_device_list=cfg.GPU_AVAILABLE,
allow_growth=True)
config = tf.ConfigProto(
gpu_options=gpu_options,
device_count={
"GPU": cfg.GPU_USE_COUNT,
},
allow_soft_placement=True,
)
with tf.Session(config=config) as sess:
model = RPN3D(
cls=cfg.DETECT_OBJ,
single_batch_size=args.single_batch_size,
learning_rate=args.lr,
max_gradient_norm=5.0,
is_train=True,
alpha=1.5,
beta=1,
avail_gpus=cfg.GPU_AVAILABLE.split(',')
)
# param init/restore
if tf.train.get_checkpoint_state(save_model_dir):
print("Reading model parameters from %s" % save_model_dir)
model.saver.restore(
sess, tf.train.latest_checkpoint(save_model_dir))
else:
print("Created model with fresh parameters.")
tf.global_variables_initializer().run()
# train and validate
iter_per_epoch = int(
len(train_loader) / (args.single_batch_size * cfg.GPU_USE_COUNT))
is_summary, is_summary_image, is_validate = False, False, False
summary_interval = 5
summary_image_interval = 20
save_model_interval = int(iter_per_epoch / 3)
validate_interval = 60
summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
while model.epoch.eval() < args.max_epoch:
is_summary, is_summary_image, is_validate = False, False, False
iter = model.global_step.eval()
if not iter % summary_interval:
is_summary = True
if not iter % summary_image_interval:
is_summary_image = True
if not iter % save_model_interval:
model.saver.save(sess, os.path.join(
save_model_dir, 'checkpoint'), global_step=model.global_step)
if not iter % validate_interval:
is_validate = True
if not iter % iter_per_epoch:
sess.run(model.epoch_add_op)
print('train {} epoch, total: {}'.format(
model.epoch.eval(), args.max_epoch))
ret = model.train_step(
sess, train_loader.load(), train=True, summary=is_summary)
print('train: {}/{} @ epoch:{}/{} loss: {} reg_loss: {} cls_loss: {} {}'.format(iter,
iter_per_epoch * args.max_epoch, model.epoch.eval(), args.max_epoch, ret[0], ret[1], ret[2], args.tag))
if is_summary:
summary_writer.add_summary(ret[-1], iter)
if is_summary_image:
ret = model.predict_step(
sess, valid_loader.load(), summary=True)
summary_writer.add_summary(ret[-1], iter)
if is_validate:
ret = model.validate_step(
sess, valid_loader.load(), summary=True)
summary_writer.add_summary(ret[-1], iter)
if check_if_should_pause(args.tag):
model.saver.save(sess, os.path.join(
save_model_dir, 'checkpoint'), global_step=model.global_step)
print('pause and save model @ {} steps:{}'.format(
save_model_dir, model.global_step.eval()))
sys.exit(0)
print('train done. total epoch:{} iter:{}'.format(
model.epoch.eval(), model.global_step.eval()))
# finallly save model
model.saver.save(sess, os.path.join(
save_model_dir, 'checkpoint'), global_step=model.global_step)
if __name__ == '__main__':
tf.app.run(main)