Skip to content

Commit

Permalink
Fix 'ValueError: requested an output size...' in the `_output_padding…
Browse files Browse the repository at this point in the history
…` method in transposed convolution and allow arbitrary patch sizes to be passed to ResidualUNet
  • Loading branch information
wolny committed Dec 20, 2023
1 parent 94ccbd3 commit c888f41
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 25 deletions.
54 changes: 38 additions & 16 deletions pytorch3dunet/unet3d/buildingblocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr
# we're in the encoder path
conv1_in_channels = in_channels
if upscale == 1:
conv1_out_channels = out_channels
conv1_out_channels = out_channels
else:
conv1_out_channels = out_channels // 2
conv1_out_channels = out_channels // 2
if conv1_out_channels < in_channels:
conv1_out_channels = in_channels
conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels
Expand All @@ -162,10 +162,10 @@ def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr
# check if dropout_prob is a tuple and if so
# split it for different dropout probabilities for each convolution.
if isinstance(dropout_prob, list) or isinstance(dropout_prob, tuple):
dropout_prob1 = dropout_prob[0]
dropout_prob2 = dropout_prob[1]
dropout_prob1 = dropout_prob[0]
dropout_prob2 = dropout_prob[1]
else:
dropout_prob1 = dropout_prob2 = dropout_prob
dropout_prob1 = dropout_prob2 = dropout_prob

# conv1
self.add_module('SingleConv1',
Expand Down Expand Up @@ -339,7 +339,7 @@ def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=2
conv_layer_order='gcr', num_groups=8, padding=1, upsample='default',
dropout_prob=0.1, is3d=True):
super(Decoder, self).__init__()

# perform concat joining per default
concat = True

Expand All @@ -349,13 +349,13 @@ def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=2
if upsample is not None and upsample != 'none':
if upsample == 'default':
if basic_module == DoubleConv:
upsample = 'nearest' # use nearest neighbor interpolation for upsampling
concat = True # use concat joining
upsample = 'nearest' # use nearest neighbor interpolation for upsampling
concat = True # use concat joining
adapt_channels = False # don't adapt channels
elif basic_module == ResNetBlock or basic_module == ResNetBlockSE:
upsample = 'deconv' # use deconvolution upsampling
concat = False # use summation joining
adapt_channels = True # adapt channels after joining
upsample = 'deconv' # use deconvolution upsampling
concat = False # use summation joining
adapt_channels = True # adapt channels after joining

