diff --git a/segmentation_models_pytorch/base/modules.py b/segmentation_models_pytorch/base/modules.py
index cbd643b6..15cfdb12 100644
--- a/segmentation_models_pytorch/base/modules.py
+++ b/segmentation_models_pytorch/base/modules.py
@@ -1,3 +1,5 @@
+from typing import Any, Dict, Union
+
 import torch
 import torch.nn as nn
 
@@ -7,43 +9,109 @@
     InPlaceABN = None
 
 
+def get_norm_layer(
+    use_norm: Union[bool, str, Dict[str, Any]], out_channels: int
+) -> nn.Module:
+    supported_norms = ("inplace", "batchnorm", "identity", "layernorm", "instancenorm")
+
+    # Step 1. Convert tot dict representation
+
+    ## Check boolean
+    if use_norm is True:
+        norm_params = {"type": "batchnorm"}
+    elif use_norm is False:
+        norm_params = {"type": "identity"}
+
+    ## Check string
+    elif isinstance(use_norm, str):
+        norm_str = use_norm.lower()
+        if norm_str == "inplace":
+            norm_params = {
+                "type": "inplace",
+                "activation": "leaky_relu",
+                "activation_param": 0.0,
+            }
+        elif norm_str in supported_norms:
+            norm_params = {"type": norm_str}
+        else:
+            raise ValueError(
+                f"Unrecognized normalization type string provided: {use_norm}. Should be in "
+                f"{supported_norms}"
+            )
+
+    ## Check dict
+    elif isinstance(use_norm, dict):
+        norm_params = use_norm
+
+    else:
+        raise ValueError(
+            f"Invalid type for use_norm should either be a bool (batchnorm/identity), "
+            f"a string in {supported_norms}, or a dict like {{'type': 'batchnorm', **kwargs}}"
+        )
+
+    # Step 2. Check if the dict is valid
+    if "type" not in norm_params:
+        raise ValueError(
+            f"Malformed dictionary given in use_norm: {use_norm}. Should contain key 'type'."
+        )
+    if norm_params["type"] not in supported_norms:
+        raise ValueError(
+            f"Unrecognized normalization type string provided: {use_norm}. Should be in {supported_norms}"
+        )
+    if norm_params["type"] == "inplace" and InPlaceABN is None:
+        raise RuntimeError(
+            "In order to use `use_norm='inplace'` the inplace_abn package must be installed. Use:\n"
+            "  $ pip install -U wheel setuptools\n"
+            "  $ pip install inplace_abn --no-build-isolation\n"
+            "Also see: https://github.com/mapillary/inplace_abn"
+        )
+
+    # Step 3. Initialize the norm layer
+    norm_type = norm_params["type"]
+    norm_kwargs = {k: v for k, v in norm_params.items() if k != "type"}
+
+    if norm_type == "inplace":
+        norm = InPlaceABN(out_channels, **norm_kwargs)
+    elif norm_type == "batchnorm":
+        norm = nn.BatchNorm2d(out_channels, **norm_kwargs)
+    elif norm_type == "identity":
+        norm = nn.Identity()
+    elif norm_type == "layernorm":
+        norm = nn.LayerNorm(out_channels, **norm_kwargs)
+    elif norm_type == "instancenorm":
+        norm = nn.InstanceNorm2d(out_channels, **norm_kwargs)
+    else:
+        raise ValueError(f"Unrecognized normalization type: {norm_type}")
+
+    return norm
+
+
 class Conv2dReLU(nn.Sequential):
     def __init__(
         self,
-        in_channels,
-        out_channels,
-        kernel_size,
-        padding=0,
-        stride=1,
-        use_batchnorm=True,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        padding: int = 0,
+        stride: int = 1,
+        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
     ):
-        if use_batchnorm == "inplace" and InPlaceABN is None:
-            raise RuntimeError(
-                "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
-                + "To install see: https://github.com/mapillary/inplace_abn"
-            )
+        norm = get_norm_layer(use_norm, out_channels)
 
+        is_identity = isinstance(norm, nn.Identity)
         conv = nn.Conv2d(
             in_channels,
             out_channels,
             kernel_size,
             stride=stride,
             padding=padding,
-            bias=not (use_batchnorm),
+            bias=is_identity,
         )
-        relu = nn.ReLU(inplace=True)
-
-        if use_batchnorm == "inplace":
-            bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
-            relu = nn.Identity()
 
-        elif use_batchnorm and use_batchnorm != "inplace":
-            bn = nn.BatchNorm2d(out_channels)
-
-        else:
-            bn = nn.Identity()
+        is_inplaceabn = InPlaceABN is not None and isinstance(norm, InPlaceABN)
+        activation = nn.Identity() if is_inplaceabn else nn.ReLU(inplace=True)
 
-        super(Conv2dReLU, self).__init__(conv, bn, relu)
+        super(Conv2dReLU, self).__init__(conv, norm, activation)
 
 
 class SCSEModule(nn.Module):
diff --git a/segmentation_models_pytorch/decoders/linknet/decoder.py b/segmentation_models_pytorch/decoders/linknet/decoder.py
index 8dfd8434..95c7f9f6 100644
--- a/segmentation_models_pytorch/decoders/linknet/decoder.py
+++ b/segmentation_models_pytorch/decoders/linknet/decoder.py
@@ -1,28 +1,33 @@
 import torch
 import torch.nn as nn
 
-from typing import List, Optional
+from typing import Any, Dict, List, Optional, Union
 from segmentation_models_pytorch.base import modules
 
 
 class TransposeX2(nn.Sequential):
-    def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
+    ):
         super().__init__()
