Skip to content

Commit aba1291

Browse files
committedJun 28, 2022
update
1 parent 3fb86ce commit aba1291

File tree

3 files changed

+181
-96
lines changed

3 files changed

+181
-96
lines changed
 

‎srgan.py

+159-84
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from tensorlayerx.nn import Module
22
import tensorlayerx as tlx
3-
from tensorlayerx.nn import Conv2d, BatchNorm2d,Elementwise,SubpixelConv2d, UpSampling2d, Flatten, Sequential
3+
from tensorlayerx.nn import Conv2d, BatchNorm2d, Elementwise, SubpixelConv2d, UpSampling2d, Flatten, Sequential
44
from tensorlayerx.nn import Linear, MaxPool2d
55

66
W_init = tlx.initializers.TruncatedNormal(stddev=0.02)
@@ -11,10 +11,16 @@ class ResidualBlock(Module):
1111

1212
def __init__(self):
1313
super(ResidualBlock, self).__init__()
14-
self.conv1 = Conv2d(out_channels=64, kernel_size=(3,3), stride=(1,1), act=None, padding='SAME', W_init=W_init, b_init = None)
15-
self.bn1 = BatchNorm2d(num_features=64, act=tlx.ReLU, gamma_init=G_init)
16-
self.conv2 = Conv2d(out_channels=64, kernel_size=(3,3), stride=(1,1), act=None, padding='SAME', W_init=W_init, b_init = None)
17-
self.bn2 = BatchNorm2d(num_features=64, act=None, gamma_init=G_init)
14+
self.conv1 = Conv2d(
15+
out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
16+
data_format='channels_first', b_init=None
17+
)
18+
self.bn1 = BatchNorm2d(num_features=64, act=tlx.ReLU, gamma_init=G_init, data_format='channels_first')
19+
self.conv2 = Conv2d(
20+
out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
21+
data_format='channels_first', b_init=None
22+
)
23+
self.bn2 = BatchNorm2d(num_features=64, act=None, gamma_init=G_init, data_format='channels_first')
1824

1925
def forward(self, x):
2026
z = self.conv1(x)
@@ -24,21 +30,30 @@ def forward(self, x):
2430
x = x + z
2531
return x
2632

33+
2734
class SRGAN_g(Module):
2835
""" Generator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
2936
feature maps (n) and stride (s) feature maps (n) and stride (s)
3037
"""
38+
3139
def __init__(self):
32-
super(SRGAN_g,self).__init__()
33-
self.conv1 = Conv2d(out_channels=64, kernel_size=(3,3), stride=(1,1), act=tlx.ReLU, padding='SAME', W_init=W_init)
40+
super(SRGAN_g, self).__init__()
41+
self.conv1 = Conv2d(
42+
out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME', W_init=W_init,
43+
data_format='channels_first'
44+
)
3445
self.residual_block = self.make_layer()
35-
self.conv2 = Conv2d(out_channels=64, kernel_size=(3,3), stride=(1,1),padding='SAME', W_init=W_init, b_init = None)
36-
self.bn1 = BatchNorm2d(num_features=64, act=None, gamma_init=G_init)
37-
self.conv3 = Conv2d(out_channels=256, kernel_size=(3,3), stride=(1,1),padding='SAME', W_init = W_init)
38-
self.subpiexlconv1 = SubpixelConv2d(scale=2, act = tlx.ReLU)
39-
self.conv4 = Conv2d(out_channels=256, kernel_size=(3,3), stride=(1,1), padding='SAME', W_init=W_init)
40-
self.subpiexlconv2 = SubpixelConv2d(scale=2, act = tlx.ReLU)
41-
self.conv5 = Conv2d(3, kernel_size=(1,1), stride=(1,1), act=tlx.Tanh, padding='SAME', W_init=W_init)
46+
self.conv2 = Conv2d(
47+
out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init,
48+
data_format='channels_first', b_init=None
49+
)
50+
self.bn1 = BatchNorm2d(num_features=64, act=None, gamma_init=G_init, data_format='channels_first')
51+
self.conv3 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init, data_format='channels_first')
52+
self.subpiexlconv1 = SubpixelConv2d(data_format='channels_first', scale=2, act=tlx.ReLU)
53+
self.conv4 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init, data_format='channels_first')
54+
self.subpiexlconv2 = SubpixelConv2d(data_format='channels_first', scale=2, act=tlx.ReLU)
55+
self.conv5 = Conv2d(3, kernel_size=(1, 1), stride=(1, 1), act=tlx.Tanh, padding='SAME', W_init=W_init, data_format='channels_first')
56+
4257
def make_layer(self):
4358
layer_list = []
4459
for i in range(16):
@@ -61,7 +76,6 @@ def forward(self, x):
6176
return x
6277

