Skip to content

Commit

Permalink
implemented dropout layer support for reducing overfitting.
Browse files Browse the repository at this point in the history
Now either a 'd' or 'D' can be specified in the model definition at
'layer_order' so that either nn.Dropout or spatial nn.Dropout2D layers
can be added for potentially improved network regularization. In
addition, the dropout probability can be specified by a new
'dropout_prob' parameter which can take a single float, a list or tuple
of floats to allow to specify the dropout probability for each
single convolution in the double convolution step.
  • Loading branch information
jens-maus committed Dec 4, 2023
1 parent a296dc9 commit 9816be3
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 22 deletions.
58 changes: 44 additions & 14 deletions pytorch3dunet/unet3d/buildingblocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from pytorch3dunet.unet3d.se import ChannelSELayer3D, ChannelSpatialSELayer3D, SpatialSELayer3D


def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding, is3d):
def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding,
dropout_prob, is3d):
"""
Create a list of modules with together constitute a single conv layer with non-linearity
and optional batchnorm/groupnorm.
Expand All @@ -22,8 +23,11 @@ def create_conv(in_channels, out_channels, kernel_size, order, num_groups, paddi
'cl' -> conv + LeakyReLU
'ce' -> conv + ELU
'bcr' -> batchnorm + conv + ReLU
'cbrd' -> conv + batchnorm + ReLU + dropout
'cbrD' -> conv + batchnorm + ReLU + dropout2d
num_groups (int): number of groups for the GroupNorm
padding (int or tuple): add zero-padding added to all three sides of the input
dropout_prob (float): dropout probability
is3d (bool): is3d (bool): if True use Conv3d, otherwise use Conv2d
Return:
list of tuple (name, module)
Expand Down Expand Up @@ -72,8 +76,12 @@ def create_conv(in_channels, out_channels, kernel_size, order, num_groups, paddi
modules.append(('batchnorm', bn(in_channels)))
else:
modules.append(('batchnorm', bn(out_channels)))
elif char == 'd':
modules.append(('dropout', nn.Dropout(p=dropout_prob)))
elif char == 'D':
modules.append(('dropout2d', nn.Dropout2d(p=dropout_prob)))
else:
raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']")
raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c', 'd', 'D']")

return modules

Expand All @@ -94,13 +102,16 @@ class SingleConv(nn.Sequential):
'ce' -> conv + ELU
num_groups (int): number of groups for the GroupNorm
padding (int or tuple): add zero-padding
dropout_prob (float): dropout probability, default 0.1
is3d (bool): if True use Conv3d, otherwise use Conv2d
"""

def __init__(self, in_channels, out_channels, kernel_size=3, order='gcr', num_groups=8, padding=1, is3d=True):
def __init__(self, in_channels, out_channels, kernel_size=3, order='gcr', num_groups=8,
padding=1, dropout_prob=0.1, is3d=True):
super(SingleConv, self).__init__()

for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding, is3d):
for name, module in create_conv(in_channels, out_channels, kernel_size, order,
num_groups, padding, dropout_prob, is3d):
self.add_module(name, module)