-        layers = [
-            nn.ConvTranspose2d(
-                in_channels, out_channels, kernel_size=4, stride=2, padding=1
-            ),
-            nn.ReLU(inplace=True),
-        ]
-
-        if use_batchnorm:
-            layers.insert(1, nn.BatchNorm2d(out_channels))
-
-        super().__init__(*layers)
+        conv = nn.ConvTranspose2d(
+            in_channels, out_channels, kernel_size=4, stride=2, padding=1
+        )
+        norm = modules.get_norm_layer(use_norm, out_channels)
+        activation = nn.ReLU(inplace=True)
+        super().__init__(conv, norm, activation)
 
 
 class DecoderBlock(nn.Module):
-    def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
+    ):
         super().__init__()
 
         self.block = nn.Sequential(
@@ -30,16 +35,14 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr
                 in_channels,
                 in_channels // 4,
                 kernel_size=1,
-                use_batchnorm=use_batchnorm,
-            ),
-            TransposeX2(
-                in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm
+                use_norm=use_norm,
             ),
+            TransposeX2(in_channels // 4, in_channels // 4, use_norm=use_norm),
             modules.Conv2dReLU(
                 in_channels // 4,
                 out_channels,
                 kernel_size=1,
-                use_batchnorm=use_batchnorm,
+                use_norm=use_norm,
             ),
         )
 
@@ -58,7 +61,7 @@ def __init__(
         encoder_channels: List[int],
         prefinal_channels: int = 32,
         n_blocks: int = 5,
-        use_batchnorm: bool = True,
+        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
     ):
         super().__init__()
 
