diff --git a/segmentation_models_pytorch/base/modules.py b/segmentation_models_pytorch/base/modules.py index cbd643b6..15cfdb12 100644 --- a/segmentation_models_pytorch/base/modules.py +++ b/segmentation_models_pytorch/base/modules.py @@ -1,3 +1,5 @@ +from typing import Any, Dict, Union + import torch import torch.nn as nn @@ -7,43 +9,109 @@ InPlaceABN = None +def get_norm_layer( + use_norm: Union[bool, str, Dict[str, Any]], out_channels: int +) -> nn.Module: + supported_norms = ("inplace", "batchnorm", "identity", "layernorm", "instancenorm") + + # Step 1. Convert tot dict representation + + ## Check boolean + if use_norm is True: + norm_params = {"type": "batchnorm"} + elif use_norm is False: + norm_params = {"type": "identity"} + + ## Check string + elif isinstance(use_norm, str): + norm_str = use_norm.lower() + if norm_str == "inplace": + norm_params = { + "type": "inplace", + "activation": "leaky_relu", + "activation_param": 0.0, + } + elif norm_str in supported_norms: + norm_params = {"type": norm_str} + else: + raise ValueError( + f"Unrecognized normalization type string provided: {use_norm}. Should be in " + f"{supported_norms}" + ) + + ## Check dict + elif isinstance(use_norm, dict): + norm_params = use_norm + + else: + raise ValueError( + f"Invalid type for use_norm should either be a bool (batchnorm/identity), " + f"a string in {supported_norms}, or a dict like {{'type': 'batchnorm', **kwargs}}" + ) + + # Step 2. Check if the dict is valid + if "type" not in norm_params: + raise ValueError( + f"Malformed dictionary given in use_norm: {use_norm}. Should contain key 'type'." + ) + if norm_params["type"] not in supported_norms: + raise ValueError( + f"Unrecognized normalization type string provided: {use_norm}. Should be in {supported_norms}" + ) + if norm_params["type"] == "inplace" and InPlaceABN is None: + raise RuntimeError( + "In order to use `use_norm='inplace'` the inplace_abn package must be installed. Use:\n" + " $ pip install -U wheel setuptools\n" + " $ pip install inplace_abn --no-build-isolation\n" + "Also see: https://github.com/mapillary/inplace_abn" + ) + + # Step 3. Initialize the norm layer + norm_type = norm_params["type"] + norm_kwargs = {k: v for k, v in norm_params.items() if k != "type"} + + if norm_type == "inplace": + norm = InPlaceABN(out_channels, **norm_kwargs) + elif norm_type == "batchnorm": + norm = nn.BatchNorm2d(out_channels, **norm_kwargs) + elif norm_type == "identity": + norm = nn.Identity() + elif norm_type == "layernorm": + norm = nn.LayerNorm(out_channels, **norm_kwargs) + elif norm_type == "instancenorm": + norm = nn.InstanceNorm2d(out_channels, **norm_kwargs) + else: + raise ValueError(f"Unrecognized normalization type: {norm_type}") + + return norm + + class Conv2dReLU(nn.Sequential): def __init__( self, - in_channels, - out_channels, - kernel_size, - padding=0, - stride=1, - use_batchnorm=True, + in_channels: int, + out_channels: int, + kernel_size: int, + padding: int = 0, + stride: int = 1, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): - if use_batchnorm == "inplace" and InPlaceABN is None: - raise RuntimeError( - "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " - + "To install see: https://github.com/mapillary/inplace_abn" - ) + norm = get_norm_layer(use_norm, out_channels) + is_identity = isinstance(norm, nn.Identity) conv = nn.Conv2d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, - bias=not (use_batchnorm), + bias=is_identity, ) - relu = nn.ReLU(inplace=True) - - if use_batchnorm == "inplace": - bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) - relu = nn.Identity() - elif use_batchnorm and use_batchnorm != "inplace": - bn = nn.BatchNorm2d(out_channels) - - else: - bn = nn.Identity() + is_inplaceabn = InPlaceABN is not None and isinstance(norm, InPlaceABN) + activation = nn.Identity() if is_inplaceabn else nn.ReLU(inplace=True) - super(Conv2dReLU, self).__init__(conv, bn, relu) + super(Conv2dReLU, self).__init__(conv, norm, activation) class SCSEModule(nn.Module): diff --git a/segmentation_models_pytorch/decoders/linknet/decoder.py b/segmentation_models_pytorch/decoders/linknet/decoder.py index 8dfd8434..95c7f9f6 100644 --- a/segmentation_models_pytorch/decoders/linknet/decoder.py +++ b/segmentation_models_pytorch/decoders/linknet/decoder.py @@ -1,28 +1,33 @@ import torch import torch.nn as nn -from typing import List, Optional +from typing import Any, Dict, List, Optional, Union from segmentation_models_pytorch.base import modules class TransposeX2(nn.Sequential): - def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True): + def __init__( + self, + in_channels: int, + out_channels: int, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + ): super().__init__() - layers = [ - nn.ConvTranspose2d( - in_channels, out_channels, kernel_size=4, stride=2, padding=1 - ), - nn.ReLU(inplace=True), - ] - - if use_batchnorm: - layers.insert(1, nn.BatchNorm2d(out_channels)) - - super().__init__(*layers) + conv = nn.ConvTranspose2d( + in_channels, out_channels, kernel_size=4, stride=2, padding=1 + ) + norm = modules.get_norm_layer(use_norm, out_channels) + activation = nn.ReLU(inplace=True) + super().__init__(conv, norm, activation) class DecoderBlock(nn.Module): - def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True): + def __init__( + self, + in_channels: int, + out_channels: int, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + ): super().__init__() self.block = nn.Sequential( @@ -30,16 +35,14 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr in_channels, in_channels // 4, kernel_size=1, - use_batchnorm=use_batchnorm, - ), - TransposeX2( - in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm + use_norm=use_norm, ), + TransposeX2(in_channels // 4, in_channels // 4, use_norm=use_norm), modules.Conv2dReLU( in_channels // 4, out_channels, kernel_size=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ), ) @@ -58,7 +61,7 @@ def __init__( encoder_channels: List[int], prefinal_channels: int = 32, n_blocks: int = 5, - use_batchnorm: bool = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): super().__init__() @@ -71,7 +74,11 @@ def __init__( self.blocks = nn.ModuleList( [ - DecoderBlock(channels[i], channels[i + 1], use_batchnorm=use_batchnorm) + DecoderBlock( + channels[i], + channels[i + 1], + use_norm=use_norm, + ) for i in range(n_blocks) ] ) diff --git a/segmentation_models_pytorch/decoders/linknet/model.py b/segmentation_models_pytorch/decoders/linknet/model.py index 356468ed..be0d01b2 100644 --- a/segmentation_models_pytorch/decoders/linknet/model.py +++ b/segmentation_models_pytorch/decoders/linknet/model.py @@ -1,4 +1,5 @@ -from typing import Any, Optional, Union +import warnings +from typing import Any, Dict, Optional, Union, Callable from segmentation_models_pytorch.base import ( ClassificationHead, @@ -29,9 +30,22 @@ class Linknet(SegmentationModel): Default is 5 encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and other pretrained weights (see table with available weights for each encoder_name) - decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers - is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. - Available options are **True, False, "inplace"** + decoder_use_norm: Specifies normalization between Conv2D and activation. + Accepts the following types: + - **True**: Defaults to `"batchnorm"`. + - **False**: No normalization (`nn.Identity`). + - **str**: Specifies normalization type using default parameters. Available values: + `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`. + - **dict**: Fully customizable normalization settings. Structure: + ```python + {"type": <norm_type>, **kwargs} + ``` + where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation. + + **Example**: + ```python + decoder_use_norm={"type": "layernorm", "eps": 1e-2} + ``` in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) activation: An activation function to apply after the final convolution layer. @@ -60,10 +74,10 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_use_batchnorm: bool = True, + decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", in_channels: int = 3, classes: int = 1, - activation: Optional[Union[str, callable]] = None, + activation: Optional[Union[str, Callable]] = None, aux_params: Optional[dict] = None, **kwargs: dict[str, Any], ): @@ -74,6 +88,15 @@ def __init__( "Encoder `{}` is not supported for Linknet".format(encoder_name) ) + decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) + if decoder_use_batchnorm is not None: + warnings.warn( + "The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm", + DeprecationWarning, + stacklevel=2, + ) + decoder_use_norm = decoder_use_batchnorm + self.encoder = get_encoder( encoder_name, in_channels=in_channels, @@ -86,7 +109,7 @@ def __init__( encoder_channels=self.encoder.out_channels, n_blocks=encoder_depth, prefinal_channels=32, - use_batchnorm=decoder_use_batchnorm, + use_norm=decoder_use_norm, ) self.segmentation_head = SegmentationHead( diff --git a/segmentation_models_pytorch/decoders/manet/decoder.py b/segmentation_models_pytorch/decoders/manet/decoder.py index 61b1fe57..ae2498c7 100644 --- a/segmentation_models_pytorch/decoders/manet/decoder.py +++ b/segmentation_models_pytorch/decoders/manet/decoder.py @@ -1,9 +1,9 @@ +from typing import Any, Dict, List, Optional, Union + import torch import torch.nn as nn import torch.nn.functional as F -from typing import List, Optional - from segmentation_models_pytorch.base import modules as md @@ -49,7 +49,7 @@ def __init__( in_channels: int, skip_channels: int, out_channels: int, - use_batchnorm: bool = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", reduction: int = 16, ): # MFABBlock is just a modified version of SE-blocks, one for skip, one for input @@ -60,10 +60,13 @@ def __init__( in_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ), md.Conv2dReLU( - in_channels, skip_channels, kernel_size=1, use_batchnorm=use_batchnorm + in_channels, + skip_channels, + kernel_size=1, + use_norm=use_norm, ), ) reduced_channels = max(1, skip_channels // reduction) @@ -87,14 +90,14 @@ def __init__( out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) self.conv2 = md.Conv2dReLU( out_channels, out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) def forward( @@ -119,7 +122,7 @@ def __init__( in_channels: int, skip_channels: int, out_channels: int, - use_batchnorm: bool = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): super().__init__() self.conv1 = md.Conv2dReLU( @@ -127,14 +130,14 @@ def __init__( out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) self.conv2 = md.Conv2dReLU( out_channels, out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) def forward( @@ -155,7 +158,7 @@ def __init__( decoder_channels: List[int], n_blocks: int = 5, reduction: int = 16, - use_batchnorm: bool = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", pab_channels: int = 64, ): super().__init__() @@ -182,7 +185,7 @@ def __init__( self.center = PABBlock(head_channels, pab_channels=pab_channels) # combine decoder keyword arguments - kwargs = dict(use_batchnorm=use_batchnorm) # no attention type here + kwargs = dict(use_norm=use_norm) # no attention type here blocks = [ MFABBlock(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs) if skip_ch > 0 diff --git a/segmentation_models_pytorch/decoders/manet/model.py b/segmentation_models_pytorch/decoders/manet/model.py index 6ed59207..a478b5c5 100644 --- a/segmentation_models_pytorch/decoders/manet/model.py +++ b/segmentation_models_pytorch/decoders/manet/model.py @@ -1,4 +1,5 @@ -from typing import Any, List, Optional, Union +import warnings +from typing import Any, Dict, Optional, Union, Sequence, Callable from segmentation_models_pytorch.base import ( ClassificationHead, @@ -29,9 +30,22 @@ class MAnet(SegmentationModel): other pretrained weights (see table with available weights for each encoder_name) decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. Length of the list should be the same as **encoder_depth** - decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers - is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. - Available options are **True, False, "inplace"** + decoder_use_norm: Specifies normalization between Conv2D and activation. + Accepts the following types: + - **True**: Defaults to `"batchnorm"`. + - **False**: No normalization (`nn.Identity`). + - **str**: Specifies normalization type using default parameters. Available values: + `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`. + - **dict**: Fully customizable normalization settings. Structure: + ```python + {"type": <norm_type>, **kwargs} + ``` + where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation. + + **Example**: + ```python + decoder_use_norm={"type": "layernorm", "eps": 1e-2} + ``` decoder_pab_channels: A number of channels for PAB module in decoder. Default is 64. in_channels: A number of input channels for the model, default is 3 (RGB images) @@ -63,17 +77,26 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_use_batchnorm: bool = True, - decoder_channels: List[int] = (256, 128, 64, 32, 16), + decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + decoder_channels: Sequence[int] = (256, 128, 64, 32, 16), decoder_pab_channels: int = 64, in_channels: int = 3, classes: int = 1, - activation: Optional[Union[str, callable]] = None, + activation: Optional[Union[str, Callable]] = None, aux_params: Optional[dict] = None, **kwargs: dict[str, Any], ): super().__init__() + decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) + if decoder_use_batchnorm is not None: + warnings.warn( + "The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm", + DeprecationWarning, + stacklevel=2, + ) + decoder_use_norm = decoder_use_batchnorm + self.encoder = get_encoder( encoder_name, in_channels=in_channels, @@ -86,7 +109,7 @@ def __init__( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, n_blocks=encoder_depth, - use_batchnorm=decoder_use_batchnorm, + use_norm=decoder_use_norm, pab_channels=decoder_pab_channels, ) diff --git a/segmentation_models_pytorch/decoders/pspnet/decoder.py b/segmentation_models_pytorch/decoders/pspnet/decoder.py index 42ac42d0..80ad289c 100644 --- a/segmentation_models_pytorch/decoders/pspnet/decoder.py +++ b/segmentation_models_pytorch/decoders/pspnet/decoder.py @@ -1,8 +1,9 @@ +from typing import Any, Dict, List, Tuple, Union + import torch import torch.nn as nn import torch.nn.functional as F -from typing import List, Tuple from segmentation_models_pytorch.base import modules @@ -12,17 +13,17 @@ def __init__( in_channels: int, out_channels: int, pool_size: int, - use_batchnorm: bool = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): super().__init__() if pool_size == 1: - use_batchnorm = False # PyTorch does not support BatchNorm for 1x1 shape + use_norm = "identity" # PyTorch does not support BatchNorm for 1x1 shape self.pool = nn.Sequential( nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)), modules.Conv2dReLU( - in_channels, out_channels, (1, 1), use_batchnorm=use_batchnorm + in_channels, out_channels, kernel_size=1, use_norm=use_norm ), ) @@ -38,7 +39,7 @@ def __init__( self, in_channels: int, sizes: Tuple[int, ...] = (1, 2, 3, 6), - use_batchnorm: bool = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): super().__init__() @@ -48,7 +49,7 @@ def __init__( in_channels, in_channels // len(sizes), size, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) for size in sizes ] @@ -64,7 +65,7 @@ class PSPDecoder(nn.Module): def __init__( self, encoder_channels: List[int], - use_batchnorm: bool = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", out_channels: int = 512, dropout: float = 0.2, ): @@ -73,14 +74,14 @@ def __init__( self.psp = PSPModule( in_channels=encoder_channels[-1], sizes=(1, 2, 3, 6), - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) self.conv = modules.Conv2dReLU( in_channels=encoder_channels[-1] * 2, out_channels=out_channels, kernel_size=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) self.dropout = nn.Dropout2d(p=dropout) diff --git a/segmentation_models_pytorch/decoders/pspnet/model.py b/segmentation_models_pytorch/decoders/pspnet/model.py index 8b99b3da..4b2d19f0 100644 --- a/segmentation_models_pytorch/decoders/pspnet/model.py +++ b/segmentation_models_pytorch/decoders/pspnet/model.py @@ -1,4 +1,5 @@ -from typing import Any, Optional, Union +import warnings +from typing import Any, Dict, Optional, Union, Callable from segmentation_models_pytorch.base import ( ClassificationHead, @@ -28,9 +29,22 @@ class PSPNet(SegmentationModel): encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and other pretrained weights (see table with available weights for each encoder_name) psp_out_channels: A number of filters in Spatial Pyramid - psp_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers - is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. - Available options are **True, False, "inplace"** + decoder_use_norm: Specifies normalization between Conv2D and activation. + Accepts the following types: + - **True**: Defaults to `"batchnorm"`. + - **False**: No normalization (`nn.Identity`). + - **str**: Specifies normalization type using default parameters. Available values: + `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`. + - **dict**: Fully customizable normalization settings. Structure: + ```python + {"type": <norm_type>, **kwargs} + ``` + where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation. + + **Example**: + ```python + decoder_use_norm={"type": "layernorm", "eps": 1e-2} + ``` psp_dropout: Spatial dropout rate in [0, 1) used in Spatial Pyramid in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) @@ -62,17 +76,26 @@ def __init__( encoder_weights: Optional[str] = "imagenet", encoder_depth: int = 3, psp_out_channels: int = 512, - psp_use_batchnorm: bool = True, + decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", psp_dropout: float = 0.2, in_channels: int = 3, classes: int = 1, - activation: Optional[Union[str, callable]] = None, + activation: Optional[Union[str, Callable]] = None, upsampling: int = 8, aux_params: Optional[dict] = None, **kwargs: dict[str, Any], ): super().__init__() + psp_use_batchnorm = kwargs.pop("psp_use_batchnorm", None) + if psp_use_batchnorm is not None: + warnings.warn( + "The usage of psp_use_batchnorm is deprecated. Please modify your code for decoder_use_norm", + DeprecationWarning, + stacklevel=2, + ) + decoder_use_norm = psp_use_batchnorm + self.encoder = get_encoder( encoder_name, in_channels=in_channels, @@ -83,7 +106,7 @@ def __init__( self.decoder = PSPDecoder( encoder_channels=self.encoder.out_channels, - use_batchnorm=psp_use_batchnorm, + use_norm=decoder_use_norm, out_channels=psp_out_channels, dropout=psp_dropout, ) diff --git a/segmentation_models_pytorch/decoders/segformer/decoder.py b/segmentation_models_pytorch/decoders/segformer/decoder.py index cd160a4c..2bfadfff 100644 --- a/segmentation_models_pytorch/decoders/segformer/decoder.py +++ b/segmentation_models_pytorch/decoders/segformer/decoder.py @@ -50,7 +50,7 @@ def __init__( in_channels=(len(encoder_channels) - 1) * segmentation_channels, out_channels=segmentation_channels, kernel_size=1, - use_batchnorm=True, + use_norm="batchnorm", ) def forward(self, features: List[torch.Tensor]) -> torch.Tensor: diff --git a/segmentation_models_pytorch/decoders/unet/decoder.py b/segmentation_models_pytorch/decoders/unet/decoder.py index 0e4f35fd..cfeb267e 100644 --- a/segmentation_models_pytorch/decoders/unet/decoder.py +++ b/segmentation_models_pytorch/decoders/unet/decoder.py @@ -1,8 +1,9 @@ +from typing import Any, Dict, List, Optional, Sequence, Union + import torch import torch.nn as nn import torch.nn.functional as F -from typing import Optional, Sequence, List from segmentation_models_pytorch.base import modules as md @@ -14,7 +15,7 @@ def __init__( in_channels: int, skip_channels: int, out_channels: int, - use_batchnorm: bool = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", attention_type: Optional[str] = None, interpolation_mode: str = "nearest", ): @@ -25,7 +26,7 @@ def __init__( out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) self.attention1 = md.Attention( attention_type, in_channels=in_channels + skip_channels @@ -35,7 +36,7 @@ def __init__( out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) self.attention2 = md.Attention(attention_type, in_channels=out_channels) @@ -63,20 +64,25 @@ def forward( class UnetCenterBlock(nn.Sequential): """Center block of the Unet decoder. Applied to the last feature map of the encoder.""" - def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True): + def __init__( + self, + in_channels: int, + out_channels: int, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + ): conv1 = md.Conv2dReLU( in_channels, out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) conv2 = md.Conv2dReLU( out_channels, out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) super().__init__(conv1, conv2) @@ -93,7 +99,7 @@ def __init__( encoder_channels: Sequence[int], decoder_channels: Sequence[int], n_blocks: int = 5, - use_batchnorm: bool = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", attention_type: Optional[str] = None, add_center_block: bool = False, interpolation_mode: str = "nearest", @@ -120,7 +126,9 @@ def __init__( if add_center_block: self.center = UnetCenterBlock( - head_channels, head_channels, use_batchnorm=use_batchnorm + head_channels, + head_channels, + use_norm=use_norm, ) else: self.center = nn.Identity() @@ -134,7 +142,7 @@ def __init__( block_in_channels, block_skip_channels, block_out_channels, - use_batchnorm=use_batchnorm, + use_norm=use_norm, attention_type=attention_type, interpolation_mode=interpolation_mode, ) diff --git a/segmentation_models_pytorch/decoders/unet/model.py b/segmentation_models_pytorch/decoders/unet/model.py index 4b30527d..22d7db11 100644 --- a/segmentation_models_pytorch/decoders/unet/model.py +++ b/segmentation_models_pytorch/decoders/unet/model.py @@ -1,4 +1,5 @@ -from typing import Any, Optional, Union, Callable, Sequence +import warnings +from typing import Any, Dict, Optional, Union, Callable, Sequence from segmentation_models_pytorch.base import ( ClassificationHead, @@ -39,9 +40,22 @@ class Unet(SegmentationModel): other pretrained weights (see table with available weights for each encoder_name) decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. Length of the list should be the same as **encoder_depth** - decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers - is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. - Available options are **True, False, "inplace"** + decoder_use_norm: Specifies normalization between Conv2D and activation. + Accepts the following types: + - **True**: Defaults to `"batchnorm"`. + - **False**: No normalization (`nn.Identity`). + - **str**: Specifies normalization type using default parameters. Available values: + `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`. + - **dict**: Fully customizable normalization settings. Structure: + ```python + {"type": <norm_type>, **kwargs} + ``` + where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation. + + **Example**: + ```python + decoder_use_norm={"type": "layernorm", "eps": 1e-2} + ``` decoder_attention_type: Attention module used in decoder of the model. Available options are **None** and **scse** (https://arxiv.org/abs/1808.08127). decoder_interpolation_mode: Interpolation mode used in decoder of the model. Available options are @@ -95,7 +109,7 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_use_batchnorm: bool = True, + decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", decoder_channels: Sequence[int] = (256, 128, 64, 32, 16), decoder_attention_type: Optional[str] = None, decoder_interpolation_mode: str = "nearest", @@ -107,6 +121,15 @@ def __init__( ): super().__init__() + decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) + if decoder_use_batchnorm is not None: + warnings.warn( + "The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm", + DeprecationWarning, + stacklevel=2, + ) + decoder_use_norm = decoder_use_batchnorm + self.encoder = get_encoder( encoder_name, in_channels=in_channels, @@ -116,11 +139,12 @@ def __init__( ) add_center_block = encoder_name.startswith("vgg") + self.decoder = UnetDecoder( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, n_blocks=encoder_depth, - use_batchnorm=decoder_use_batchnorm, + use_norm=decoder_use_norm, add_center_block=add_center_block, attention_type=decoder_attention_type, interpolation_mode=decoder_interpolation_mode, diff --git a/segmentation_models_pytorch/decoders/unetplusplus/decoder.py b/segmentation_models_pytorch/decoders/unetplusplus/decoder.py index 3282849f..e09327ac 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/decoder.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/decoder.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch.nn.functional as F -from typing import Optional, List +from typing import Any, Dict, List, Optional, Union, Sequence from segmentation_models_pytorch.base import modules as md @@ -13,7 +13,7 @@ def __init__( in_channels: int, skip_channels: int, out_channels: int, - use_batchnorm: bool = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", attention_type: Optional[str] = None, ): super().__init__() @@ -22,7 +22,7 @@ def __init__( out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) self.attention1 = md.Attention( attention_type, in_channels=in_channels + skip_channels @@ -32,7 +32,7 @@ def __init__( out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) self.attention2 = md.Attention(attention_type, in_channels=out_channels) @@ -50,20 +50,25 @@ def forward( class CenterBlock(nn.Sequential): - def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True): + def __init__( + self, + in_channels: int, + out_channels: int, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + ): conv1 = md.Conv2dReLU( in_channels, out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) conv2 = md.Conv2dReLU( out_channels, out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) super().__init__(conv1, conv2) @@ -71,10 +76,10 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr class UnetPlusPlusDecoder(nn.Module): def __init__( self, - encoder_channels: List[int], - decoder_channels: List[int], + encoder_channels: Sequence[int], + decoder_channels: Sequence[int], n_blocks: int = 5, - use_batchnorm: bool = True, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", attention_type: Optional[str] = None, center: bool = False, ): @@ -97,13 +102,18 @@ def __init__( self.out_channels = decoder_channels if center: self.center = CenterBlock( - head_channels, head_channels, use_batchnorm=use_batchnorm + head_channels, + head_channels, + use_norm=use_norm, ) else: self.center = nn.Identity() # combine decoder keyword arguments - kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) + kwargs = dict( + use_norm=use_norm, + attention_type=attention_type, + ) blocks = {} for layer_idx in range(len(self.in_channels) - 1): diff --git a/segmentation_models_pytorch/decoders/unetplusplus/model.py b/segmentation_models_pytorch/decoders/unetplusplus/model.py index 5c3d3a91..be0f8f83 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/model.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/model.py @@ -1,4 +1,5 @@ -from typing import Any, List, Optional, Union +import warnings +from typing import Any, Dict, Sequence, Optional, Union, Callable from segmentation_models_pytorch.base import ( ClassificationHead, @@ -28,9 +29,22 @@ class UnetPlusPlus(SegmentationModel): other pretrained weights (see table with available weights for each encoder_name) decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. Length of the list should be the same as **encoder_depth** - decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers - is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. - Available options are **True, False, "inplace"** + decoder_use_norm: Specifies normalization between Conv2D and activation. + Accepts the following types: + - **True**: Defaults to `"batchnorm"`. + - **False**: No normalization (`nn.Identity`). + - **str**: Specifies normalization type using default parameters. Available values: + `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`. + - **dict**: Fully customizable normalization settings. Structure: + ```python + {"type": <norm_type>, **kwargs} + ``` + where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation. + + **Example**: + ```python + decoder_use_norm={"type": "layernorm", "eps": 1e-2} + ``` decoder_attention_type: Attention module used in decoder of the model. Available options are **None** and **scse** (https://arxiv.org/abs/1808.08127). in_channels: A number of input channels for the model, default is 3 (RGB images) @@ -64,12 +78,12 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_use_batchnorm: bool = True, - decoder_channels: List[int] = (256, 128, 64, 32, 16), + decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + decoder_channels: Sequence[int] = (256, 128, 64, 32, 16), decoder_attention_type: Optional[str] = None, in_channels: int = 3, classes: int = 1, - activation: Optional[Union[str, callable]] = None, + activation: Optional[Union[str, Callable]] = None, aux_params: Optional[dict] = None, **kwargs: dict[str, Any], ): @@ -80,6 +94,15 @@ def __init__( "UnetPlusPlus is not support encoder_name={}".format(encoder_name) ) + decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) + if decoder_use_batchnorm is not None: + warnings.warn( + "The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm", + DeprecationWarning, + stacklevel=2, + ) + decoder_use_norm = decoder_use_batchnorm + self.encoder = get_encoder( encoder_name, in_channels=in_channels, @@ -92,7 +115,7 @@ def __init__( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, n_blocks=encoder_depth, - use_batchnorm=decoder_use_batchnorm, + use_norm=decoder_use_norm, center=True if encoder_name.startswith("vgg") else False, attention_type=decoder_attention_type, ) diff --git a/segmentation_models_pytorch/decoders/upernet/decoder.py b/segmentation_models_pytorch/decoders/upernet/decoder.py index 99c74fb1..810778f3 100644 --- a/segmentation_models_pytorch/decoders/upernet/decoder.py +++ b/segmentation_models_pytorch/decoders/upernet/decoder.py @@ -1,3 +1,5 @@ +from typing import Any, Dict, Union, Sequence + import torch import torch.nn as nn import torch.nn.functional as F @@ -8,10 +10,10 @@ class PSPModule(nn.Module): def __init__( self, - in_channels, - out_channels, - sizes=(1, 2, 3, 6), - use_batchnorm=True, + in_channels: int, + out_channels: int, + sizes: Sequence[int] = (1, 2, 3, 6), + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): super().__init__() self.blocks = nn.ModuleList( @@ -22,7 +24,7 @@ def __init__( in_channels, in_channels // len(sizes), kernel_size=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ), ) for size in sizes @@ -32,7 +34,7 @@ def __init__( in_channels=in_channels * 2, out_channels=out_channels, kernel_size=1, - use_batchnorm=True, + use_norm="batchnorm", ) def forward(self, x): @@ -48,14 +50,19 @@ def forward(self, x): class FPNBlock(nn.Module): - def __init__(self, skip_channels, pyramid_channels, use_batchnorm=True): + def __init__( + self, + skip_channels: int, + pyramid_channels: int, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + ): super().__init__() self.skip_conv = ( md.Conv2dReLU( skip_channels, pyramid_channels, kernel_size=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) if skip_channels != 0 else nn.Identity() @@ -73,10 +80,11 @@ def forward(self, x, skip): class UPerNetDecoder(nn.Module): def __init__( self, - encoder_channels, - encoder_depth=5, - pyramid_channels=256, - segmentation_channels=64, + encoder_channels: Sequence[int], + encoder_depth: int = 5, + pyramid_channels: int = 256, + segmentation_channels: int = 64, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): super().__init__() @@ -94,7 +102,7 @@ def __init__( in_channels=encoder_channels[0], out_channels=pyramid_channels, sizes=(1, 2, 3, 6), - use_batchnorm=True, + use_norm=use_norm, ) # FPN Module @@ -107,7 +115,7 @@ def __init__( out_channels=segmentation_channels, kernel_size=3, padding=1, - use_batchnorm=True, + use_norm=use_norm, ) def forward(self, features): diff --git a/segmentation_models_pytorch/decoders/upernet/model.py b/segmentation_models_pytorch/decoders/upernet/model.py index 7ffeee5b..6ad5afd5 100644 --- a/segmentation_models_pytorch/decoders/upernet/model.py +++ b/segmentation_models_pytorch/decoders/upernet/model.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Dict, Optional, Union, Callable from segmentation_models_pytorch.base import ( ClassificationHead, @@ -25,6 +25,22 @@ class UPerNet(SegmentationModel): other pretrained weights (see table with available weights for each encoder_name) decoder_pyramid_channels: A number of convolution filters in Feature Pyramid, default is 256 decoder_segmentation_channels: A number of convolution filters in segmentation blocks, default is 64 + decoder_use_norm: Specifies normalization between Conv2D and activation. + Accepts the following types: + - **True**: Defaults to `"batchnorm"`. + - **False**: No normalization (`nn.Identity`). + - **str**: Specifies normalization type using default parameters. Available values: + `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`. + - **dict**: Fully customizable normalization settings. Structure: + ```python + {"type": <norm_type>, **kwargs} + ``` + where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation. + + **Example**: + ```python + use_norm={"type": "layernorm", "eps": 1e-2} + ``` in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) activation: An activation function to apply after the final convolution layer. @@ -58,9 +74,10 @@ def __init__( encoder_weights: Optional[str] = "imagenet", decoder_pyramid_channels: int = 256, decoder_segmentation_channels: int = 64, + decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", in_channels: int = 3, classes: int = 1, - activation: Optional[Union[str, callable]] = None, + activation: Optional[Union[str, Callable]] = None, aux_params: Optional[dict] = None, **kwargs: dict[str, Any], ): @@ -79,6 +96,7 @@ def __init__( encoder_depth=encoder_depth, pyramid_channels=decoder_pyramid_channels, segmentation_channels=decoder_segmentation_channels, + use_norm=decoder_use_norm, ) self.segmentation_head = SegmentationHead( diff --git a/tests/base/test_modules.py b/tests/base/test_modules.py new file mode 100644 index 00000000..5afa8e4f --- /dev/null +++ b/tests/base/test_modules.py @@ -0,0 +1,64 @@ +import pytest +from torch import nn +from segmentation_models_pytorch.base.modules import Conv2dReLU + + +def test_conv2drelu_batchnorm(): + module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="batchnorm") + + assert isinstance(module[0], nn.Conv2d) + assert isinstance(module[1], nn.BatchNorm2d) + assert isinstance(module[2], nn.ReLU) + + +def test_conv2drelu_batchnorm_with_keywords(): + module = Conv2dReLU( + 3, + 16, + kernel_size=3, + padding=1, + use_norm={"type": "batchnorm", "momentum": 1e-4, "affine": False}, + ) + + assert isinstance(module[0], nn.Conv2d) + assert isinstance(module[1], nn.BatchNorm2d) + assert module[1].momentum == 1e-4 and module[1].affine is False + assert isinstance(module[2], nn.ReLU) + + +def test_conv2drelu_identity(): + module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="identity") + + assert isinstance(module[0], nn.Conv2d) + assert isinstance(module[1], nn.Identity) + assert isinstance(module[2], nn.ReLU) + + +def test_conv2drelu_layernorm(): + module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="layernorm") + + assert isinstance(module[0], nn.Conv2d) + assert isinstance(module[1], nn.LayerNorm) + assert isinstance(module[2], nn.ReLU) + + +def test_conv2drelu_instancenorm(): + module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="instancenorm") + + assert isinstance(module[0], nn.Conv2d) + assert isinstance(module[1], nn.InstanceNorm2d) + assert isinstance(module[2], nn.ReLU) + + +def test_conv2drelu_inplace(): + try: + from inplace_abn import InPlaceABN + except ImportError: + pytest.skip("InPlaceABN is not installed") + + module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="inplace") + + assert len(module) == 3 + assert isinstance(module[0], nn.Conv2d) + assert isinstance(module[1], InPlaceABN) + assert isinstance(module[2], nn.Identity) diff --git a/tests/encoders/test_batchnorm_deprecation.py b/tests/encoders/test_batchnorm_deprecation.py new file mode 100644 index 00000000..ff53563f --- /dev/null +++ b/tests/encoders/test_batchnorm_deprecation.py @@ -0,0 +1,54 @@ +import pytest + +import torch + +import segmentation_models_pytorch as smp +from tests.utils import check_two_models_strictly_equal + + +@pytest.mark.parametrize("model_name", ["unet", "unetplusplus", "linknet", "manet"]) +@pytest.mark.parametrize("decoder_option", [True, False, "inplace"]) +def test_seg_models_before_after_use_norm(model_name, decoder_option): + torch.manual_seed(42) + with pytest.warns(DeprecationWarning): + model_decoder_batchnorm = smp.create_model( + model_name, + "mobilenet_v2", + encoder_weights=None, + decoder_use_batchnorm=decoder_option, + ) + model_decoder_norm = smp.create_model( + model_name, + "mobilenet_v2", + encoder_weights=None, + decoder_use_norm=decoder_option, + ) + + model_decoder_norm.load_state_dict(model_decoder_batchnorm.state_dict()) + + check_two_models_strictly_equal( + model_decoder_batchnorm, model_decoder_norm, torch.rand(1, 3, 224, 224) + ) + + +@pytest.mark.parametrize("decoder_option", [True, False, "inplace"]) +def test_pspnet_before_after_use_norm(decoder_option): + torch.manual_seed(42) + with pytest.warns(DeprecationWarning): + model_decoder_batchnorm = smp.create_model( + "pspnet", + "mobilenet_v2", + encoder_weights=None, + psp_use_batchnorm=decoder_option, + ) + model_decoder_norm = smp.create_model( + "pspnet", + "mobilenet_v2", + encoder_weights=None, + decoder_use_norm=decoder_option, + ) + model_decoder_norm.load_state_dict(model_decoder_batchnorm.state_dict()) + + check_two_models_strictly_equal( + model_decoder_batchnorm, model_decoder_norm, torch.rand(1, 3, 224, 224) + ) diff --git a/tests/models/base.py b/tests/models/base.py index f7492986..b96e76e8 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -282,4 +282,4 @@ def test_torch_script(self): eager_output = model(sample) self.assertEqual(scripted_output.shape, eager_output.shape) - torch.testing.assert_close(scripted_output, eager_output) + torch.testing.assert_close(scripted_output, eager_output, rtol=1e-3, atol=1e-3) diff --git a/tests/utils.py b/tests/utils.py index 6e201f1d..1e97b40b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -58,3 +58,23 @@ def check_run_test_on_diff_or_main(filepath_patterns: List[str]): return True return False + + +def check_two_models_strictly_equal( + model_a: torch.nn.Module, model_b: torch.nn.Module, input_data: torch.Tensor +) -> None: + for (k1, v1), (k2, v2) in zip( + model_a.state_dict().items(), model_b.state_dict().items() + ): + assert k1 == k2, f"Key mismatch: {k1} != {k2}" + torch.testing.assert_close( + v1, v2, msg=f"Tensor mismatch at key '{k1}':\n{v1} !=\n{v2}" + ) + + model_a.eval() + model_b.eval() + with torch.inference_mode(): + output_a = model_a(input_data) + output_b = model_b(input_data) + + torch.testing.assert_close(output_a, output_b)