This repository has been archived by the owner on May 21, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 51
/
layers.py
73 lines (55 loc) · 2.15 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from keras.engine.topology import Layer
import numpy as np
import tensorflow as tf
custom_layers = {}
class ImageRescale(Layer):
def __init__(self, scale, method=tf.image.ResizeMethod.BICUBIC,
trainable=False, **kwargs):
self.scale = scale
self.method = method
super().__init__(trainable=trainable, **kwargs)
def compute_size(self, shape):
size = np.array(shape)[[1, 2]] * self.scale
return tuple(size.astype(int))
def call(self, x):
size = self.compute_size(x.shape.as_list())
return tf.image.resize_images(x, size, method=self.method)
def compute_output_shape(self, input_shape):
size = self.compute_size(input_shape)
return (input_shape[0], *size, input_shape[3])
def get_config(self):
config = super().get_config()
config['scale'] = self.scale
config['method'] = self.method
return config
custom_layers['ImageRescale'] = ImageRescale
class Conv2DSubPixel(Layer):
"""Sub-pixel convolution layer.
See https://arxiv.org/abs/1609.05158
"""
def __init__(self, scale, trainable=False, **kwargs):
self.scale = scale
super().__init__(trainable=trainable, **kwargs)
def call(self, t):
r = self.scale
shape = t.shape.as_list()
new_shape = self.compute_output_shape(shape)
H, W = shape[1:3]
C = new_shape[-1]
t = tf.reshape(t, [-1, H, W, r, r, C])
# Here we are different from Equation 4 from the paper. That equation
# is equivalent to switching 3 and 4 in `perm`. But I feel my
# implementation is more natural.
t = tf.transpose(t, perm=[0, 1, 3, 2, 4, 5]) # S, H, r, H, r, C
t = tf.reshape(t, [-1, H * r, W * r, C])
return t
def compute_output_shape(self, input_shape):
r = self.scale
H, W, rrC = np.array(input_shape[1:])
assert rrC % (r ** 2) == 0
return (input_shape[0], H * r, W * r, rrC // (r ** 2))
def get_config(self):
config = super().get_config()
config['scale'] = self.scale
return config
custom_layers['Conv2DSubPixel'] = Conv2DSubPixel