-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
115 lines (106 loc) · 4.14 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import glob
import socket
import logging
import sys
sys.path.insert(0, '/home/lrm/CaRINA/Inpainting/depth_ipntg/lib/python3.5/site-packages')
import tensorflow as tf
import neuralgym as ng
from inpaint_model import InpaintCAModel
logger = logging.getLogger()
def multigpu_graph_def(model, data, config, gpu_id=0, loss_type='g'):
with tf.device('/cpu:0'):
images = data.data_pipeline(config.BATCH_SIZE)
if gpu_id == 0 and loss_type == 'g':
_, _, losses = model.build_graph_with_losses(
images, config, summary=True, reuse=True)
else:
_, _, losses = model.build_graph_with_losses(
images, config, reuse=True)
if loss_type == 'g':
return losses['g_loss']
elif loss_type == 'd':
return losses['d_loss']
else:
raise ValueError('loss type is not supported.')
if __name__ == "__main__":
config = ng.Config('inpaint.yml')
if config.GPU_ID != -1:
ng.set_gpus(config.GPU_ID)
else:
ng.get_gpus(config.NUM_GPUS)
# training data
with open(config.DATA_FLIST[config.DATASET][0]) as f:
fnames = f.read().splitlines()
data = ng.data.DataFromFNames(
fnames, config.IMG_SHAPES, random_crop=config.RANDOM_CROP)
images = data.data_pipeline(config.BATCH_SIZE)
# main model
model = InpaintCAModel()
g_vars, d_vars, losses = model.build_graph_with_losses(
images, config=config)
# validation images
if config.VAL:
with open(config.DATA_FLIST[config.DATASET][1]) as f:
val_fnames = f.read().splitlines()
# progress monitor by visualizing static images
for i in range(config.STATIC_VIEW_SIZE):
static_fnames = val_fnames[i:i+1]
static_images = ng.data.DataFromFNames(
static_fnames, config.IMG_SHAPES, nthreads=1,
random_crop=config.RANDOM_CROP).data_pipeline(1)
static_inpainted_images = model.build_static_infer_graph(
static_images, config, name='static_view/%d' % i)
# training settings
lr = tf.get_variable(
'lr', shape=[], trainable=False,
initializer=tf.constant_initializer(1e-4))
d_optimizer = tf.train.AdamOptimizer(lr, beta1=0.5, beta2=0.9)
g_optimizer = d_optimizer
# gradient processor
if config.GRADIENT_CLIP:
gradient_processor = lambda grad_var: (
tf.clip_by_average_norm(grad_var[0], config.GRADIENT_CLIP_VALUE),
grad_var[1])
else:
gradient_processor = None
# log dir
log_prefix = 'model_logs/' + '_'.join([
ng.date_uid(), socket.gethostname(), config.DATASET,
'MASKED' if config.GAN_WITH_MASK else 'NORMAL',
config.GAN,config.LOG_DIR])
# train discriminator with secondary trainer, should initialize before
# primary trainer.
discriminator_training_callback = ng.callbacks.SecondaryTrainer(
pstep=1,
optimizer=d_optimizer,
var_list=d_vars,
max_iters=5,
graph_def=multigpu_graph_def,
graph_def_kwargs={
'model': model, 'data': data, 'config': config, 'loss_type': 'd'},
)
# train generator with primary trainer
trainer = ng.train.Trainer(
optimizer=g_optimizer,
var_list=g_vars,
max_iters=config.MAX_ITERS,
graph_def=multigpu_graph_def,
grads_summary=config.GRADS_SUMMARY,
gradient_processor=gradient_processor,
graph_def_kwargs={
'model': model, 'data': data, 'config': config, 'loss_type': 'g'},
spe=config.TRAIN_SPE,
log_dir=log_prefix,
)
# add all callbacks
if not config.PRETRAIN_COARSE_NETWORK:
trainer.add_callbacks(discriminator_training_callback)
trainer.add_callbacks([
ng.callbacks.WeightsViewer(),
ng.callbacks.ModelRestorer(trainer.context['saver'], dump_prefix='model_logs/'+config.MODEL_RESTORE+'/snap', optimistic=True),
ng.callbacks.ModelSaver(config.TRAIN_SPE, trainer.context['saver'], log_prefix+'/snap'),
ng.callbacks.SummaryWriter((config.VAL_PSTEPS//1), trainer.context['summary_writer'], tf.summary.merge_all()),
])
# launch training
trainer.train()