# perform deconvolution upsampling if mode is deconv
if upsample == 'deconv':
Expand Down Expand Up @@ -509,14 +509,36 @@ class TransposeConvUpsampling(AbstractUpsampling):
is3d (bool): if True use ConvTranspose3d, otherwise use ConvTranspose2d
"""

def __init__(self, in_channels=None, out_channels=None, kernel_size=3, scale_factor=2, is3d=True):
class Upsample(nn.Module):
"""
Workaround the 'ValueError: requested an output size...' in the `_output_padding` method in
transposed convolution. It performs transposed conv followed by the interpolation to the correct size if necessary.
"""

def __init__(self, conv_transposed, is3d):
super().__init__()
self.conv_transposed = conv_transposed
self.is3d = is3d

def forward(self, x, size):
x = self.conv_transposed(x)
if self.is3d:
output_size = x.size()[-3:]
else:
output_size = x.size()[-2:]
if output_size != size:
return F.interpolate(x, size=size)
return x

def __init__(self, in_channels, out_channels, kernel_size=3, scale_factor=2, is3d=True):
# make sure that the output size reverses the MaxPool3d from the corresponding encoder
if is3d is True:
upsample = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor,
padding=1, bias=False)
conv_transposed = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size,
stride=scale_factor, padding=1, bias=False)
else:
upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor,
padding=1, bias=False)
conv_transposed = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size,
stride=scale_factor, padding=1, bias=False)
upsample = self.Upsample(conv_transposed, is3d)
super().__init__(upsample)


Expand Down
24 changes: 24 additions & 0 deletions pytorch3dunet/unet3d/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,30 @@ def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, lay
is3d=False)


class ResidualUNet2D(AbstractUNet):
"""
Residual 2DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf.
"""

def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1,
conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs):
super(ResidualUNet2D, self).__init__(in_channels=in_channels,
out_channels=out_channels,
final_sigmoid=final_sigmoid,
basic_module=ResNetBlock,
f_maps=f_maps,
layer_order=layer_order,
num_groups=num_groups,
num_levels=num_levels,
is_segmentation=is_segmentation,
conv_padding=conv_padding,
conv_upscale=conv_upscale,
upsample=upsample,
dropout_prob=dropout_prob,
is3d=False)


def get_model(model_config):
model_class = get_class(model_config['name'], modules=[
'pytorch3dunet.unet3d.model'
Expand Down
27 changes: 18 additions & 9 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import torch

from pytorch3dunet.unet3d.buildingblocks import ResNetBlock
from pytorch3dunet.unet3d.model import UNet2D, UNet3D, ResidualUNet3D, ResidualUNetSE3D
from pytorch3dunet.unet3d.model import UNet2D, UNet3D, ResidualUNet3D, ResidualUNetSE3D, ResidualUNet2D


class TestModel:
def test_unet2d(self):
model = UNet2D(1, 1, f_maps=16, final_sigmoid=True)
model.eval()
x = torch.rand(1, 1, 64, 64)
x = torch.rand(1, 1, 65, 65)
with torch.no_grad():
y = model(x)

Expand All @@ -17,26 +17,26 @@ def test_unet2d(self):
def test_unet3d(self):
model = UNet3D(1, 1, f_maps=16, final_sigmoid=True)
model.eval()
x = torch.rand(1, 1, 32, 64, 64)
x = torch.rand(1, 1, 33, 65, 65)
with torch.no_grad():
y = model(x)

assert torch.all(0 <= y) and torch.all(y <= 1)

def test_resnet_block1(self):
blk = ResNetBlock(32, 64, is3d=False, order='cgr')
blk = ResNetBlock(33, 64, is3d=False, order='cgr')
blk.eval()
x = torch.rand(1, 32, 64, 64)
x = torch.rand(1, 33, 65, 65)

with torch.no_grad():
y = blk(x)

assert torch.all(0 <= y)

def test_resnet_block2(self):
blk = ResNetBlock(32, 32, is3d=False, order='cgr')
blk = ResNetBlock(33, 32, is3d=False, order='cgr')
blk.eval()
x = torch.rand(1, 32, 64, 64)
x = torch.rand(1, 33, 65, 65)

with torch.no_grad():
y = blk(x)
Expand All @@ -46,15 +46,24 @@ def test_resnet_block2(self):
def test_resunet3d(self):
model = ResidualUNet3D(1, 1, f_maps=16, final_sigmoid=True)
model.eval()
x = torch.rand(1, 1, 32, 64, 64)
x = torch.rand(1, 1, 33, 65, 65)
y = model(x)

assert torch.all(0 <= y) and torch.all(y <= 1)

def test_resunet2d(self):
model = ResidualUNet2D(1, 1, f_maps=16, final_sigmoid=True)
model.eval()
x = torch.rand(1, 1, 65, 65)
with torch.no_grad():
y = model(x)

assert torch.all(0 <= y) and torch.all(y <= 1)

def test_resunetSE3d(self):
model = ResidualUNetSE3D(1, 1, f_maps=16, final_sigmoid=True)
model.eval()
x = torch.rand(1, 1, 32, 64, 64)
x = torch.rand(1, 1, 33, 65, 65)
y = model(x)

assert torch.all(0 <= y) and torch.all(y <= 1)

0 comments on commit c888f41

Please sign in to comment.