Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 54 additions & 11 deletions StyleTransfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,14 @@ def main():
parser = build_parser()
args = parser.parse_args()
content_image = load_img_preprocess(args.content)
style_image = load_img_preprocess(args.style)

#style_image = load_img_preprocess(args.style)
style_paths = args.style
style_images = []

# make an array and add all style images produced from style paths given as input
for i in style_paths:
# call the load and preprocess function to load the style images from the given paths
style_images.append(load_img_preprocess(i))
"""
get the model from keras basically lets us extract the layers
and their corresponding intermediate and batch outputs
Expand Down Expand Up @@ -96,12 +102,18 @@ def main():
for layer in main_model.layers:
layer.trainable = False

style_features = [style_layer[0]
for style_layer in main_model(style_image)[:5]]
style_features = []

for i in style_images:
style_features.append([style_layer[0]
for style_layer in main_model(i)[:5]])

content_features = [content_layer[0]
for content_layer in main_model(content_image)[5:]]

gram_style_features = [gram_matrix(style_feature) for style_feature in style_features]
gram_style_features = []

for i in style_features:
gram_style_features.append([gram_matrix(style_feature) for style_feature in i])

# Set initial image
init_image = load_img_preprocess(args.content)
Expand Down Expand Up @@ -142,8 +154,9 @@ def main():


style_layer_norm = 1/float(5)
for out, inter in zip(gram_style_features, style_out):
stylePoints = stylePoints + style_layer_norm*get_style_loss(inter[0],out)
for grams in gram_style_features:
for out, inter in zip(grams, style_out):
stylePoints = stylePoints + style_layer_norm*get_style_loss(inter[0],out)

contentPoints = contentPoints*contentPoints
stylePoints = stylePoints*stylePoints
Expand All @@ -170,22 +183,52 @@ def main():
plt.imsave(output_name, display_img)
plt.show()


# referece
# citing URL: https://keras.io/applications/#vgg19
def load_img_preprocess(image_path):
img_str = tf.read_file(image_path)

img_decode = tf.image.decode_jpeg(img_str, 3)

img = tf.cast(img_decode, tf.float32)

dim =512.0

height = tf.to_float(tf.shape(img)[1])

width = tf.to_float(tf.shape(img)[0])
scale = tf.cond(tf.greater(height, width), lambda: dim/width, lambda: dim/height)

print('this is the old height and width ', height, width)
scale = tf.cond(tf.greater(height, width), lambda: dim/width , lambda: dim/height)
print('this is the scale ', scale)

newHeight = tf.to_int32(height * scale)
newWidth = tf.to_int32(width * scale)
print('newheight and new width', newHeight, newWidth)

img = tf.image.resize_images(img, [newHeight, newWidth])

"""VGG_MEAN = [123.68, 116.78, 103.94] # This is R-G-B for Imagenet

img = tf.random_crop(img, [224, 224, 3])
means = tf.reshape(tf.constant(VGG_MEAN), [1, 1, 3])
img = img - means
"""
img = np.expand_dims(img, axis=0)
VGG_MEAN = [123.68, 116.78, 103.94] #RGB Values for ImageNet

VGG_MEAN = [123.68, 116.78, 103.94]

means = tf.reshape(tf.constant(VGG_MEAN), [1, 1, 3])
img = img - means
max_dim = 512
img = Image.open(image_path)
norm = max(img.size)
scale = max_dim/norm
img = img.resize((round(img.size[0]*scale), round(img.size[1]*scale)), Image.ANTIALIAS)
img = kp_image.img_to_array(img)
img = np.expand_dims(img, axis=0)
img = tf.keras.applications.vgg19.preprocess_input(img)

return img


Expand Down