Skip to content

Commit 502e45d

Browse files
varunagrawalsoumith
authored andcommitted
replace Upsample layer with interpolate function (pytorch#424)
1 parent 323079f commit 502e45d

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

fast_neural_style/neural_style/transformer_net.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,14 @@ class UpsampleConvLayer(torch.nn.Module):
8686
def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
8787
super(UpsampleConvLayer, self).__init__()
8888
self.upsample = upsample
89-
if upsample:
90-
self.upsample_layer = torch.nn.Upsample(mode='nearest', scale_factor=upsample)
9189
reflection_padding = kernel_size // 2
9290
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
9391
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
9492

9593
def forward(self, x):
9694
x_in = x
9795
if self.upsample:
98-
x_in = self.upsample_layer(x_in)
96+
x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
9997
out = self.reflection_pad(x_in)
10098
out = self.conv2d(out)
10199
return out

0 commit comments

Comments
 (0)