Skip to content

Commit

Permalink
Fix "The image generated by eval.py is black" problem.
Browse files Browse the repository at this point in the history
before: nan -> relu -> nan
now   : nan -> relu(custom) -> 0

- added some comments for a better understanding.
  • Loading branch information
LanKo-X committed May 20, 2017
1 parent beb864e commit f2149fa
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 21 deletions.
26 changes: 20 additions & 6 deletions eval.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# coding: utf-8
from __future__ import print_function
import tensorflow as tf
from nets import nets_factory
from preprocessing import preprocessing_factory
import reader
import model
Expand All @@ -18,6 +17,8 @@


def main(_):

# Get image's height and width.
height = 0
width = 0
with open(FLAGS.image_file, 'rb') as img:
Expand All @@ -32,28 +33,41 @@ def main(_):

with tf.Graph().as_default():
with tf.Session().as_default() as sess:

# Read image data.
image_preprocessing_fn, _ = preprocessing_factory.get_preprocessing(
FLAGS.loss_model,
is_training=False)
image = reader.get_image(FLAGS.image_file, height, width, image_preprocessing_fn)

# Add batch dimension
image = tf.expand_dims(image, 0)

generated = model.net(image, training=False)
generated = tf.cast(generated, tf.uint8)

# Remove batch dimension
generated = tf.squeeze(generated, [0])

# Restore model variables.
saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V1)
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
# Use absolute path
FLAGS.model_file = os.path.abspath(FLAGS.model_file)
saver.restore(sess, FLAGS.model_file)

start_time = time.time()
generated = sess.run(generated)
generated = tf.cast(generated, tf.uint8)
end_time = time.time()
tf.logging.info('Elapsed time: %fs' % (end_time - start_time))
# Make sure 'generated' directory exists.
generated_file = 'generated/res.jpg'
if os.path.exists('generated') is False:
os.makedirs('generated')

# Generate and write image data to file.
with open(generated_file, 'wb') as img:
start_time = time.time()
img.write(sess.run(tf.image.encode_jpeg(generated)))
end_time = time.time()
tf.logging.info('Elapsed time: %fs' % (end_time - start_time))

tf.logging.info('Done. Please check %s.' % generated_file)


Expand Down
14 changes: 13 additions & 1 deletion losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,19 @@ def get_style_features(FLAGS):
FLAGS.loss_model,
is_training=False)

# Get the style image data
size = FLAGS.image_size
img_bytes = tf.read_file(FLAGS.style_image)
if FLAGS.style_image.lower().endswith('png'):
image = tf.image.decode_png(img_bytes)
else:
image = tf.image.decode_jpeg(img_bytes)
# image = _aspect_preserving_resize(image, size)
images = tf.stack([image_preprocessing_fn(image, size, size)])

# Add the batch dimension
images = tf.expand_dims(image_preprocessing_fn(image, size, size), 0)
# images = tf.stack([image_preprocessing_fn(image, size, size)])

_, endpoints_dict = network_fn(images, spatial_squeeze=False)
features = []
for layer in FLAGS.style_layers:
Expand All @@ -52,16 +57,23 @@ def get_style_features(FLAGS):
features.append(feature)

with tf.Session() as sess:
# Restore variables for loss network.
init_func = utils._get_init_fn(FLAGS)
init_func(sess)

# Make sure the 'generated' directory is exists.
if os.path.exists('generated') is False:
os.makedirs('generated')
# Indicate cropped style image path
save_file = 'generated/target_style_' + FLAGS.naming + '.jpg'
# Write preprocessed style image to indicated path
with open(save_file, 'wb') as f:
target_image = image_unprocessing_fn(images[0, :])
value = tf.image.encode_jpeg(tf.cast(target_image, tf.uint8))
f.write(sess.run(value))
tf.logging.info('Target style pattern is saved to: %s.' % save_file)

# Return the features those layers are use for measuring style loss.
return sess.run(features)


Expand Down
33 changes: 20 additions & 13 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def conv2d(x, input_filters, output_filters, kernel, strides, mode='REFLECT'):
with tf.variable_scope('conv') as scope:
with tf.variable_scope('conv'):

shape = [kernel, kernel, input_filters, output_filters]
weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='weight')
Expand All @@ -11,7 +11,7 @@ def conv2d(x, input_filters, output_filters, kernel, strides, mode='REFLECT'):


def conv2d_transpose(x, input_filters, output_filters, kernel, strides):
with tf.variable_scope('conv_transpose') as scope:
with tf.variable_scope('conv_transpose'):