Expand All @@ -125,11 +136,12 @@ class DoubleConv(nn.Sequential):
'ce' -> conv + ELU
num_groups (int): number of groups for the GroupNorm
padding (int or tuple): add zero-padding added to all three sides of the input
dropout_prob (float or tuple): dropout probability for each convolution, default 0.1
is3d (bool): if True use Conv3d instead of Conv2d layers
"""

def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr', num_groups=8, padding=1,
is3d=True):
def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr',
num_groups=8, padding=1, dropout_prob=0.1, is3d=True):
super(DoubleConv, self).__init__()
if encoder:
# we're in the encoder path
Expand All @@ -143,14 +155,22 @@ def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr
conv1_in_channels, conv1_out_channels = in_channels, out_channels
conv2_in_channels, conv2_out_channels = out_channels, out_channels

# check if dropout_prob is a tuple and if so
# split it for different dropout probabilities for each convolution.
if isinstance(dropout_prob, list) or isinstance(dropout_prob, tuple):
dropout_prob1 = dropout_prob[0]
dropout_prob2 = dropout_prob[1]
else:
dropout_prob1 = dropout_prob2 = dropout_prob

# conv1
self.add_module('SingleConv1',
SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups,
padding=padding, is3d=is3d))
padding=padding, dropout_prob=dropout_prob1, is3d=is3d))
# conv2
self.add_module('SingleConv2',
SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups,
padding=padding, is3d=is3d))
padding=padding, dropout_prob=dropout_prob2, is3d=is3d))


class ResNetBlock(nn.Module):
Expand Down Expand Up @@ -244,12 +264,13 @@ class Encoder(nn.Module):
in `DoubleConv` module. See `DoubleConv` for more info.
num_groups (int): number of groups for the GroupNorm
padding (int or tuple): add zero-padding added to all three sides of the input
dropout_prob (float or tuple): dropout probability, default 0.1
is3d (bool): use 3d or 2d convolutions/pooling operation
"""

def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True,
pool_kernel_size=2, pool_type='max', basic_module=DoubleConv, conv_layer_order='gcr',
num_groups=8, padding=1, is3d=True):
num_groups=8, padding=1, dropout_prob=0.1, is3d=True):
super(Encoder, self).__init__()
assert pool_type in ['max', 'avg']
if apply_pooling:
Expand All @@ -272,6 +293,7 @@ def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=
order=conv_layer_order,
num_groups=num_groups,
padding=padding,
dropout_prob=dropout_prob,
is3d=is3d)

def forward(self, x):
Expand Down Expand Up @@ -300,10 +322,12 @@ class Decoder(nn.Module):
num_groups (int): number of groups for the GroupNorm
padding (int or tuple): add zero-padding added to all three sides of the input
upsample (bool): should the input be upsampled
dropout_prob (float or tuple): dropout probability, default 0.1
"""

def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=(2, 2, 2), basic_module=DoubleConv,
conv_layer_order='gcr', num_groups=8, mode='nearest', padding=1, upsample=True, is3d=True):
def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=2, basic_module=DoubleConv,
conv_layer_order='gcr', num_groups=8, padding=1, upsample=True,
dropout_prob=0.1, is3d=True):
super(Decoder, self).__init__()

if upsample:
Expand Down Expand Up @@ -332,6 +356,7 @@ def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=(
order=conv_layer_order,
num_groups=num_groups,
padding=padding,
dropout_prob=dropout_prob,
is3d=is3d)

def forward(self, encoder_features, x):
Expand All @@ -348,8 +373,9 @@ def _joining(encoder_features, x, concat):
return encoder_features + x


def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, num_groups,
pool_kernel_size, is3d):
def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding,
dropout_prob,
layer_order, num_groups, pool_kernel_size, is3d):
# create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)`
encoders = []
for i, out_feature_num in enumerate(f_maps):
Expand All @@ -362,6 +388,7 @@ def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_pa
conv_kernel_size=conv_kernel_size,
num_groups=num_groups,
padding=conv_padding,
dropout_prob=dropout_prob,
is3d=is3d)
else:
encoder = Encoder(f_maps[i - 1], out_feature_num,
Expand All @@ -371,14 +398,16 @@ def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_pa
num_groups=num_groups,
pool_kernel_size=pool_kernel_size,
padding=conv_padding,
dropout_prob=dropout_prob,
is3d=is3d)

encoders.append(encoder)

return nn.ModuleList(encoders)


def create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, num_groups, is3d):
def create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order,
num_groups, dropout_prob, is3d):
# create decoder path consisting of the Decoder modules. The length of the decoder list is equal to `len(f_maps) - 1`
decoders = []
reversed_f_maps = list(reversed(f_maps))
Expand All @@ -396,6 +425,7 @@ def create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_
conv_kernel_size=conv_kernel_size,
num_groups=num_groups,
padding=conv_padding,
dropout_prob=dropout_prob,
is3d=is3d)
decoders.append(decoder)
return nn.ModuleList(decoders)
Expand Down
27 changes: 19 additions & 8 deletions pytorch3dunet/unet3d/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@ class AbstractUNet(nn.Module):
conv_kernel_size (int or tuple): size of the convolving kernel in the basic_module
pool_kernel_size (int or tuple): the size of the window
conv_padding (int or tuple): add zero-padding added to all three sides of the input
dropout_prob (float or tuple): dropout probability, default: 0.1
is3d (bool): if True the model is 3D, otherwise 2D, default: True
"""

