Skip to content

Commit aa357f6

Browse files
committed
v_1.2.0
1 parent b601146 commit aa357f6

File tree

5 files changed

+590
-262
lines changed

5 files changed

+590
-262
lines changed

pro_gan_pytorch/CustomLayers.py

Lines changed: 392 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,392 @@
1+
""" Module containing custom layers """
2+
import torch as th
3+
import copy
4+
5+
6+
# extending Conv2D and Deconv2D layers for equalized learning rate logic
7+
class _equalized_conv2d(th.nn.Module):
8+
""" conv2d with the concept of equalized learning rate """
9+
10+
def __init__(self, c_in, c_out, k_size, stride=1, pad=0, initializer='kaiming', bias=True):
11+
"""
12+
constructor for the class
13+
:param c_in: input channels
14+
:param c_out: output channels
15+
:param k_size: kernel size (h, w) should be a tuple or a single integer
16+
:param stride: stride for conv
17+
:param pad: padding
18+
:param initializer: initializer. one of kaiming or xavier
19+
:param bias: whether to use bias or not
20+
"""
21+
super(_equalized_conv2d, self).__init__()
22+
self.conv = th.nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=True)
23+
if initializer == 'kaiming':
24+
th.nn.init.kaiming_normal_(self.conv.weight, a=th.nn.init.calculate_gain('conv2d'))
25+
elif initializer == 'xavier':
26+
th.nn.init.xavier_normal_(self.conv.weight)
27+
28+
self.use_bias = bias
29+
30+
self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0))
31+
self.scale = (th.mean(self.conv.weight.data ** 2)) ** 0.5
32+
self.conv.weight.data.copy_(self.conv.weight.data / self.scale)
33+
34+
def forward(self, x):
35+
"""
36+
forward pass of the network
37+
:param x: input
38+
:return: y => output
39+
"""
40+
try:
41+
dev_scale = self.scale.to(x.get_device())
42+
except RuntimeError:
43+
dev_scale = self.scale
44+
x = self.conv(x.mul(dev_scale))
45+
if self.use_bias:
46+
return x + self.bias.view(1, -1, 1, 1).expand_as(x)
47+
return x
48+
49+
50+
class _equalized_deconv2d(th.nn.Module):
51+
""" Transpose convolution using the equalized learning rate """
52+
53+
def __init__(self, c_in, c_out, k_size, stride=1, pad=0, initializer='kaiming', bias=True):
54+
"""
55+
constructor for the class
56+
:param c_in: input channels
57+
:param c_out: output channels
58+
:param k_size: kernel size
59+
:param stride: stride for convolution transpose
60+
:param pad: padding
61+
:param initializer: initializer. one of kaiming or xavier
62+
:param bias: whether to use bias or not
63+
"""
64+
super(_equalized_deconv2d, self).__init__()
65+
self.deconv = th.nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad, bias=False)
66+
if initializer == 'kaiming':
67+
th.nn.init.kaiming_normal_(self.deconv.weight, a=th.nn.init.calculate_gain('conv2d'))
68+
elif initializer == 'xavier':
69+
th.nn.init.xavier_normal_(self.deconv.weight)
70+
71+
self.use_bias = bias
72+
73+
self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0))
74+
self.scale = (th.mean(self.deconv.weight.data ** 2)) ** 0.5
75+
self.deconv.weight.data.copy_(self.deconv.weight.data / self.scale)
76+
77+
def forward(self, x):
78+
"""
79+
forward pass of the layer
80+
:param x: input
81+
:return: y => output
82+
"""
83+
try:
84+
dev_scale = self.scale.to(x.get_device())
85+
except RuntimeError:
86+
dev_scale = self.scale
87+
88+
x = self.deconv(x.mul(dev_scale))
89+
if self.use_bias:
90+
return x + self.bias.view(1, -1, 1, 1).expand_as(x)
91+
return x
92+
93+
94+
class _equalized_linear(th.nn.Module):
95+
""" Linear layer using equalized learning rate """
96+
97+
def __init__(self, c_in, c_out, initializer='kaiming'):
98+
"""
99+
Linear layer from pytorch extended to include equalized learning rate
100+
:param c_in: number of input channels
101+
:param c_out: number of output channels
102+
:param initializer: initializer to be used: one of "kaiming" or "xavier"
103+
"""
104+
super(_equalized_linear, self).__init__()
105+
self.linear = th.nn.Linear(c_in, c_out, bias=False)
106+
if initializer == 'kaiming':
107+
th.nn.init.kaiming_normal_(self.linear.weight,
108+
a=th.nn.init.calculate_gain('linear'))
109+
elif initializer == 'xavier':
110+
th.nn.init.xavier_normal_(self.linear.weight)
111+
112+
self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0))
113+
self.scale = (th.mean(self.linear.weight.data ** 2)) ** 0.5
114+
self.linear.weight.data.copy_(self.linear.weight.data / self.scale)
115+
116+
def forward(self, x):
117+
"""
118+
forward pass of the layer
119+
:param x: input
120+
:return: y => output
121+
"""
122+
try:
123+
dev_scale = self.scale.to(x.get_device())
124+
except RuntimeError:
125+
dev_scale = self.scale
126+
x = self.linear(x.mul(dev_scale))
127+
return x + self.bias.view(1, -1).expand_as(x)
128+
129+
130+
# ==========================================================
131+
# Layers required for Building The generator and
132+
# discriminator
133+
# ==========================================================
134+
class GenInitialBlock(th.nn.Module):
135+
""" Module implementing the initial block of the input """
136+
137+
def __init__(self, in_channels, use_eql):
138+
"""
139+
constructor for the inner class
140+
:param in_channels: number of input channels to the block
141+
:param use_eql: whether to use equalized learning rate
142+
"""
143+
from torch.nn import LeakyReLU
144+
from torch.nn.functional import local_response_norm
145+
146+
super(GenInitialBlock, self).__init__()
147+
148+
if use_eql:
149+
self.conv_1 = _equalized_deconv2d(in_channels, in_channels, (4, 4), bias=True)
150+
self.conv_2 = _equalized_conv2d(in_channels, in_channels, (3, 3),
151+
pad=1, bias=True)
152+
153+
else:
154+
from torch.nn import Conv2d, ConvTranspose2d
155+
self.conv_1 = ConvTranspose2d(in_channels, in_channels, (4, 4), bias=True)
156+
self.conv_2 = Conv2d(in_channels, in_channels, (3, 3), padding=1, bias=True)
157+
158+
# Pixelwise feature vector normalization operation
159+
self.pixNorm = lambda x: local_response_norm(x, 2 * x.shape[1], alpha=2 * x.shape[1],
160+
beta=0.5, k=1e-8)
161+
162+
# leaky_relu:
163+
self.lrelu = LeakyReLU(0.2)
164+
165+
def forward(self, x):
166+
"""
167+
forward pass of the block
168+
:param x: input to the module
169+
:return: y => output
170+
"""
171+
# convert the tensor shape:
172+
y = th.unsqueeze(th.unsqueeze(x, -1), -1)
173+
174+
# perform the forward computations:
175+
y = self.lrelu(self.conv_1(y))
176+
y = self.lrelu(self.conv_2(y))
177+
178+
# apply pixel norm
179+
y = self.pixNorm(y)
180+
181+
return y
182+
183+
184+
class GenGeneralConvBlock(th.nn.Module):
185+
""" Module implementing a general convolutional block """
186+
187+
def __init__(self, in_channels, out_channels, use_eql):
188+
"""
189+
constructor for the class
190+
:param in_channels: number of input channels to the block
191+
:param out_channels: number of output channels required
192+
:param use_eql: whether to use equalized learning rate
193+
"""
194+
from torch.nn import LeakyReLU, Upsample
195+
from torch.nn.functional import local_response_norm
196+
197+
super(GenGeneralConvBlock, self).__init__()
198+
199+
self.upsample = Upsample(scale_factor=2)
200+
201+
if use_eql:
202+
self.conv_1 = _equalized_conv2d(in_channels, out_channels, (3, 3),
203+
pad=1, bias=True)
204+
self.conv_2 = _equalized_conv2d(out_channels, out_channels, (3, 3),
205+
pad=1, bias=True)
206+
else:
207+
from torch.nn import Conv2d
208+
self.conv_1 = Conv2d(in_channels, out_channels, (3, 3),
209+
padding=1, bias=True)
210+
self.conv_2 = Conv2d(out_channels, out_channels, (3, 3),
211+
padding=1, bias=True)
212+
213+
# Pixelwise feature vector normalization operation
214+
self.pixNorm = lambda x: local_response_norm(x, 2 * x.shape[1], alpha=2 * x.shape[1],
215+
beta=0.5, k=1e-8)
216+
217+
# leaky_relu:
218+
self.lrelu = LeakyReLU(0.2)
219+
220+
def forward(self, x):
221+
"""
222+
forward pass of the block
223+
:param x: input
224+
:return: y => output
225+
"""
226+
y = self.upsample(x)
227+
y = self.pixNorm(self.lrelu(self.conv_1(y)))
228+
y = self.pixNorm(self.lrelu(self.conv_2(y)))
229+
230+
return y
231+
232+
233+
class MinibatchStdDev(th.nn.Module):
234+
def __init__(self, averaging='all'):
235+
"""
236+
constructor for the class
237+
:param averaging: the averaging mode used for calculating the MinibatchStdDev
238+
"""
239+
super(MinibatchStdDev, self).__init__()
240+
241+
# lower case the passed parameter
242+
self.averaging = averaging.lower()
243+
244+
if 'group' in self.averaging:
245+
self.n = int(self.averaging[5:])
246+
else:
247+
assert self.averaging in \
248+
['all', 'flat', 'spatial', 'none', 'gpool'], \
249+
'Invalid averaging mode %s' % self.averaging
250+
251+
# calculate the std_dev in such a way that it doesn't result in 0
252+
# otherwise 0 norm operation's gradient is nan
253+
self.adjusted_std = lambda x, **kwargs: th.sqrt(
254+
th.mean((x - th.mean(x, **kwargs)) ** 2, **kwargs) + 1e-8)
255+
256+
def forward(self, x):
257+
"""
258+
forward pass of the Layer
259+
:param x: input
260+
:return: y => output
261+
"""
262+
shape = list(x.size())
263+
target_shape = copy.deepcopy(shape)
264+
265+
# compute the std's over the minibatch
266+
vals = self.adjusted_std(x, dim=0, keepdim=True)
267+
268+
# perform averaging
269+
if self.averaging == 'all':
270+
target_shape[1] = 1
271+
vals = th.mean(vals, dim=1, keepdim=True)
272+
273+
elif self.averaging == 'spatial':
274+
if len(shape) == 4:
275+
vals = th.mean(th.mean(vals, 2, keepdim=True), 3, keepdim=True)
276+
277+
elif self.averaging == 'none':
278+
target_shape = [target_shape[0]] + [s for s in target_shape[1:]]
279+
280+
elif self.averaging == 'gpool':
281+
if len(shape) == 4:
282+
vals = th.mean(th.mean(th.mean(x, 2, keepdim=True),
283+
3, keepdim=True), 0, keepdim=True)
284+
elif self.averaging == 'flat':
285+
target_shape[1] = 1
286+
vals = th.FloatTensor([self.adjusted_std(x)])
287+
288+
else: # self.averaging == 'group'
289+
target_shape[1] = self.n
290+
vals = vals.view(self.n, self.shape[1] /
291+
self.n, self.shape[2], self.shape[3])
292+
vals = th.mean(vals, 0, keepdim=True).view(1, self.n, 1, 1)
293+
294+
# spatial replication of the computed statistic
295+
vals = vals.expand(*target_shape)
296+
297+
# concatenate the constant feature map to the input
298+
y = th.cat([x, vals], 1)
299+
300+
# return the computed value
301+
return y
302+
303+
304+
class DisFinalBlock(th.nn.Module):
305+
""" Final block for the Discriminator """
306+
307+
def __init__(self, in_channels, use_eql):
308+
"""
309+
constructor of the class
310+
:param in_channels: number of input channels
311+
:param use_eql: whether to use equalized learning rate
312+
"""
313+
from torch.nn import LeakyReLU
314+
315+
super(DisFinalBlock, self).__init__()
316+
317+
# declare the required modules for forward pass
318+
self.batch_discriminator = MinibatchStdDev()
319+
if use_eql:
320+
self.conv_1 = _equalized_conv2d(in_channels + 1, in_channels, (3, 3), pad=1)
321+
self.conv_2 = _equalized_conv2d(in_channels, in_channels, (4, 4))
322+
# final conv layer emulates a fully connected layer
323+
self.conv_3 = _equalized_conv2d(in_channels, 1, (1, 1))
324+
else:
325+
from torch.nn import Conv2d
326+
self.conv_1 = Conv2d(in_channels + 1, in_channels, (3, 3), padding=1)
327+
self.conv_2 = Conv2d(in_channels, in_channels, (4, 4))
328+
# final conv layer emulates a fully connected layer
329+
self.conv_3 = Conv2d(in_channels, 1, (1, 1))
330+
331+
# leaky_relu:
332+
self.lrelu = LeakyReLU(0.2)
333+
334+
def forward(self, x):
335+
"""
336+
forward pass of the FinalBlock
337+
:param x: input
338+
:return: y => output
339+
"""
340+
# minibatch_std_dev layer
341+
y = self.batch_discriminator(x)
342+
343+
# define the computations
344+
y = self.lrelu(self.conv_1(y))
345+
y = self.lrelu(self.conv_2(y))
346+
347+
# fully connected layer
348+
y = self.lrelu(self.conv_3(y)) # final fully connected layer
349+
350+
# flatten the output raw discriminator scores
351+
return y.view(-1)
352+
353+
354+
class DisGeneralConvBlock(th.nn.Module):
355+
""" General block in the discriminator """
356+
357+
def __init__(self, in_channels, out_channels, use_eql):
358+
"""
359+
constructor of the class
360+
:param in_channels: number of input channels
361+
:param out_channels: number of output channels
362+
:param use_eql: whether to use equalized learning rate
363+
"""
364+
from torch.nn import AvgPool2d, LeakyReLU
365+
366+
super(DisGeneralConvBlock, self).__init__()
367+
368+
if use_eql:
369+
self.conv_1 = _equalized_conv2d(in_channels, in_channels, (3, 3), pad=1)
370+
self.conv_2 = _equalized_conv2d(in_channels, out_channels, (3, 3), pad=1)
371+
else:
372+
from torch.nn import Conv2d
373+
self.conv_1 = Conv2d(in_channels, in_channels, (3, 3), padding=1)
374+
self.conv_2 = Conv2d(in_channels, out_channels, (3, 3), padding=1)
375+
376+
self.downSampler = AvgPool2d(2)
377+
378+
# leaky_relu:
379+
self.lrelu = LeakyReLU(0.2)
380+
381+
def forward(self, x):
382+
"""
383+
forward pass of the module
384+
:param x: input
385+
:return: y => output
386+
"""
387+
# define the computations
388+
y = self.lrelu(self.conv_1(x))
389+
y = self.lrelu(self.conv_2(y))
390+
y = self.downSampler(y)
391+
392+
return y

0 commit comments

Comments
 (0)