shape = [kernel, kernel, output_filters, input_filters]
weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='weight')
Expand All @@ -32,7 +32,7 @@ def resize_conv2d(x, input_filters, output_filters, kernel, strides, training):
through tf.image.resize_images, but we only know that for fixed image size, so we
plumb through a "training" argument
'''
with tf.variable_scope('conv_transpose') as scope:
with tf.variable_scope('conv_transpose'):
height = x.get_shape()[1].value if training else tf.shape(x)[1]
width = x.get_shape()[2].value if training else tf.shape(x)[2]

Expand Down Expand Up @@ -75,10 +75,17 @@ def population_statistics():
return tf.cond(training, batch_statistics, population_statistics)


def relu(input):
relu = tf.nn.relu(input)
# convert nan to zero
nan_to_zero = tf.where(tf.is_nan(relu), tf.zeros_like(relu), relu)
return nan_to_zero


def residual(x, filters, kernel, strides):
with tf.variable_scope('residual') as scope:
with tf.variable_scope('residual'):
conv1 = conv2d(x, filters, filters, kernel, strides)
conv2 = conv2d(tf.nn.relu(conv1), filters, filters, kernel, strides)
conv2 = conv2d(relu(conv1), filters, filters, kernel, strides)

residual = x + conv2

Expand All @@ -90,11 +97,11 @@ def net(image, training):
image = tf.pad(image, [[0, 0], [10, 10], [10, 10], [0, 0]], mode='REFLECT')

with tf.variable_scope('conv1'):
conv1 = tf.nn.relu(instance_norm(conv2d(image, 3, 32, 9, 1)))
conv1 = relu(instance_norm(conv2d(image, 3, 32, 9, 1)))
with tf.variable_scope('conv2'):
conv2 = tf.nn.relu(instance_norm(conv2d(conv1, 32, 64, 3, 2)))
conv2 = relu(instance_norm(conv2d(conv1, 32, 64, 3, 2)))
with tf.variable_scope('conv3'):
conv3 = tf.nn.relu(instance_norm(conv2d(conv2, 64, 128, 3, 2)))
conv3 = relu(instance_norm(conv2d(conv2, 64, 128, 3, 2)))
with tf.variable_scope('res1'):
res1 = residual(conv3, 128, 3, 1)
with tf.variable_scope('res2'):
Expand All @@ -107,13 +114,13 @@ def net(image, training):
res5 = residual(res4, 128, 3, 1)
# print(res5.get_shape())
with tf.variable_scope('deconv1'):
# deconv1 = tf.nn.relu(instance_norm(conv2d_transpose(res5, 128, 64, 3, 2)))
deconv1 = tf.nn.relu(instance_norm(resize_conv2d(res5, 128, 64, 3, 2, training)))
# deconv1 = relu(instance_norm(conv2d_transpose(res5, 128, 64, 3, 2)))
deconv1 = relu(instance_norm(resize_conv2d(res5, 128, 64, 3, 2, training)))
with tf.variable_scope('deconv2'):
# deconv2 = tf.nn.relu(instance_norm(conv2d_transpose(deconv1, 64, 32, 3, 2)))
deconv2 = tf.nn.relu(instance_norm(resize_conv2d(deconv1, 64, 32, 3, 2, training)))
# deconv2 = relu(instance_norm(conv2d_transpose(deconv1, 64, 32, 3, 2)))
deconv2 = relu(instance_norm(resize_conv2d(deconv1, 64, 32, 3, 2, training)))
with tf.variable_scope('deconv3'):
# deconv_test = tf.nn.relu(instance_norm(conv2d(deconv2, 32, 32, 2, 1)))
# deconv_test = relu(instance_norm(conv2d(deconv2, 32, 32, 2, 1)))
deconv3 = tf.nn.tanh(instance_norm(conv2d(deconv2, 32, 3, 9, 1)))

y = (deconv3 + 1) * 127.5
Expand Down
13 changes: 12 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def parse_args():

def main(FLAGS):
style_features_t = losses.get_style_features(FLAGS)

# Make sure the training path exists.
training_path = os.path.join(FLAGS.model_path, FLAGS.naming)
if not(os.path.exists(training_path)):
os.makedirs(training_path)
Expand All @@ -46,6 +48,8 @@ def main(FLAGS):
]
processed_generated = tf.stack(processed_generated)
_, endpoints_dict = network_fn(tf.concat([processed_generated, processed_images], 0), spatial_squeeze=False)

# Log the structure of loss network
tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):')
for key in endpoints_dict:
tf.logging.info(key)
Expand All @@ -57,6 +61,7 @@ def main(FLAGS):

loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss

# Add Summary for visualization in tensorboard.
"""Add Summary"""
tf.summary.scalar('losses/content_loss', content_loss)
tf.summary.scalar('losses/style_loss', style_loss)
Expand All @@ -66,6 +71,7 @@ def main(FLAGS):
tf.summary.scalar('weighted_losses/weighted_style_loss', style_loss * FLAGS.style_weight)
tf.summary.scalar('weighted_losses/weighted_regularizer_loss', tv_loss * FLAGS.tv_weight)
tf.summary.scalar('total_loss', loss)

for layer in FLAGS.style_layers:
tf.summary.scalar('style_losses/' + layer, style_loss_summary[layer])
tf.summary.image('generated', generated)
Expand All @@ -78,21 +84,26 @@ def main(FLAGS):

"""Prepare to Train"""
global_step = tf.Variable(0, name="global_step", trainable=False)

variable_to_train = []
for variable in tf.trainable_variables():
if not(variable.name.startswith(FLAGS.loss_model)):
variable_to_train.append(variable)

train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train)

variables_to_restore = []
for v in tf.global_variables():
if not(v.name.startswith(FLAGS.loss_model)):
variables_to_restore.append(v)
saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1)

sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

# Restore variables for loss network.
init_func = utils._get_init_fn(FLAGS)
init_func(sess)

# Restore variables for training model if the checkpoint file exists.
last_file = tf.train.latest_checkpoint(training_path)
if last_file:
tf.logging.info('Restoring model from {}'.format(last_file))
Expand Down

0 comments on commit f2149fa

Please sign in to comment.