forked from titu1994/Fast-Neural-Style
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
183 lines (130 loc) · 7.25 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
from __future__ import print_function
from __future__ import division
import os
from loss import dummy_loss
import models
import numpy as np
import argparse
import time
import img_utils
import warnings
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from keras import backend as K
# Only supports Theano for now
K.set_image_dim_ordering("th")
parser = argparse.ArgumentParser(description='Fast Neural style transfer with Keras.')
parser.add_argument('style_reference_image_path', metavar='ref', type=str,
help='Path to the style reference image.')
parser.add_argument("data_path", type=str, help="Path to training images")
parser.add_argument("validation_img", type=str, default=None, help='Path to validation image')
parser.add_argument("--content_weight", type=float, default=10., help='Content weight')
parser.add_argument("--style_weight", type=float, default=1., help='Style weight')
parser.add_argument("--tv_weight", type=float, default=8.5e-5, help='Total Variation Weight')
parser.add_argument("--image_size", dest="img_size", default=256, type=int, help='Output Image size')
parser.add_argument("--epochs", default=1, type=int, help='Number of epochs')
parser.add_argument("--nb_imgs", default=80000, type=int, help='Number of images per epoch')
parser.add_argument("--model_depth", default="shallow", type=str, help='Can be one of "shallow" or "wide"')
parser.add_argument("--model_width", default="thin", type=str, help='Can be one of "thin" or "wide"')
parser.add_argument("--pool_type", default="max", type=str, help='Pooling type')
parser.add_argument("--kernel_size", default=3, type=int, help='Kernel Size')
parser.add_argument("--val_checkpoint", type=int, default=-1, help='Check the output of network to a validation image')
args = parser.parse_args()
style_reference_image_path = args.style_reference_image_path
style_name = os.path.splitext(os.path.basename(style_reference_image_path))[0]
validation_img_path = args.validation_img
warnings.warn("Due to recent changes in how regularizers are handled in Keras, the code has been updated to support the new method.\n"
"Please update your keras to the master branch to train properly.")
''' Attributes '''
# Dimensions of the input image
img_width = img_height = int(args.img_size) # Image size needs to be same for gram matrix
nb_epoch = int(args.epochs)
num_iter = int(args.nb_imgs) # Should be equal to number of images
train_batchsize = 1 # Using batchsize >= 2 results in unstable training
content_weight = float(args.content_weight)
style_weight = float(args.style_weight)
tv_weight = float(args.tv_weight)
val_checkpoint = int(args.val_checkpoint)
if val_checkpoint == -1:
val_checkpoint = num_iter / 200 # Assuming full MS COCO Dataset has ~80k samples, validate every 400 samples
kernel_size = int(args.kernel_size)
pool_type = str(args.pool_type)
assert pool_type in ["max", "ave"], 'Pool Type must be either "max" or "ave"'
pool_type = 1 if pool_type == "ave" else 0
iteration = 0
model_depth = str(args.model_depth).lower()
assert model_depth in ["shallow", "deep"], 'model_depth must be one of "shallow" or "deep"'
model_width = str(args.model_width).lower()
assert model_width in ["thin", "wide"], 'model_width must be one of "thin" or "wide"'
size_multiple = 4 if model_depth == "shallow" else 8
''' Model '''
if not os.path.exists("models/"):
os.makedirs("models/")
FastNet = models.FastStyleNet(img_width=img_width, img_height=img_height, kernel_size=kernel_size, pool_type=pool_type,
style_weight=style_weight, content_weight=content_weight, tv_weight=tv_weight,
model_width=model_width, model_depth=model_depth,
save_fastnet_model="models/%s.h5" % style_name)
model = FastNet.create_model(style_name=None, train_mode=True, style_image_path=style_reference_image_path)
optimizer = Adam(beta_1=0.99)
model.compile(optimizer, dummy_loss) # Dummy loss is used since we are learning from regularizes
print('Finished compiling fastnet model.')
datagen = ImageDataGenerator(rescale=1. / 255)
if K.image_dim_ordering() == "th":
dummy_y = np.zeros((train_batchsize, 3, img_height, img_width)) # Dummy output, not used since we use regularizers to train
else:
dummy_y = np.zeros((train_batchsize, img_height, img_width, 3)) # Dummy output, not used since we use regularizers to train
prev_improvement = -1
early_stop = False
validation_fastnet = None
for i in range(nb_epoch):
print()
print("Epoch : %d" % (i + 1))
for x in datagen.flow_from_directory(args.data_path, class_mode=None, batch_size=train_batchsize,
target_size=(img_width, img_height), shuffle=False):
try:
t1 = time.time()
hist = model.fit([x, x.copy()], dummy_y, batch_size=train_batchsize, nb_epoch=1, verbose=0)
iteration += train_batchsize
loss = hist.history['loss'][0]
if prev_improvement == -1:
prev_improvement = loss
improvement = (prev_improvement - loss) / prev_improvement * 100
prev_improvement = loss
t2 = time.time()
print("Iter : %d / %d | Improvement : %0.2f percent | Time required : %0.2f seconds | Loss : %d" %
(iteration, num_iter, improvement, t2 - t1, loss))
if iteration % val_checkpoint == 0:
print("Producing validation image...")
# This ensures that image height and width is an even number
x = img_utils.preprocess_image(validation_img_path, resize=False)
x /= 255.
width, height = x.shape[2], x.shape[3]
iter_path = style_name + "_epoch_%d_at_iteration_%d" % (i + 1, iteration)
FastNet.save_fastnet_weights(iter_path, directory="val_weights/")
path = "val_weights/fastnet_" + iter_path + ".h5"
if validation_fastnet is None:
validation_fastnet = models.FastStyleNet(width, height, kernel_size, pool_type,
model_width=model_width, model_depth=model_depth)
validation_fastnet.create_model(validation_path=path)
validation_fastnet.model.compile(optimizer, dummy_loss)
else:
validation_fastnet.model.load_weights(path)
y_pred = validation_fastnet.fastnet_predict(x)
y_pred = y_pred[0, :, :, :]
y_pred = y_pred.transpose((1, 2, 0))
print("Mean per channel : ", np.mean(y_pred, axis=(0, 1)))
y_pred = np.clip(y_pred, 0, 255).astype('uint8')
path = "val_epoch_%d_at_iteration_%d.png" % (i + 1, iteration)
img_utils.save_result(y_pred, path, directory="val_imgs/")
path = "val_imgs/" + path
print("Validation image saved at : %s" % path)
if iteration >= num_iter:
break
except KeyboardInterrupt:
print("Keyboard interrupt detected. Stopping early.")
early_stop = True
break
iteration = 0
if early_stop:
break
FastNet.save_fastnet_weights(style_name, directory="weights/")