6378

64-
6579
class SRGAN_g2(Module):
6680
""" Generator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
6781
feature maps (n) and stride (s) feature maps (n) and stride (s)
@@ -70,22 +84,34 @@ class SRGAN_g2(Module):
7084
7185
Use Resize Conv
7286
"""
87+
7388
def __init__(self):
74-
super(SRGAN_g2,self).__init__()
75-
self.conv1 = Conv2d(out_channels=64, kernel_size=(3,3), stride=(1,1), act=None, padding='SAME', W_init=W_init)
89+
super(SRGAN_g2, self).__init__()
90+
self.conv1 = Conv2d(
91+
out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
92+
data_format='channels_first'
93+
)
7694
self.residual_block = self.make_layer()
77-
self.conv2 = Conv2d(out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init,
78-
b_init=None)
79-
self.bn1 = BatchNorm2d(act=None, gamma_init=G_init)
80-
self.upsample1 = UpSampling2d(scale=(2,2), method='bilinear')
81-
self.conv3 = Conv2d(out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init,
82-
b_init=None)
83-
self.bn2 = BatchNorm2d(act= tlx.ReLU, gamma_init=G_init)
84-
self.upsample2 = UpSampling2d(scale=(4,4),method='bilinear')
85-
self.conv4 = Conv2d(out_channels=32, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init,
86-
b_init=None)
87-
self.bn3 = BatchNorm2d(act = tlx.ReLU, gamma_init=G_init)
88-
self.conv5 = Conv2d(out_channels=3, kernel_size=(1,1), stride=(1,1), act = tlx.Tanh, padding='SAME', W_init=W_init)
95+
self.conv2 = Conv2d(
96+
out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init,
97+
data_format='channels_first', b_init=None
98+
)
99+
self.bn1 = BatchNorm2d(act=None, gamma_init=G_init, data_format='channels_first')
100+
self.upsample1 = UpSampling2d(data_format='channels_first', scale=(2, 2), method='bilinear')
101+
self.conv3 = Conv2d(
102+
out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init,
103+
data_format='channels_first', b_init=None
104+
)
105+
self.bn2 = BatchNorm2d(act=tlx.ReLU, gamma_init=G_init, data_format='channels_first')
106+
self.upsample2 = UpSampling2d(data_format='channels_first', scale=(4, 4), method='bilinear')
107+
self.conv4 = Conv2d(
108+
out_channels=32, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init,
109+
data_format='channels_first', b_init=None
110+
)
111+
self.bn3 = BatchNorm2d(act=tlx.ReLU, gamma_init=G_init, data_format='channels_first')
112+
self.conv5 = Conv2d(
113+
out_channels=3, kernel_size=(1, 1), stride=(1, 1), act=tlx.Tanh, padding='SAME', W_init=W_init
114+
)
89115

90116
def make_layer(self):
91117
layer_list = []
@@ -109,27 +135,53 @@ def forward(self, x):
109135
x = self.conv5(x)
110136
return x
111137