def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr',
num_groups=8, num_levels=4, is_segmentation=True, conv_kernel_size=3, pool_kernel_size=2,
conv_padding=1, is3d=True):
conv_padding=1, dropout_prob=0.1, is3d=True):
super(AbstractUNet, self).__init__()

if isinstance(f_maps, int):
Expand All @@ -49,11 +50,13 @@ def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_map
assert num_groups is not None, "num_groups must be specified if GroupNorm is used"

# create encoder path
self.encoders = create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding, layer_order,
num_groups, pool_kernel_size, is3d)
self.encoders = create_encoders(in_channels, f_maps, basic_module, conv_kernel_size,
conv_padding, dropout_prob,
layer_order, num_groups, pool_kernel_size, is3d)

# create decoder path
self.decoders = create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, num_groups,
self.decoders = create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding,
layer_order, num_groups, dropout_prob,
is3d)

# in the last layer a 1×1 convolution reduces the number of output channels to the number of labels
Expand Down Expand Up @@ -110,7 +113,8 @@ class UNet3D(AbstractUNet):
"""

def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1, **kwargs):
num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1,
dropout_prob=0.1, **kwargs):
super(UNet3D, self).__init__(in_channels=in_channels,
out_channels=out_channels,
final_sigmoid=final_sigmoid,
Expand All @@ -121,6 +125,7 @@ def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, lay
num_levels=num_levels,
is_segmentation=is_segmentation,
conv_padding=conv_padding,
dropout_prob=dropout_prob,
is3d=True)


Expand All @@ -133,7 +138,8 @@ class ResidualUNet3D(AbstractUNet):
"""

def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1, **kwargs):
num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1,
dropout_prob=0.1, **kwargs):
super(ResidualUNet3D, self).__init__(in_channels=in_channels,
out_channels=out_channels,
final_sigmoid=final_sigmoid,
Expand All @@ -144,6 +150,7 @@ def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, lay
num_levels=num_levels,
is_segmentation=is_segmentation,
conv_padding=conv_padding,
dropout_prob=dropout_prob,
is3d=True)


Expand All @@ -158,7 +165,8 @@ class ResidualUNetSE3D(AbstractUNet):
"""

def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1, **kwargs):
num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1,
dropout_prob=0.1, **kwargs):
super(ResidualUNetSE3D, self).__init__(in_channels=in_channels,
out_channels=out_channels,
final_sigmoid=final_sigmoid,
Expand All @@ -169,6 +177,7 @@ def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, lay
num_levels=num_levels,
is_segmentation=is_segmentation,
conv_padding=conv_padding,
dropout_prob=dropout_prob,
is3d=True)


Expand All @@ -179,7 +188,8 @@ class UNet2D(AbstractUNet):
"""

def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1, **kwargs):
num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1,
dropout_prob=0.1, **kwargs):
super(UNet2D, self).__init__(in_channels=in_channels,
out_channels=out_channels,
final_sigmoid=final_sigmoid,
Expand All @@ -190,6 +200,7 @@ def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, lay
num_levels=num_levels,
is_segmentation=is_segmentation,
conv_padding=conv_padding,
dropout_prob=dropout_prob,
is3d=False)


Expand Down

0 comments on commit 9816be3

Please sign in to comment.