From 9816be3d86cc04760d2550a25a8aca7f416a9cac Mon Sep 17 00:00:00 2001 From: Jens Maus Date: Mon, 4 Dec 2023 11:51:43 +0100 Subject: [PATCH] implemented dropout layer support for reducing overfitting. Now either a 'd' or 'D' can be specified in the model definition at 'layer_order' so that either nn.Dropout or spatial nn.Dropout2D layers can be added for potentially improved network regularization. In addition, the dropout probability can be specified by a new 'dropout_prob' parameter which can take a single float, a list or tuple of floats to allow to specify the dropout probability for each single convolution in the double convolution step. --- pytorch3dunet/unet3d/buildingblocks.py | 58 +++++++++++++++++++------- pytorch3dunet/unet3d/model.py | 27 ++++++++---- 2 files changed, 63 insertions(+), 22 deletions(-) diff --git a/pytorch3dunet/unet3d/buildingblocks.py b/pytorch3dunet/unet3d/buildingblocks.py index 517e628b..08ef4308 100644 --- a/pytorch3dunet/unet3d/buildingblocks.py +++ b/pytorch3dunet/unet3d/buildingblocks.py @@ -7,7 +7,8 @@ from pytorch3dunet.unet3d.se import ChannelSELayer3D, ChannelSpatialSELayer3D, SpatialSELayer3D -def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding, is3d): +def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding, + dropout_prob, is3d): """ Create a list of modules with together constitute a single conv layer with non-linearity and optional batchnorm/groupnorm. @@ -22,8 +23,11 @@ def create_conv(in_channels, out_channels, kernel_size, order, num_groups, paddi 'cl' -> conv + LeakyReLU 'ce' -> conv + ELU 'bcr' -> batchnorm + conv + ReLU + 'cbrd' -> conv + batchnorm + ReLU + dropout + 'cbrD' -> conv + batchnorm + ReLU + dropout2d num_groups (int): number of groups for the GroupNorm padding (int or tuple): add zero-padding added to all three sides of the input + dropout_prob (float): dropout probability is3d (bool): is3d (bool): if True use Conv3d, otherwise use Conv2d Return: list of tuple (name, module) @@ -72,8 +76,12 @@ def create_conv(in_channels, out_channels, kernel_size, order, num_groups, paddi modules.append(('batchnorm', bn(in_channels))) else: modules.append(('batchnorm', bn(out_channels))) + elif char == 'd': + modules.append(('dropout', nn.Dropout(p=dropout_prob))) + elif char == 'D': + modules.append(('dropout2d', nn.Dropout2d(p=dropout_prob))) else: - raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']") + raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c', 'd', 'D']") return modules @@ -94,13 +102,16 @@ class SingleConv(nn.Sequential): 'ce' -> conv + ELU num_groups (int): number of groups for the GroupNorm padding (int or tuple): add zero-padding + dropout_prob (float): dropout probability, default 0.1 is3d (bool): if True use Conv3d, otherwise use Conv2d """ - def __init__(self, in_channels, out_channels, kernel_size=3, order='gcr', num_groups=8, padding=1, is3d=True): + def __init__(self, in_channels, out_channels, kernel_size=3, order='gcr', num_groups=8, + padding=1, dropout_prob=0.1, is3d=True): super(SingleConv, self).__init__() - for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding, is3d): + for name, module in create_conv(in_channels, out_channels, kernel_size, order, + num_groups, padding, dropout_prob, is3d): self.add_module(name, module) @@ -125,11 +136,12 @@ class DoubleConv(nn.Sequential): 'ce' -> conv + ELU num_groups (int): number of groups for the GroupNorm padding (int or tuple): add zero-padding added to all three sides of the input + dropout_prob (float or tuple): dropout probability for each convolution, default 0.1 is3d (bool): if True use Conv3d instead of Conv2d layers """ - def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr', num_groups=8, padding=1, - is3d=True): + def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr', + num_groups=8, padding=1, dropout_prob=0.1, is3d=True): super(DoubleConv, self).__init__() if encoder: # we're in the encoder path @@ -143,14 +155,22 @@ def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr conv1_in_channels, conv1_out_channels = in_channels, out_channels conv2_in_channels, conv2_out_channels = out_channels, out_channels + # check if dropout_prob is a tuple and if so + # split it for different dropout probabilities for each convolution. + if isinstance(dropout_prob, list) or isinstance(dropout_prob, tuple): + dropout_prob1 = dropout_prob[0] + dropout_prob2 = dropout_prob[1] + else: + dropout_prob1 = dropout_prob2 = dropout_prob + # conv1 self.add_module('SingleConv1', SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups, - padding=padding, is3d=is3d)) + padding=padding, dropout_prob=dropout_prob1, is3d=is3d)) # conv2 self.add_module('SingleConv2', SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups, - padding=padding, is3d=is3d)) + padding=padding, dropout_prob=dropout_prob2, is3d=is3d)) class ResNetBlock(nn.Module): @@ -244,12 +264,13 @@ class Encoder(nn.Module): in `DoubleConv` module. See `DoubleConv` for more info. num_groups (int): number of groups for the GroupNorm padding (int or tuple): add zero-padding added to all three sides of the input + dropout_prob (float or tuple): dropout probability, default 0.1 is3d (bool): use 3d or 2d convolutions/pooling operation """ def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True, pool_kernel_size=2, pool_type='max', basic_module=DoubleConv, conv_layer_order='gcr', - num_groups=8, padding=1, is3d=True): + num_groups=8, padding=1, dropout_prob=0.1, is3d=True): super(Encoder, self).__init__() assert pool_type in ['max', 'avg'] if apply_pooling: @@ -272,6 +293,7 @@ def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling= order=conv_layer_order, num_groups=num_groups, padding=padding, + dropout_prob=dropout_prob, is3d=is3d) def forward(self, x): @@ -300,10 +322,12 @@ class Decoder(nn.Module): num_groups (int): number of groups for the GroupNorm padding (int or tuple): add zero-padding added to all three sides of the input upsample (bool): should the input be upsampled + dropout_prob (float or tuple): dropout probability, default 0.1 """ - def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=(2, 2, 2), basic_module=DoubleConv, - conv_layer_order='gcr', num_groups=8, mode='nearest', padding=1, upsample=True, is3d=True): + def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=2, basic_module=DoubleConv, + conv_layer_order='gcr', num_groups=8, padding=1, upsample=True, + dropout_prob=0.1, is3d=True): super(Decoder, self).__init__() if upsample: @@ -332,6 +356,7 @@ def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=( order=conv_layer_order, num_groups=num_groups, padding=padding, + dropout_prob=dropout_prob, is3d=is3d) def forward(self, encoder_features, x): @@ -348,8 +373,9 @@ def _joining(encoder_features, x, concat): return encoder_features + x -def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, num_groups, - pool_kernel_size, is3d): +def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding, + dropout_prob, + layer_order, num_groups, pool_kernel_size, is3d): # create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)` encoders = [] for i, out_feature_num in enumerate(f_maps): @@ -362,6 +388,7 @@ def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_pa conv_kernel_size=conv_kernel_size, num_groups=num_groups, padding=conv_padding, + dropout_prob=dropout_prob, is3d=is3d) else: encoder = Encoder(f_maps[i - 1], out_feature_num, @@ -371,6 +398,7 @@ def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_pa num_groups=num_groups, pool_kernel_size=pool_kernel_size, padding=conv_padding, + dropout_prob=dropout_prob, is3d=is3d) encoders.append(encoder) @@ -378,7 +406,8 @@ def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_pa return nn.ModuleList(encoders) -def create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, num_groups, is3d): +def create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, + num_groups, dropout_prob, is3d): # create decoder path consisting of the Decoder modules. The length of the decoder list is equal to `len(f_maps) - 1` decoders = [] reversed_f_maps = list(reversed(f_maps)) @@ -396,6 +425,7 @@ def create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_ conv_kernel_size=conv_kernel_size, num_groups=num_groups, padding=conv_padding, + dropout_prob=dropout_prob, is3d=is3d) decoders.append(decoder) return nn.ModuleList(decoders) diff --git a/pytorch3dunet/unet3d/model.py b/pytorch3dunet/unet3d/model.py index 89f6246f..cc4496fe 100644 --- a/pytorch3dunet/unet3d/model.py +++ b/pytorch3dunet/unet3d/model.py @@ -32,12 +32,13 @@ class AbstractUNet(nn.Module): conv_kernel_size (int or tuple): size of the convolving kernel in the basic_module pool_kernel_size (int or tuple): the size of the window conv_padding (int or tuple): add zero-padding added to all three sides of the input + dropout_prob (float or tuple): dropout probability, default: 0.1 is3d (bool): if True the model is 3D, otherwise 2D, default: True """ def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr', num_groups=8, num_levels=4, is_segmentation=True, conv_kernel_size=3, pool_kernel_size=2, - conv_padding=1, is3d=True): + conv_padding=1, dropout_prob=0.1, is3d=True): super(AbstractUNet, self).__init__() if isinstance(f_maps, int): @@ -49,11 +50,13 @@ def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_map assert num_groups is not None, "num_groups must be specified if GroupNorm is used" # create encoder path - self.encoders = create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, - num_groups, pool_kernel_size, is3d) + self.encoders = create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, + conv_padding, dropout_prob, + layer_order, num_groups, pool_kernel_size, is3d) # create decoder path - self.decoders = create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, num_groups, + self.decoders = create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, + layer_order, num_groups, dropout_prob, is3d) # in the last layer a 1×1 convolution reduces the number of output channels to the number of labels @@ -110,7 +113,8 @@ class UNet3D(AbstractUNet): """ def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', - num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1, **kwargs): + num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1, + dropout_prob=0.1, **kwargs): super(UNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid, @@ -121,6 +125,7 @@ def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, lay num_levels=num_levels, is_segmentation=is_segmentation, conv_padding=conv_padding, + dropout_prob=dropout_prob, is3d=True) @@ -133,7 +138,8 @@ class ResidualUNet3D(AbstractUNet): """ def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', - num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1, **kwargs): + num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1, + dropout_prob=0.1, **kwargs): super(ResidualUNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid, @@ -144,6 +150,7 @@ def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, lay num_levels=num_levels, is_segmentation=is_segmentation, conv_padding=conv_padding, + dropout_prob=dropout_prob, is3d=True) @@ -158,7 +165,8 @@ class ResidualUNetSE3D(AbstractUNet): """ def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', - num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1, **kwargs): + num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1, + dropout_prob=0.1, **kwargs): super(ResidualUNetSE3D, self).__init__(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid, @@ -169,6 +177,7 @@ def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, lay num_levels=num_levels, is_segmentation=is_segmentation, conv_padding=conv_padding, + dropout_prob=dropout_prob, is3d=True) @@ -179,7 +188,8 @@ class UNet2D(AbstractUNet): """ def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', - num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1, **kwargs): + num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1, + dropout_prob=0.1, **kwargs): super(UNet2D, self).__init__(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid, @@ -190,6 +200,7 @@ def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, lay num_levels=num_levels, is_segmentation=is_segmentation, conv_padding=conv_padding, + dropout_prob=dropout_prob, is3d=False)