138+
112139
class SRGAN_d2(Module):
113140
""" Discriminator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
114141
feature maps (n) and stride (s) feature maps (n) and stride (s)
115142
"""
143+
116144
def __init__(self, ):
117145
super(SRGAN_d2, self).__init__()
118-
self.conv1 = Conv2d(out_channels=64, kernel_size=(3,3), stride=(1,1), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME', W_init=W_init)
119-
self.conv2 = Conv2d(out_channels=64, kernel_size=(3,3), stride=(2,2), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME', W_init=W_init, b_init=None)
120-
self.bn1 = BatchNorm2d( gamma_init=G_init)
121-
self.conv3 = Conv2d(out_channels=128, kernel_size=(3,3), stride=(1,1), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME', W_init=W_init, b_init=None)
122-
self.bn2 = BatchNorm2d( gamma_init=G_init)
123-
self.conv4 = Conv2d(out_channels=128, kernel_size=(3,3), stride=(2,2), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME', W_init=W_init, b_init=None)
124-
self.bn3 = BatchNorm2d(gamma_init=G_init)
125-
self.conv5 = Conv2d(out_channels=256, kernel_size=(3,3), stride=(1,1), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME', W_init=W_init, b_init=None)
126-
self.bn4 = BatchNorm2d( gamma_init=G_init)
127-
self.conv6 = Conv2d(out_channels=256, kernel_size=(3,3), stride=(2,2), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME', W_init=W_init, b_init=None)
128-
self.bn5 = BatchNorm2d( gamma_init=G_init)
129-
self.conv7 = Conv2d(out_channels=512, kernel_size=(3,3), stride=(1,1), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME', W_init=W_init, b_init=None)
130-
self.bn6 = BatchNorm2d( gamma_init=G_init)
131-
self.conv8 = Conv2d(out_channels=512, kernel_size=(3,3), stride=(2,2), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME', W_init=W_init, b_init=None)
132-
self.bn7 = BatchNorm2d( gamma_init=G_init)
146+
self.conv1 = Conv2d(
147+
out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
148+
W_init=W_init, data_format='channels_first'
149+
)
150+
self.conv2 = Conv2d(
151+
out_channels=64, kernel_size=(3, 3), stride=(2, 2), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
152+
W_init=W_init, data_format='channels_first', b_init=None
153+
)
154+
self.bn1 = BatchNorm2d(gamma_init=G_init, data_format='channels_first')
155+
self.conv3 = Conv2d(
156+
out_channels=128, kernel_size=(3, 3), stride=(1, 1), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
157+
W_init=W_init, data_format='channels_first', b_init=None
158+
)
159+
self.bn2 = BatchNorm2d(gamma_init=G_init, data_format='channels_first')
160+
self.conv4 = Conv2d(
161+
out_channels=128, kernel_size=(3, 3), stride=(2, 2), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
162+
W_init=W_init, data_format='channels_first', b_init=None
163+
)
164+
self.bn3 = BatchNorm2d(gamma_init=G_init, data_format='channels_first')
165+
self.conv5 = Conv2d(
166+
out_channels=256, kernel_size=(3, 3), stride=(1, 1), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
167+
W_init=W_init, data_format='channels_first', b_init=None
168+
)
169+
self.bn4 = BatchNorm2d(gamma_init=G_init, data_format='channels_first')
170+
self.conv6 = Conv2d(
171+
out_channels=256, kernel_size=(3, 3), stride=(2, 2), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
172+
W_init=W_init, data_format='channels_first', b_init=None
173+
)
174+
self.bn5 = BatchNorm2d(gamma_init=G_init, data_format='channels_first')
175+
self.conv7 = Conv2d(
176+
out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
177+
W_init=W_init, data_format='channels_first', b_init=None
178+
)
179+
self.bn6 = BatchNorm2d(gamma_init=G_init, data_format='channels_first')
180+
self.conv8 = Conv2d(
181+
out_channels=512, kernel_size=(3, 3), stride=(2, 2), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
182+
W_init=W_init, data_format='channels_first', b_init=None
183+
)
184+
self.bn7 = BatchNorm2d(gamma_init=G_init, data_format='channels_first')
133185
self.flat = Flatten()
134186
self.dense1 = Linear(out_features=1024, act=tlx.LeakyReLU(negative_slope=0.2))
135187
self.dense2 = Linear(out_features=1)
@@ -158,43 +210,67 @@ def forward(self, x):
158210
return n, logits
159211

160212

161-
162-
163213
class SRGAN_d(Module):
164214

165-
def __init__(self, dim = 64):
166-
super(SRGAN_d,self).__init__()
167-
self.conv1 = Conv2d(out_channels=dim, kernel_size=(4,4), stride=(2,2), act=tlx.LeakyReLU, padding='SAME', W_init=W_init)
168-
self.conv2 = Conv2d(out_channels=dim * 2, kernel_size=(4,4), stride=(2,2), act=None, padding='SAME', W_init=W_init, b_init=None)
169-
self.bn1 = BatchNorm2d(num_features=dim * 2, act=tlx.LeakyReLU, gamma_init=G_init)
170-
self.conv3 = Conv2d(out_channels=dim * 4, kernel_size=(4,4), stride=(2,2), act=None, padding='SAME', W_init=W_init, b_init=None)
171-
self.bn2 = BatchNorm2d(num_features=dim * 4,act=tlx.LeakyReLU, gamma_init=G_init)
172-
self.conv4 = Conv2d(out_channels=dim * 8, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME',W_init=W_init, b_init=None)
173-
self.bn3 = BatchNorm2d(num_features=dim * 8, act=tlx.LeakyReLU, gamma_init=G_init)
174-
self.conv5 = Conv2d(out_channels=dim * 16, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME',
175-
W_init=W_init, b_init=None)
176-
self.bn4 = BatchNorm2d(num_features=dim * 16, act=tlx.LeakyReLU, gamma_init=G_init)
177-
self.conv6 = Conv2d(out_channels=dim * 32, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME',
178-
W_init=W_init, b_init=None)
179-
self.bn5 = BatchNorm2d(num_features=dim * 32,act=tlx.LeakyReLU, gamma_init=G_init)
180-
self.conv7 = Conv2d(out_channels=dim * 16, kernel_size=(1, 1), stride=(1, 1), act=None, padding='SAME',
181-
W_init=W_init, b_init=None)
182-
self.bn6 = BatchNorm2d(num_features=dim * 16,act=tlx.LeakyReLU, gamma_init=G_init)
183-
self.conv8 = Conv2d(out_channels=dim * 8, kernel_size=(1, 1), stride=(1, 1), act=None, padding='SAME',
184-
W_init=W_init, b_init=None)
185-
self.bn7 = BatchNorm2d(num_features=dim * 8,act=None, gamma_init=G_init)
186-
self.conv9 = Conv2d(out_channels=dim * 2, kernel_size=(1, 1), stride=(1, 1), act=None, padding='SAME',
187-
W_init=W_init, b_init=None)
188-
self.bn8 = BatchNorm2d(num_features=dim * 2,act=tlx.LeakyReLU, gamma_init=G_init)
189-
self.conv10 = Conv2d(out_channels=dim * 2, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME',
190-
W_init=W_init, b_init=None)
191-
self.bn9 = BatchNorm2d(num_features=dim * 2,act=tlx.LeakyReLU, gamma_init=G_init)
192-
self.conv11 = Conv2d(out_channels=dim * 8, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME',
193-
W_init=W_init, b_init=None)
194-
self.bn10 = BatchNorm2d(num_features=dim * 8, gamma_init=G_init)
215+
def __init__(self, dim=64):
216+
super(SRGAN_d, self).__init__()
217+
self.conv1 = Conv2d(
218+
out_channels=dim, kernel_size=(4, 4), stride=(2, 2), act=tlx.LeakyReLU, padding='SAME', W_init=W_init,
219+
data_format='channels_first'
220+
)
221+
self.conv2 = Conv2d(
222+
out_channels=dim * 2, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init,
223+
data_format='channels_first', b_init=None
224+
)
225+
self.bn1 = BatchNorm2d(num_features=dim * 2, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
226+
self.conv3 = Conv2d(
227+
out_channels=dim * 4, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init,
228+
data_format='channels_first', b_init=None
229+
)
230+
self.bn2 = BatchNorm2d(num_features=dim * 4, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
231+
self.conv4 = Conv2d(
232+
out_channels=dim * 8, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init,
233+
data_format='channels_first', b_init=None
234+
)
235+
self.bn3 = BatchNorm2d(num_features=dim * 8, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
236+
self.conv5 = Conv2d(
237+
out_channels=dim * 16, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init,
238+
data_format='channels_first', b_init=None
239+
)
240+
self.bn4 = BatchNorm2d(num_features=dim * 16, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
241+
self.conv6 = Conv2d(
242+
out_channels=dim * 32, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init,
243+
data_format='channels_first', b_init=None
244+
)
245+
self.bn5 = BatchNorm2d(num_features=dim * 32, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
246+
self.conv7 = Conv2d(
247+
out_channels=dim * 16, kernel_size=(1, 1), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
248+
data_format='channels_first', b_init=None
249+
)
250+
self.bn6 = BatchNorm2d(num_features=dim * 16, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
251+
self.conv8 = Conv2d(
252+
out_channels=dim * 8, kernel_size=(1, 1), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
253+
data_format='channels_first', b_init=None
254+
)
255+
self.bn7 = BatchNorm2d(num_features=dim * 8, act=None, gamma_init=G_init, data_format='channels_first')
256+
self.conv9 = Conv2d(
257+
out_channels=dim * 2, kernel_size=(1, 1), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
258+
data_format='channels_first', b_init=None
259+
)
260+
self.bn8 = BatchNorm2d(num_features=dim * 2, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
261+
self.conv10 = Conv2d(
262+
out_channels=dim * 2, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
263+
data_format='channels_first', b_init=None
264+
)
265+
self.bn9 = BatchNorm2d(num_features=dim * 2, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
266+
self.conv11 = Conv2d(
267+
out_channels=dim * 8, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
268+
data_format='channels_first', b_init=None
269+
)
270+
self.bn10 = BatchNorm2d(num_features=dim * 8, gamma_init=G_init, data_format='channels_first')
195271
self.add = Elementwise(combine_fn=tlx.add, act=tlx.LeakyReLU)
196272
self.flat = Flatten()
197-
self.dense = Linear(out_features=1, W_init=W_init)
273+
self.dense = Linear(out_features=1, W_init=W_init)
198274

199275
def forward(self, x):
200276

@@ -227,37 +303,36 @@ def forward(self, x):
227303
return x
228304

229305

230-
231306
class Vgg19_simple_api(Module):
232307

233308
def __init__(self):
234-
super(Vgg19_simple_api,self).__init__()
309+
super(Vgg19_simple_api, self).__init__()
235310
""" conv1 """
236311
self.conv1 = Conv2d(out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
237312
self.conv2 = Conv2d(out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
238-
self.maxpool1 = MaxPool2d(kernel_size=(2,2), stride=(2,2), padding='SAME')
313+
self.maxpool1 = MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding='SAME')
239314
""" conv2 """
240315
self.conv3 = Conv2d(out_channels=128, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
241316
self.conv4 = Conv2d(out_channels=128, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
242-
self.maxpool2 = MaxPool2d(kernel_size=(2,2), stride=(2,2), padding='SAME')
317+
self.maxpool2 = MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding='SAME')
243318
""" conv3 """
244319
self.conv5 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
245320
self.conv6 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
246321
self.conv7 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
247322
self.conv8 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
248-
self.maxpool3 = MaxPool2d(kernel_size=(2,2), stride=(2,2), padding='SAME')
323+
self.maxpool3 = MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding='SAME')
249324
""" conv4 """
250325
self.conv9 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
251326
self.conv10 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
252327
self.conv11 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
253328
self.conv12 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
254-
self.maxpool4 = MaxPool2d(kernel_size=(2,2), stride=(2,2), padding='SAME') # (batch_size, 14, 14, 512)
329+
self.maxpool4 = MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding='SAME') # (batch_size, 14, 14, 512)
255330
""" conv5 """
256331
self.conv13 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
257332
self.conv14 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
258333
self.conv15 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
259334
self.conv16 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
260-
self.maxpool5 = MaxPool2d(kernel_size=(2,2), stride=(2,2), padding='SAME') # (batch_size, 7, 7, 512)
335+
self.maxpool5 = MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding='SAME') # (batch_size, 7, 7, 512)
261336
""" fc 6~8 """
262337
self.flat = Flatten()
263338
self.dense1 = Linear(out_features=4096, act=tlx.ReLU)

‎train.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
import os
2-
os.environ['TL_BACKEND'] = 'tensorflow' # Just modify this line, easily change to any framework! PyTorch will coming soon!
2+
# os.environ['TL_BACKEND'] = 'tensorflow' # Just modify this line, easily switch to any framework! PyTorch will coming soon!
33
# os.environ['TL_BACKEND'] = 'mindspore'
44
# os.environ['TL_BACKEND'] = 'paddle'
5+
os.environ['TL_BACKEND'] = 'torch'
56
import time
67
import numpy as np
78
import tensorlayerx as tlx
89
from tensorlayerx.dataflow import Dataset, DataLoader
910
from srgan import SRGAN_g, SRGAN_d
1011
from config import config
11-
from tensorlayerx.vision.transforms import Compose, RandomCrop, Normalize, RandomFlipHorizontal, Resize
12+
from tensorlayerx.vision.transforms import Compose, RandomCrop, Normalize, RandomFlipHorizontal, Resize, HWC2CHW
1213
import vgg
1314
from tensorlayerx.model import TrainOneStep
1415
from tensorlayerx.nn import Module
1516
import cv2
17+
tlx.set_device('GPU')
1618

1719
###====================== HYPER-PARAMETERS ===========================###
1820
batch_size = 8
@@ -28,14 +30,16 @@
2830
RandomCrop(size=(384, 384)),
2931
RandomFlipHorizontal(),
3032
])
31-
nor = Normalize(mean=(127.5), std=(127.5), data_format='HWC')
33+
nor = Compose([Normalize(mean=(127.5), std=(127.5), data_format='HWC'),
34+
HWC2CHW()])
3235
lr_transform = Resize(size=(96, 96))
3336

37+
train_hr_imgs = tlx.vision.load_images(path=config.TRAIN.hr_img_path, n_threads = 32)
3438

3539
class TrainData(Dataset):
3640

3741
def __init__(self, hr_trans=hr_transform, lr_trans=lr_transform):
38-
self.train_hr_imgs = tlx.vision.load_images(path=config.TRAIN.hr_img_path)
42+
self.train_hr_imgs = train_hr_imgs
3943
self.hr_trans = hr_trans
4044
self.lr_trans = lr_trans
4145

@@ -104,12 +108,12 @@ def forward(self, lr, hr):
104108

105109
G = SRGAN_g()
106110
D = SRGAN_d()
107-
VGG = vgg.VGG19(pretrained=False, end_with='pool4', mode='dynamic')
111+
VGG = vgg.VGG19(pretrained=True, end_with='pool4', mode='dynamic')
108112
# automatic init layers weights shape with input tensor.
109113
# Calculating and filling 'in_channels' of each layer is a very troublesome thing.
110114
# So, just use 'init_build' with input shape. 'in_channels' of each layer will be automaticlly set.
111-
G.init_build(tlx.nn.Input(shape=(8, 96, 96, 3)))
112-
D.init_build(tlx.nn.Input(shape=(8, 384, 384, 3)))
115+
G.init_build(tlx.nn.Input(shape=(8, 3, 96, 96)))
116+
D.init_build(tlx.nn.Input(shape=(8, 3, 384, 384)))
113117

114118

115119
def train():
@@ -176,15 +180,17 @@ def evaluate():
176180

177181

178182
valid_lr_img_tensor = np.asarray(valid_lr_img_tensor, dtype=np.float32)
183+
valid_lr_img_tensor = np.transpose(valid_lr_img_tensor,axes=[2, 0, 1])
179184
valid_lr_img_tensor = valid_lr_img_tensor[np.newaxis, :, :, :]
180185
valid_lr_img_tensor= tlx.ops.convert_to_tensor(valid_lr_img_tensor)
181186
size = [valid_lr_img.shape[0], valid_lr_img.shape[1]]
182187

183188
out = tlx.ops.convert_to_numpy(G(valid_lr_img_tensor))
184189
out = np.asarray((out + 1) * 127.5, dtype=np.uint8)
190+
out = np.transpose(out[0], axes=[1, 2, 0])
185191
print("LR size: %s / generated HR size: %s" % (size, out.shape)) # LR size: (339, 510, 3) / gen HR size: (1, 1356, 2040, 3)
186192
print("[*] save images")
187-
tlx.vision.save_image(out[0], file_name='valid_gen.png', path=save_dir)
193+
tlx.vision.save_image(out, file_name='valid_gen.png', path=save_dir)
188194
tlx.vision.save_image(valid_lr_img, file_name='valid_lr.png', path=save_dir)
189195
tlx.vision.save_image(valid_hr_img, file_name='valid_hr.png', path=save_dir)
190196
out_bicu = cv2.resize(valid_lr_img, dsize = [size[1] * 4, size[0] * 4], interpolation = cv2.INTER_CUBIC)

‎vgg.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def forward(self, inputs):
103103
"""
104104

105105
# inputs = inputs * 255 - np.array([123.68, 116.779, 103.939], dtype=np.float32).reshape([1, 1, 1, 3])
106-
inputs = inputs * 255. - tlx.convert_to_tensor(np.array([123.68, 116.779, 103.939], dtype=np.float32))
106+
inputs = inputs * 255. - tlx.convert_to_tensor(np.array([123.68, 116.779, 103.939], dtype=np.float32).reshape(-1,1,1))
107107
out = self.make_layer(inputs)
108108
return out
109109

@@ -126,18 +126,18 @@ def make_layers(config, batch_norm=False, end_with='outputs'):
126126
layer_list.append(
127127
Conv2d(
128128
out_channels=n_filter, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME',
129-
in_channels=in_channels, name=layer_name
129+
in_channels=in_channels, name=layer_name, data_format='channels_first'
130130
)
131131
)
132132
if batch_norm:
133-
layer_list.append(BatchNorm(num_features=n_filter))
133+
layer_list.append(BatchNorm(num_features=n_filter, data_format='channels_first'))
134134
if layer_name == end_with:
135135
is_end = True
136136
break
137137
else:
138138
layer_name = layer_names[layer_group_idx]
139139
if layer_group == 'M':
140-
layer_list.append(MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding='SAME', name=layer_name))
140+
layer_list.append(MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding='SAME', name=layer_name, data_format='channels_first'))
141141
elif layer_group == 'O':
142142
layer_list.append(Linear(out_features=1000, in_features=4096, name=layer_name))
143143
elif layer_group == 'F':
@@ -175,6 +175,10 @@ def restore_model(model, layer_type):
175175
if len(model.all_weights) == len(weights):
176176
break
177177
# assign weight values
178+
if tlx.BACKEND != 'tensorflow':
179+
for i in range(len(weights)):
180+
if len(weights[i].shape) == 4:
181+
weights[i] = np.transpose(weights[i], axes=[3, 2, 0, 1])
178182
assign_weights(weights, model)
179183
del weights
180184

0 commit comments

Comments
 (0)
Please sign in to comment.