@@ -71,7 +74,11 @@ def __init__(
 
         self.blocks = nn.ModuleList(
             [
-                DecoderBlock(channels[i], channels[i + 1], use_batchnorm=use_batchnorm)
+                DecoderBlock(
+                    channels[i],
+                    channels[i + 1],
+                    use_norm=use_norm,
+                )
                 for i in range(n_blocks)
             ]
         )
diff --git a/segmentation_models_pytorch/decoders/linknet/model.py b/segmentation_models_pytorch/decoders/linknet/model.py
index 356468ed..be0d01b2 100644
--- a/segmentation_models_pytorch/decoders/linknet/model.py
+++ b/segmentation_models_pytorch/decoders/linknet/model.py
@@ -1,4 +1,5 @@
-from typing import Any, Optional, Union
+import warnings
+from typing import Any, Dict, Optional, Union, Callable
 
 from segmentation_models_pytorch.base import (
     ClassificationHead,
@@ -29,9 +30,22 @@ class Linknet(SegmentationModel):
             Default is 5
         encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
             other pretrained weights (see table with available weights for each encoder_name)
-        decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers
-            is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption.
-            Available options are **True, False, "inplace"**
+        decoder_use_norm:     Specifies normalization between Conv2D and activation.
+            Accepts the following types:
+            - **True**: Defaults to `"batchnorm"`.
+            - **False**: No normalization (`nn.Identity`).
+            - **str**: Specifies normalization type using default parameters. Available values:
+              `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`.
+            - **dict**: Fully customizable normalization settings. Structure:
+              ```python
+              {"type": <norm_type>, **kwargs}
+              ```
+              where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation.
+
+            **Example**:
+            ```python
+            decoder_use_norm={"type": "layernorm", "eps": 1e-2}
+            ```
         in_channels: A number of input channels for the model, default is 3 (RGB images)
         classes: A number of classes for output mask (or you can think as a number of channels of output mask)
         activation: An activation function to apply after the final convolution layer.
@@ -60,10 +74,10 @@ def __init__(
         encoder_name: str = "resnet34",
         encoder_depth: int = 5,
         encoder_weights: Optional[str] = "imagenet",
-        decoder_use_batchnorm: bool = True,
+        decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
         in_channels: int = 3,
         classes: int = 1,
-        activation: Optional[Union[str, callable]] = None,
+        activation: Optional[Union[str, Callable]] = None,
         aux_params: Optional[dict] = None,
         **kwargs: dict[str, Any],
     ):
@@ -74,6 +88,15 @@ def __init__(
                 "Encoder `{}` is not supported for Linknet".format(encoder_name)
             )
 
+        decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None)
+        if decoder_use_batchnorm is not None:
+            warnings.warn(
+                "The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+            decoder_use_norm = decoder_use_batchnorm
+
         self.encoder = get_encoder(
             encoder_name,
             in_channels=in_channels,
@@ -86,7 +109,7 @@ def __init__(
             encoder_channels=self.encoder.out_channels,
             n_blocks=encoder_depth,
             prefinal_channels=32,
-            use_batchnorm=decoder_use_batchnorm,
+            use_norm=decoder_use_norm,
         )
 
         self.segmentation_head = SegmentationHead(
diff --git a/segmentation_models_pytorch/decoders/manet/decoder.py b/segmentation_models_pytorch/decoders/manet/decoder.py
index 61b1fe57..ae2498c7 100644
--- a/segmentation_models_pytorch/decoders/manet/decoder.py
+++ b/segmentation_models_pytorch/decoders/manet/decoder.py
@@ -1,9 +1,9 @@
+from typing import Any, Dict, List, Optional, Union
+
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
-from typing import List, Optional
-
 from segmentation_models_pytorch.base import modules as md
 
 
@@ -49,7 +49,7 @@ def __init__(
         in_channels: int,
         skip_channels: int,
         out_channels: int,
-        use_batchnorm: bool = True,
+        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
         reduction: int = 16,
     ):
         # MFABBlock is just a modified version of SE-blocks, one for skip, one for input
@@ -60,10 +60,13 @@ def __init__(
                 in_channels,
                 kernel_size=3,
                 padding=1,
-                use_batchnorm=use_batchnorm,
+                use_norm=use_norm,
             ),
             md.Conv2dReLU(
-                in_channels, skip_channels, kernel_size=1, use_batchnorm=use_batchnorm
+                in_channels,
+                skip_channels,
+                kernel_size=1,
+                use_norm=use_norm,
             ),
         )
         reduced_channels = max(1, skip_channels // reduction)
@@ -87,14 +90,14 @@ def __init__(
             out_channels,
             kernel_size=3,
             padding=1,
-            use_batchnorm=use_batchnorm,
+            use_norm=use_norm,
         )
         self.conv2 = md.Conv2dReLU(
             out_channels,
             out_channels,
             kernel_size=3,
             padding=1,
-            use_batchnorm=use_batchnorm,
+            use_norm=use_norm,
         )
 
     def forward(
@@ -119,7 +122,7 @@ def __init__(
         in_channels: int,
         skip_channels: int,
         out_channels: int,
-        use_batchnorm: bool = True,
+        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
     ):
         super().__init__()
         self.conv1 = md.Conv2dReLU(
@@ -127,14 +130,14 @@ def __init__(
             out_channels,
             kernel_size=3,
             padding=1,
-            use_batchnorm=use_batchnorm,
+            use_norm=use_norm,
         )
         self.conv2 = md.Conv2dReLU(
             out_channels,
             out_channels,
             kernel_size=3,
             padding=1,
-            use_batchnorm=use_batchnorm,
+            use_norm=use_norm,
         )
 
     def forward(
@@ -155,7 +158,7 @@ def __init__(
         decoder_channels: List[int],
         n_blocks: int = 5,
         reduction: int = 16,
-        use_batchnorm: bool = True,
+        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
         pab_channels: int = 64,
     ):
         super().__init__()
@@ -182,7 +185,7 @@ def __init__(
         self.center = PABBlock(head_channels, pab_channels=pab_channels)
 
         # combine decoder keyword arguments
-        kwargs = dict(use_batchnorm=use_batchnorm)  # no attention type here
+        kwargs = dict(use_norm=use_norm)  # no attention type here
         blocks = [
             MFABBlock(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs)
             if skip_ch > 0
diff --git a/segmentation_models_pytorch/decoders/manet/model.py b/segmentation_models_pytorch/decoders/manet/model.py
index 6ed59207..a478b5c5 100644
--- a/segmentation_models_pytorch/decoders/manet/model.py
+++ b/segmentation_models_pytorch/decoders/manet/model.py
@@ -1,4 +1,5 @@
-from typing import Any, List, Optional, Union
+import warnings
+from typing import Any, Dict, Optional, Union, Sequence, Callable
 
 from segmentation_models_pytorch.base import (
     ClassificationHead,
@@ -29,9 +30,22 @@ class MAnet(SegmentationModel):
             other pretrained weights (see table with available weights for each encoder_name)
         decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder.
             Length of the list should be the same as **encoder_depth**
-        decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers
-            is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption.
-            Available options are **True, False, "inplace"**
+        decoder_use_norm: Specifies normalization between Conv2D and activation.
+            Accepts the following types:
+            - **True**: Defaults to `"batchnorm"`.
+            - **False**: No normalization (`nn.Identity`).
+            - **str**: Specifies normalization type using default parameters. Available values:
+              `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`.
+            - **dict**: Fully customizable normalization settings. Structure:
+              ```python
+              {"type": <norm_type>, **kwargs}
+              ```
+              where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation.
+
+            **Example**:
+            ```python
+            decoder_use_norm={"type": "layernorm", "eps": 1e-2}
+            ```
         decoder_pab_channels: A number of channels for PAB module in decoder.
             Default is 64.
         in_channels: A number of input channels for the model, default is 3 (RGB images)
@@ -63,17 +77,26 @@ def __init__(
         encoder_name: str = "resnet34",
         encoder_depth: int = 5,
         encoder_weights: Optional[str] = "imagenet",
-        decoder_use_batchnorm: bool = True,
-        decoder_channels: List[int] = (256, 128, 64, 32, 16),
+        decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
+        decoder_channels: Sequence[int] = (256, 128, 64, 32, 16),
         decoder_pab_channels: int = 64,
         in_channels: int = 3,
         classes: int = 1,
-        activation: Optional[Union[str, callable]] = None,
+        activation: Optional[Union[str, Callable]] = None,
         aux_params: Optional[dict] = None,
         **kwargs: dict[str, Any],
     ):
         super().__init__()
 
+        decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None)
+        if decoder_use_batchnorm is not None:
+            warnings.warn(
+                "The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+            decoder_use_norm = decoder_use_batchnorm
+
         self.encoder = get_encoder(
             encoder_name,
             in_channels=in_channels,
@@ -86,7 +109,7 @@ def __init__(
             encoder_channels=self.encoder.out_channels,
             decoder_channels=decoder_channels,
             n_blocks=encoder_depth,
-            use_batchnorm=decoder_use_batchnorm,
+            use_norm=decoder_use_norm,
             pab_channels=decoder_pab_channels,
         )
 
diff --git a/segmentation_models_pytorch/decoders/pspnet/decoder.py b/segmentation_models_pytorch/decoders/pspnet/decoder.py
index 42ac42d0..80ad289c 100644
--- a/segmentation_models_pytorch/decoders/pspnet/decoder.py
+++ b/segmentation_models_pytorch/decoders/pspnet/decoder.py
@@ -1,8 +1,9 @@
+from typing import Any, Dict, List, Tuple, Union
+
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
-from typing import List, Tuple
 from segmentation_models_pytorch.base import modules
 
 
@@ -12,17 +13,17 @@ def __init__(
         in_channels: int,
         out_channels: int,
         pool_size: int,
-        use_batchnorm: bool = True,
+        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
     ):
         super().__init__()
 
         if pool_size == 1:
-            use_batchnorm = False  # PyTorch does not support BatchNorm for 1x1 shape
+            use_norm = "identity"  # PyTorch does not support BatchNorm for 1x1 shape
 
         self.pool = nn.Sequential(
             nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)),
             modules.Conv2dReLU(
-                in_channels, out_channels, (1, 1), use_batchnorm=use_batchnorm
+                in_channels, out_channels, kernel_size=1, use_norm=use_norm
             ),
         )
 
@@ -38,7 +39,7 @@ def __init__(
         self,
         in_channels: int,
         sizes: Tuple[int, ...] = (1, 2, 3, 6),
-        use_batchnorm: bool = True,
+        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
     ):
         super().__init__()
 
@@ -48,7 +49,7 @@ def __init__(
                     in_channels,
                     in_channels // len(sizes),
                     size,
-                    use_batchnorm=use_batchnorm,
+                    use_norm=use_norm,
                 )
                 for size in sizes
             ]
@@ -64,7 +65,7 @@ class PSPDecoder(nn.Module):
     def __init__(
         self,
         encoder_channels: List[int],
-        use_batchnorm: bool = True,
+        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
         out_channels: int = 512,
         dropout: float = 0.2,
     ):
@@ -73,14 +74,14 @@ def __init__(
         self.psp = PSPModule(
             in_channels=encoder_channels[-1],
             sizes=(1, 2, 3, 6),
-            use_batchnorm=use_batchnorm,
+            use_norm=use_norm,
         )
 
         self.conv = modules.Conv2dReLU(
             in_channels=encoder_channels[-1] * 2,
             out_channels=out_channels,
             kernel_size=1,
-            use_batchnorm=use_batchnorm,
+            use_norm=use_norm,
         )
 
         self.dropout = nn.Dropout2d(p=dropout)
diff --git a/segmentation_models_pytorch/decoders/pspnet/model.py b/segmentation_models_pytorch/decoders/pspnet/model.py
index 8b99b3da..4b2d19f0 100644
--- a/segmentation_models_pytorch/decoders/pspnet/model.py
+++ b/segmentation_models_pytorch/decoders/pspnet/model.py
@@ -1,4 +1,5 @@
-from typing import Any, Optional, Union
+import warnings
+from typing import Any, Dict, Optional, Union, Callable
 
 from segmentation_models_pytorch.base import (
     ClassificationHead,
@@ -28,9 +29,22 @@ class PSPNet(SegmentationModel):
         encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
             other pretrained weights (see table with available weights for each encoder_name)
         psp_out_channels: A number of filters in Spatial Pyramid
-        psp_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers
-            is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption.
-            Available options are **True, False, "inplace"**
+        decoder_use_norm:     Specifies normalization between Conv2D and activation.
+            Accepts the following types:
+            - **True**: Defaults to `"batchnorm"`.
+            - **False**: No normalization (`nn.Identity`).
+            - **str**: Specifies normalization type using default parameters. Available values:
+              `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`.
+            - **dict**: Fully customizable normalization settings. Structure:
+              ```python
+              {"type": <norm_type>, **kwargs}
+              ```
+              where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation.
+
+            **Example**:
+            ```python
+            decoder_use_norm={"type": "layernorm", "eps": 1e-2}
+            ```
         psp_dropout: Spatial dropout rate in [0, 1) used in Spatial Pyramid
         in_channels: A number of input channels for the model, default is 3 (RGB images)
         classes: A number of classes for output mask (or you can think as a number of channels of output mask)
@@ -62,17 +76,26 @@ def __init__(
         encoder_weights: Optional[str] = "imagenet",
         encoder_depth: int = 3,
         psp_out_channels: int = 512,
-        psp_use_batchnorm: bool = True,
+        decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
         psp_dropout: float = 0.2,
         in_channels: int = 3,
         classes: int = 1,
-        activation: Optional[Union[str, callable]] = None,
+        activation: Optional[Union[str, Callable]] = None,
         upsampling: int = 8,
         aux_params: Optional[dict] = None,
         **kwargs: dict[str, Any],
     ):
         super().__init__()
 
+        psp_use_batchnorm = kwargs.pop("psp_use_batchnorm", None)
+        if psp_use_batchnorm is not None:
+            warnings.warn(
+                "The usage of psp_use_batchnorm is deprecated. Please modify your code for decoder_use_norm",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+            decoder_use_norm = psp_use_batchnorm
+
         self.encoder = get_encoder(
             encoder_name,
             in_channels=in_channels,
@@ -83,7 +106,7 @@ def __init__(
 
         self.decoder = PSPDecoder(
             encoder_channels=self.encoder.out_channels,
-            use_batchnorm=psp_use_batchnorm,
+            use_norm=decoder_use_norm,
             out_channels=psp_out_channels,
             dropout=psp_dropout,
         )
diff --git a/segmentation_models_pytorch/decoders/segformer/decoder.py b/segmentation_models_pytorch/decoders/segformer/decoder.py
index cd160a4c..2bfadfff 100644
--- a/segmentation_models_pytorch/decoders/segformer/decoder.py
+++ b/segmentation_models_pytorch/decoders/segformer/decoder.py
@@ -50,7 +50,7 @@ def __init__(
             in_channels=(len(encoder_channels) - 1) * segmentation_channels,
             out_channels=segmentation_channels,
             kernel_size=1,
-            use_batchnorm=True,
+            use_norm="batchnorm",
         )
 
     def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
diff --git a/segmentation_models_pytorch/decoders/unet/decoder.py b/segmentation_models_pytorch/decoders/unet/decoder.py
index 0e4f35fd..cfeb267e 100644
--- a/segmentation_models_pytorch/decoders/unet/decoder.py
+++ b/segmentation_models_pytorch/decoders/unet/decoder.py
@@ -1,8 +1,9 @@
+from typing import Any, Dict, List, Optional, Sequence, Union
+
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
-from typing import Optional, Sequence, List
 from segmentation_models_pytorch.base import modules as md
 
 
@@ -14,7 +15,7 @@ def __init__(
         in_channels: int,
         skip_channels: int,
         out_channels: int,
-        use_batchnorm: bool = True,
+        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
         attention_type: Optional[str] = None,
         interpolation_mode: str = "nearest",
     ):
@@ -25,7 +26,7 @@ def __init__(
             out_channels,
             kernel_size=3,
             padding=1,
-            use_batchnorm=use_batchnorm,
+            use_norm=use_norm,
         )
         self.attention1 = md.Attention(
             attention_type, in_channels=in_channels + skip_channels
@@ -35,7 +36,7 @@ def __init__(
             out_channels,
             kernel_size=3,
             padding=1,
-            use_batchnorm=use_batchnorm,
+            use_norm=use_norm,
         )
         self.attention2 = md.Attention(attention_type, in_channels=out_channels)
 
@@ -63,20 +64,25 @@ def forward(
 class UnetCenterBlock(nn.Sequential):
     """Center block of the Unet decoder. Applied to the last feature map of the encoder."""
 
-    def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
+    ):
         conv1 = md.Conv2dReLU(
             in_channels,
             out_channels,
             kernel_size=3,
             padding=1,
-            use_batchnorm=use_batchnorm,
+            use_norm=use_norm,
         )
         conv2 = md.Conv2dReLU(
             out_channels,
             out_channels,
             kernel_size=3,
             padding=1,
-            use_batchnorm=use_batchnorm,
+            use_norm=use_norm,
         )
         super().__init__(conv1, conv2)
 
@@ -93,7 +99,7 @@ def __init__(
         encoder_channels: Sequence[int],
         decoder_channels: Sequence[int],
         n_blocks: int = 5,
-        use_batchnorm: bool = True,
+        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
         attention_type: Optional[str] = None,
         add_center_block: bool = False,
         interpolation_mode: str = "nearest",
@@ -120,7 +126,9 @@ def __init__(
 
         if add_center_block:
             self.center = UnetCenterBlock(
-                head_channels, head_channels, use_batchnorm=use_batchnorm
+                head_channels,
+                head_channels,
+                use_norm=use_norm,
             )
         else:
             self.center = nn.Identity()
@@ -134,7 +142,7 @@ def __init__(
                 block_in_channels,
                 block_skip_channels,
                 block_out_channels,
-                use_batchnorm=use_batchnorm,
+                use_norm=use_norm,
                 attention_type=attention_type,
                 interpolation_mode=interpolation_mode,
             )
diff --git a/segmentation_models_pytorch/decoders/unet/model.py b/segmentation_models_pytorch/decoders/unet/model.py
index 4b30527d..22d7db11 100644
--- a/segmentation_models_pytorch/decoders/unet/model.py
+++ b/segmentation_models_pytorch/decoders/unet/model.py
@@ -1,4 +1,5 @@
-from typing import Any, Optional, Union, Callable, Sequence
+import warnings
+from typing import Any, Dict, Optional, Union, Callable, Sequence
 
 from segmentation_models_pytorch.base import (
     ClassificationHead,
@@ -39,9 +40,22 @@ class Unet(SegmentationModel):
             other pretrained weights (see table with available weights for each encoder_name)
         decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder.
             Length of the list should be the same as **encoder_depth**
-        decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers
-            is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption.
-            Available options are **True, False, "inplace"**
+        decoder_use_norm:     Specifies normalization between Conv2D and activation.
+            Accepts the following types:
+            - **True**: Defaults to `"batchnorm"`.
+            - **False**: No normalization (`nn.Identity`).
+            - **str**: Specifies normalization type using default parameters. Available values:
+              `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`.
+            - **dict**: Fully customizable normalization settings. Structure:
+              ```python
+              {"type": <norm_type>, **kwargs}
+              ```
+              where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation.
+
+            **Example**:
+            ```python
+            decoder_use_norm={"type": "layernorm", "eps": 1e-2}
+            ```
         decoder_attention_type: Attention module used in decoder of the model. Available options are
             **None** and **scse** (https://arxiv.org/abs/1808.08127).
         decoder_interpolation_mode: Interpolation mode used in decoder of the model. Available options are
@@ -95,7 +109,7 @@ def __init__(
         encoder_name: str = "resnet34",
         encoder_depth: int = 5,
         encoder_weights: Optional[str] = "imagenet",
-        decoder_use_batchnorm: bool = True,
+        decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
         decoder_channels: Sequence[int] = (256, 128, 64, 32, 16),
         decoder_attention_type: Optional[str] = None,
         decoder_interpolation_mode: str = "nearest",
@@ -107,6 +121,15 @@ def __init__(
     ):
         super().__init__()
 
+        decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None)
+        if decoder_use_batchnorm is not None:
+            warnings.warn(
+                "The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+            decoder_use_norm = decoder_use_batchnorm
+
         self.encoder = get_encoder(
             encoder_name,
             in_channels=in_channels,
@@ -116,11 +139,12 @@ def __init__(
         )
 
         add_center_block = encoder_name.startswith("vgg")
+
         self.decoder = UnetDecoder(
             encoder_channels=self.encoder.out_channels,
             decoder_channels=decoder_channels,
             n_blocks=encoder_depth,
-            use_batchnorm=decoder_use_batchnorm,
+            use_norm=decoder_use_norm,
             add_center_block=add_center_block,
             attention_type=decoder_attention_type,
             interpolation_mode=decoder_interpolation_mode,
diff --git a/segmentation_models_pytorch/decoders/unetplusplus/decoder.py b/segmentation_models_pytorch/decoders/unetplusplus/decoder.py
index 3282849f..e09327ac 100644
--- a/segmentation_models_pytorch/decoders/unetplusplus/decoder.py
+++ b/segmentation_models_pytorch/decoders/unetplusplus/decoder.py
@@ -2,7 +2,7 @@
 import torch.nn as nn
 import torch.nn.functional as F
 
-from typing import Optional, List
+from typing import Any, Dict, List, Optional, Union, Sequence
 
 from segmentation_models_pytorch.base import modules as md
 
@@ -13,7 +13,7 @@ def __init__(
         in_channels: int,
         skip_channels: int,
         out_channels: int,
-        use_batchnorm: bool = True,
+        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
         attention_type: Optional[str] = None,
     ):
         super().__init__()
@@ -22,7 +22,7 @@ def __init__(
             out_channels,
             kernel_size=3,
             padding=1,
-            use_batchnorm=use_batchnorm,
+            use_norm=use_norm,
         )
         self.attention1 = md.Attention(
             attention_type, in_channels=in_channels + skip_channels
@@ -32,7 +32,7 @@ def __init__(
             out_channels,
             kernel_size=3,
             padding=1,
-            use_batchnorm=use_batchnorm,
+            use_norm=use_norm,
         )
         self.attention2 = md.Attention(attention_type, in_channels=out_channels)
 
@@ -50,20 +50,25 @@ def forward(
 
 
 class CenterBlock(nn.Sequential):
-    def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
+    ):
         conv1 = md.Conv2dReLU(
             in_channels,
             out_channels,
             kernel_size=3,
             padding=1,
-            use_batchnorm=use_batchnorm,
+            use_norm=use_norm,
         )
         conv2 = md.Conv2dReLU(
             out_channels,
             out_channels,
             kernel_size=3,
             padding=1,
-            use_batchnorm=use_batchnorm,
+            use_norm=use_norm,
         )
         super().__init__(conv1, conv2)
 
@@ -71,10 +76,10 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr
 class UnetPlusPlusDecoder(nn.Module):
     def __init__(
         self,
-        encoder_channels: List[int],
-        decoder_channels: List[int],
+        encoder_channels: Sequence[int],
+        decoder_channels: Sequence[int],
         n_blocks: int = 5,
-        use_batchnorm: bool = True,
+        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
         attention_type: Optional[str] = None,
         center: bool = False,
     ):
@@ -97,13 +102,18 @@ def __init__(
         self.out_channels = decoder_channels
         if center:
             self.center = CenterBlock(
-                head_channels, head_channels, use_batchnorm=use_batchnorm
+                head_channels,
+                head_channels,
+                use_norm=use_norm,
             )
         else:
             self.center = nn.Identity()
 
         # combine decoder keyword arguments
-        kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
+        kwargs = dict(
+            use_norm=use_norm,
+            attention_type=attention_type,
+        )
 
         blocks = {}
         for layer_idx in range(len(self.in_channels) - 1):
diff --git a/segmentation_models_pytorch/decoders/unetplusplus/model.py b/segmentation_models_pytorch/decoders/unetplusplus/model.py
index 5c3d3a91..be0f8f83 100644
--- a/segmentation_models_pytorch/decoders/unetplusplus/model.py
+++ b/segmentation_models_pytorch/decoders/unetplusplus/model.py
@@ -1,4 +1,5 @@
-from typing import Any, List, Optional, Union
+import warnings
+from typing import Any, Dict, Sequence, Optional, Union, Callable
 
 from segmentation_models_pytorch.base import (
     ClassificationHead,
@@ -28,9 +29,22 @@ class UnetPlusPlus(SegmentationModel):
             other pretrained weights (see table with available weights for each encoder_name)
         decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder.
             Length of the list should be the same as **encoder_depth**
-        decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers
-            is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption.
-            Available options are **True, False, "inplace"**
+        decoder_use_norm:     Specifies normalization between Conv2D and activation.
+            Accepts the following types:
+            - **True**: Defaults to `"batchnorm"`.
+            - **False**: No normalization (`nn.Identity`).
+            - **str**: Specifies normalization type using default parameters. Available values:
+              `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`.
+            - **dict**: Fully customizable normalization settings. Structure:
+              ```python
+              {"type": <norm_type>, **kwargs}
+              ```
+              where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation.
+
+            **Example**:
+            ```python
+            decoder_use_norm={"type": "layernorm", "eps": 1e-2}
+            ```
         decoder_attention_type: Attention module used in decoder of the model.
             Available options are **None** and **scse** (https://arxiv.org/abs/1808.08127).
         in_channels: A number of input channels for the model, default is 3 (RGB images)
@@ -64,12 +78,12 @@ def __init__(
         encoder_name: str = "resnet34",
         encoder_depth: int = 5,
         encoder_weights: Optional[str] = "imagenet",
-        decoder_use_batchnorm: bool = True,
-        decoder_channels: List[int] = (256, 128, 64, 32, 16),
+        decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
+        decoder_channels: Sequence[int] = (256, 128, 64, 32, 16),
         decoder_attention_type: Optional[str] = None,
         in_channels: int = 3,
         classes: int = 1,
-        activation: Optional[Union[str, callable]] = None,
+        activation: Optional[Union[str, Callable]] = None,
         aux_params: Optional[dict] = None,
         **kwargs: dict[str, Any],
     ):
@@ -80,6 +94,15 @@ def __init__(
                 "UnetPlusPlus is not support encoder_name={}".format(encoder_name)
             )
 
+        decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None)
+        if decoder_use_batchnorm is not None:
+            warnings.warn(
+                "The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+            decoder_use_norm = decoder_use_batchnorm
+
         self.encoder = get_encoder(
             encoder_name,
             in_channels=in_channels,
@@ -92,7 +115,7 @@ def __init__(
             encoder_channels=self.encoder.out_channels,
             decoder_channels=decoder_channels,
             n_blocks=encoder_depth,
-            use_batchnorm=decoder_use_batchnorm,
+            use_norm=decoder_use_norm,
             center=True if encoder_name.startswith("vgg") else False,
             attention_type=decoder_attention_type,
         )
diff --git a/segmentation_models_pytorch/decoders/upernet/decoder.py b/segmentation_models_pytorch/decoders/upernet/decoder.py
index 99c74fb1..810778f3 100644
--- a/segmentation_models_pytorch/decoders/upernet/decoder.py
+++ b/segmentation_models_pytorch/decoders/upernet/decoder.py
@@ -1,3 +1,5 @@
+from typing import Any, Dict, Union, Sequence
+
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
@@ -8,10 +10,10 @@
 class PSPModule(nn.Module):
     def __init__(
         self,
-        in_channels,
-        out_channels,
-        sizes=(1, 2, 3, 6),
-        use_batchnorm=True,
+        in_channels: int,
+        out_channels: int,
+        sizes: Sequence[int] = (1, 2, 3, 6),
+        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
     ):
         super().__init__()
         self.blocks = nn.ModuleList(
@@ -22,7 +24,7 @@ def __init__(
                         in_channels,
                         in_channels // len(sizes),
                         kernel_size=1,
-                        use_batchnorm=use_batchnorm,
+                        use_norm=use_norm,
                     ),
                 )
                 for size in sizes
@@ -32,7 +34,7 @@ def __init__(
             in_channels=in_channels * 2,
             out_channels=out_channels,
             kernel_size=1,
-            use_batchnorm=True,
+            use_norm="batchnorm",
         )
 
     def forward(self, x):
@@ -48,14 +50,19 @@ def forward(self, x):
 
 
 class FPNBlock(nn.Module):
-    def __init__(self, skip_channels, pyramid_channels, use_batchnorm=True):
+    def __init__(
+        self,
+        skip_channels: int,
+        pyramid_channels: int,
+        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
+    ):
         super().__init__()
         self.skip_conv = (
             md.Conv2dReLU(
                 skip_channels,
                 pyramid_channels,
                 kernel_size=1,
-                use_batchnorm=use_batchnorm,
+                use_norm=use_norm,
             )
             if skip_channels != 0
             else nn.Identity()
@@ -73,10 +80,11 @@ def forward(self, x, skip):
 class UPerNetDecoder(nn.Module):
     def __init__(
         self,
-        encoder_channels,
-        encoder_depth=5,
-        pyramid_channels=256,
-        segmentation_channels=64,
+        encoder_channels: Sequence[int],
+        encoder_depth: int = 5,
+        pyramid_channels: int = 256,
+        segmentation_channels: int = 64,
+        use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
     ):
         super().__init__()
 
@@ -94,7 +102,7 @@ def __init__(
             in_channels=encoder_channels[0],
             out_channels=pyramid_channels,
             sizes=(1, 2, 3, 6),
-            use_batchnorm=True,
+            use_norm=use_norm,
         )
 
         # FPN Module
@@ -107,7 +115,7 @@ def __init__(
             out_channels=segmentation_channels,
             kernel_size=3,
             padding=1,
-            use_batchnorm=True,
+            use_norm=use_norm,
         )
 
     def forward(self, features):
diff --git a/segmentation_models_pytorch/decoders/upernet/model.py b/segmentation_models_pytorch/decoders/upernet/model.py
index 7ffeee5b..6ad5afd5 100644
--- a/segmentation_models_pytorch/decoders/upernet/model.py
+++ b/segmentation_models_pytorch/decoders/upernet/model.py
@@ -1,4 +1,4 @@
-from typing import Any, Optional, Union
+from typing import Any, Dict, Optional, Union, Callable
 
 from segmentation_models_pytorch.base import (
     ClassificationHead,
@@ -25,6 +25,22 @@ class UPerNet(SegmentationModel):
             other pretrained weights (see table with available weights for each encoder_name)
         decoder_pyramid_channels: A number of convolution filters in Feature Pyramid, default is 256
         decoder_segmentation_channels: A number of convolution filters in segmentation blocks, default is 64
+        decoder_use_norm: Specifies normalization between Conv2D and activation.
+            Accepts the following types:
+            - **True**: Defaults to `"batchnorm"`.
+            - **False**: No normalization (`nn.Identity`).
+            - **str**: Specifies normalization type using default parameters. Available values:
+              `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`.
+            - **dict**: Fully customizable normalization settings. Structure:
+              ```python
+              {"type": <norm_type>, **kwargs}
+              ```
+              where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation.
+
+            **Example**:
+            ```python
+            use_norm={"type": "layernorm", "eps": 1e-2}
+            ```
         in_channels: A number of input channels for the model, default is 3 (RGB images)
         classes: A number of classes for output mask (or you can think as a number of channels of output mask)
         activation: An activation function to apply after the final convolution layer.
@@ -58,9 +74,10 @@ def __init__(
         encoder_weights: Optional[str] = "imagenet",
         decoder_pyramid_channels: int = 256,
         decoder_segmentation_channels: int = 64,
+        decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
         in_channels: int = 3,
         classes: int = 1,
-        activation: Optional[Union[str, callable]] = None,
+        activation: Optional[Union[str, Callable]] = None,
         aux_params: Optional[dict] = None,
         **kwargs: dict[str, Any],
     ):
@@ -79,6 +96,7 @@ def __init__(
             encoder_depth=encoder_depth,
             pyramid_channels=decoder_pyramid_channels,
             segmentation_channels=decoder_segmentation_channels,
+            use_norm=decoder_use_norm,
         )
 
         self.segmentation_head = SegmentationHead(
diff --git a/tests/base/test_modules.py b/tests/base/test_modules.py
new file mode 100644
index 00000000..5afa8e4f
--- /dev/null
+++ b/tests/base/test_modules.py
@@ -0,0 +1,64 @@
+import pytest
+from torch import nn
+from segmentation_models_pytorch.base.modules import Conv2dReLU
+
+
+def test_conv2drelu_batchnorm():
+    module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="batchnorm")
+
+    assert isinstance(module[0], nn.Conv2d)
+    assert isinstance(module[1], nn.BatchNorm2d)
+    assert isinstance(module[2], nn.ReLU)
+
+
+def test_conv2drelu_batchnorm_with_keywords():
+    module = Conv2dReLU(
+        3,
+        16,
+        kernel_size=3,
+        padding=1,
+        use_norm={"type": "batchnorm", "momentum": 1e-4, "affine": False},
+    )
+
+    assert isinstance(module[0], nn.Conv2d)
+    assert isinstance(module[1], nn.BatchNorm2d)
+    assert module[1].momentum == 1e-4 and module[1].affine is False
+    assert isinstance(module[2], nn.ReLU)
+
+
+def test_conv2drelu_identity():
+    module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="identity")
+
+    assert isinstance(module[0], nn.Conv2d)
+    assert isinstance(module[1], nn.Identity)
+    assert isinstance(module[2], nn.ReLU)
+
+
+def test_conv2drelu_layernorm():
+    module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="layernorm")
+
+    assert isinstance(module[0], nn.Conv2d)
+    assert isinstance(module[1], nn.LayerNorm)
+    assert isinstance(module[2], nn.ReLU)
+
+
+def test_conv2drelu_instancenorm():
+    module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="instancenorm")
+
+    assert isinstance(module[0], nn.Conv2d)
+    assert isinstance(module[1], nn.InstanceNorm2d)
+    assert isinstance(module[2], nn.ReLU)
+
+
+def test_conv2drelu_inplace():
+    try:
+        from inplace_abn import InPlaceABN
+    except ImportError:
+        pytest.skip("InPlaceABN is not installed")
+
+    module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="inplace")
+
+    assert len(module) == 3
+    assert isinstance(module[0], nn.Conv2d)
+    assert isinstance(module[1], InPlaceABN)
+    assert isinstance(module[2], nn.Identity)
diff --git a/tests/encoders/test_batchnorm_deprecation.py b/tests/encoders/test_batchnorm_deprecation.py
new file mode 100644
index 00000000..ff53563f
--- /dev/null
+++ b/tests/encoders/test_batchnorm_deprecation.py
@@ -0,0 +1,54 @@
+import pytest
+
+import torch
+
+import segmentation_models_pytorch as smp
+from tests.utils import check_two_models_strictly_equal
+
+
+@pytest.mark.parametrize("model_name", ["unet", "unetplusplus", "linknet", "manet"])
+@pytest.mark.parametrize("decoder_option", [True, False, "inplace"])
+def test_seg_models_before_after_use_norm(model_name, decoder_option):
+    torch.manual_seed(42)
+    with pytest.warns(DeprecationWarning):
+        model_decoder_batchnorm = smp.create_model(
+            model_name,
+            "mobilenet_v2",
+            encoder_weights=None,
+            decoder_use_batchnorm=decoder_option,
+        )
+    model_decoder_norm = smp.create_model(
+        model_name,
+        "mobilenet_v2",
+        encoder_weights=None,
+        decoder_use_norm=decoder_option,
+    )
+
+    model_decoder_norm.load_state_dict(model_decoder_batchnorm.state_dict())
+
+    check_two_models_strictly_equal(
+        model_decoder_batchnorm, model_decoder_norm, torch.rand(1, 3, 224, 224)
+    )
+
+
+@pytest.mark.parametrize("decoder_option", [True, False, "inplace"])
+def test_pspnet_before_after_use_norm(decoder_option):
+    torch.manual_seed(42)
+    with pytest.warns(DeprecationWarning):
+        model_decoder_batchnorm = smp.create_model(
+            "pspnet",
+            "mobilenet_v2",
+            encoder_weights=None,
+            psp_use_batchnorm=decoder_option,
+        )
+    model_decoder_norm = smp.create_model(
+        "pspnet",
+        "mobilenet_v2",
+        encoder_weights=None,
+        decoder_use_norm=decoder_option,
+    )
+    model_decoder_norm.load_state_dict(model_decoder_batchnorm.state_dict())
+
+    check_two_models_strictly_equal(
+        model_decoder_batchnorm, model_decoder_norm, torch.rand(1, 3, 224, 224)
+    )
diff --git a/tests/models/base.py b/tests/models/base.py
index f7492986..b96e76e8 100644
--- a/tests/models/base.py
+++ b/tests/models/base.py
@@ -282,4 +282,4 @@ def test_torch_script(self):
             eager_output = model(sample)
 
         self.assertEqual(scripted_output.shape, eager_output.shape)
-        torch.testing.assert_close(scripted_output, eager_output)
+        torch.testing.assert_close(scripted_output, eager_output, rtol=1e-3, atol=1e-3)
diff --git a/tests/utils.py b/tests/utils.py
index 6e201f1d..1e97b40b 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -58,3 +58,23 @@ def check_run_test_on_diff_or_main(filepath_patterns: List[str]):
                 return True
 
     return False
+
+
+def check_two_models_strictly_equal(
+    model_a: torch.nn.Module, model_b: torch.nn.Module, input_data: torch.Tensor
+) -> None:
+    for (k1, v1), (k2, v2) in zip(
+        model_a.state_dict().items(), model_b.state_dict().items()
+    ):
+        assert k1 == k2, f"Key mismatch: {k1} != {k2}"
+        torch.testing.assert_close(
+            v1, v2, msg=f"Tensor mismatch at key '{k1}':\n{v1} !=\n{v2}"
+        )
+
+    model_a.eval()
+    model_b.eval()
+    with torch.inference_mode():
+        output_a = model_a(input_data)
+        output_b = model_b(input_data)
+
+    torch.testing.assert_close(output_a, output_b)