Skip to content

Commit

Permalink
Merge pull request #98 from hzdr-MedImaging/feature-conv-upscale-num
Browse files Browse the repository at this point in the history
implement optional conv_upscale parameter in AbstractUNet to allow to define at which convolution level upscaling should be performed.
  • Loading branch information
wolny authored Dec 4, 2023
2 parents 673a6f4 + 0156980 commit 5f50cb0
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
16 changes: 12 additions & 4 deletions pytorch3dunet/unet3d/buildingblocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,21 @@ class DoubleConv(nn.Sequential):
'ce' -> conv + ELU
num_groups (int): number of groups for the GroupNorm
padding (int or tuple): add zero-padding added to all three sides of the input
upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2
dropout_prob (float or tuple): dropout probability for each convolution, default 0.1
is3d (bool): if True use Conv3d instead of Conv2d layers
"""

def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr',
num_groups=8, padding=1, dropout_prob=0.1, is3d=True):
num_groups=8, padding=1, upscale=2, dropout_prob=0.1, is3d=True):
super(DoubleConv, self).__init__()
if encoder:
# we're in the encoder path
conv1_in_channels = in_channels
conv1_out_channels = out_channels // 2
if upscale == 1:
conv1_out_channels = out_channels
else:
conv1_out_channels = out_channels // 2
if conv1_out_channels < in_channels:
conv1_out_channels = in_channels
conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels
Expand Down Expand Up @@ -264,13 +268,14 @@ class Encoder(nn.Module):
in `DoubleConv` module. See `DoubleConv` for more info.
num_groups (int): number of groups for the GroupNorm
padding (int or tuple): add zero-padding added to all three sides of the input
upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2
dropout_prob (float or tuple): dropout probability, default 0.1
is3d (bool): use 3d or 2d convolutions/pooling operation
"""

def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True,
pool_kernel_size=2, pool_type='max', basic_module=DoubleConv, conv_layer_order='gcr',
num_groups=8, padding=1, dropout_prob=0.1, is3d=True):
num_groups=8, padding=1, upscale=2, dropout_prob=0.1, is3d=True):
super(Encoder, self).__init__()
assert pool_type in ['max', 'avg']
if apply_pooling:
Expand All @@ -293,6 +298,7 @@ def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=
order=conv_layer_order,
num_groups=num_groups,
padding=padding,
upscale=upscale,
dropout_prob=dropout_prob,
is3d=is3d)

Expand Down Expand Up @@ -395,7 +401,7 @@ def _joining(encoder_features, x, concat):


def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding,
dropout_prob,
conv_upscale, dropout_prob,
layer_order, num_groups, pool_kernel_size, is3d):
# create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)`
encoders = []
Expand All @@ -409,6 +415,7 @@ def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_pa
conv_kernel_size=conv_kernel_size,
num_groups=num_groups,
padding=conv_padding,
upscale=conv_upscale,
dropout_prob=dropout_prob,
is3d=is3d)
else:
Expand All @@ -419,6 +426,7 @@ def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_pa
num_groups=num_groups,
pool_kernel_size=pool_kernel_size,
padding=conv_padding,
upscale=conv_upscale,
dropout_prob=dropout_prob,
is3d=is3d)

Expand Down
17 changes: 11 additions & 6 deletions pytorch3dunet/unet3d/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class AbstractUNet(nn.Module):
conv_kernel_size (int or tuple): size of the convolving kernel in the basic_module
pool_kernel_size (int or tuple): the size of the window
conv_padding (int or tuple): add zero-padding added to all three sides of the input
conv_upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2
upsample (str): algorithm used for decoder upsampling:
InterpolateUpsampling: 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'
TransposeConvUpsampling: 'deconv'
Expand All @@ -43,7 +44,7 @@ class AbstractUNet(nn.Module):

def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr',
num_groups=8, num_levels=4, is_segmentation=True, conv_kernel_size=3, pool_kernel_size=2,
conv_padding=1, upsample='default', dropout_prob=0.1, is3d=True):
conv_padding=1, conv_upscale=2, upsample='default', dropout_prob=0.1, is3d=True):
super(AbstractUNet, self).__init__()

if isinstance(f_maps, int):
Expand All @@ -56,7 +57,7 @@ def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_map

# create encoder path
self.encoders = create_encoders(in_channels, f_maps, basic_module, conv_kernel_size,
conv_padding, dropout_prob,
conv_padding, conv_upscale, dropout_prob,
layer_order, num_groups, pool_kernel_size, is3d)

# create decoder path
Expand Down Expand Up @@ -119,7 +120,7 @@ class UNet3D(AbstractUNet):

def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1,
upsample='default', dropout_prob=0.1, **kwargs):
conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs):
super(UNet3D, self).__init__(in_channels=in_channels,
out_channels=out_channels,
final_sigmoid=final_sigmoid,
Expand All @@ -130,6 +131,7 @@ def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, lay
num_levels=num_levels,
is_segmentation=is_segmentation,
conv_padding=conv_padding,
conv_upscale=conv_upscale,
upsample=upsample,
dropout_prob=dropout_prob,
is3d=True)
Expand All @@ -145,7 +147,7 @@ class ResidualUNet3D(AbstractUNet):

def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1,
upsample='default', dropout_prob=0.1, **kwargs):
conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs):
super(ResidualUNet3D, self).__init__(in_channels=in_channels,
out_channels=out_channels,
final_sigmoid=final_sigmoid,
Expand All @@ -156,6 +158,7 @@ def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, lay
num_levels=num_levels,
is_segmentation=is_segmentation,
conv_padding=conv_padding,
conv_upscale=conv_upscale,
upsample=upsample,
dropout_prob=dropout_prob,
is3d=True)
Expand All @@ -173,7 +176,7 @@ class ResidualUNetSE3D(AbstractUNet):

def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1,
upsample='default', dropout_prob=0.1, **kwargs):
conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs):
super(ResidualUNetSE3D, self).__init__(in_channels=in_channels,
out_channels=out_channels,
final_sigmoid=final_sigmoid,
Expand All @@ -184,6 +187,7 @@ def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, lay
num_levels=num_levels,
is_segmentation=is_segmentation,
conv_padding=conv_padding,
conv_upscale=conv_upscale,
upsample=upsample,
dropout_prob=dropout_prob,
is3d=True)
Expand All @@ -197,7 +201,7 @@ class UNet2D(AbstractUNet):

def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1,
upsample='default', dropout_prob=0.1, **kwargs):
conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs):
super(UNet2D, self).__init__(in_channels=in_channels,
out_channels=out_channels,
final_sigmoid=final_sigmoid,
Expand All @@ -208,6 +212,7 @@ def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, lay
num_levels=num_levels,
is_segmentation=is_segmentation,
conv_padding=conv_padding,
conv_upscale=conv_upscale,
upsample=upsample,
dropout_prob=dropout_prob,
is3d=False)
Expand Down

0 comments on commit 5f50cb0

Please sign in to comment.