Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate use_batchnorm in favor of generalized use_norm parameter #1095

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
95 changes: 81 additions & 14 deletions segmentation_models_pytorch/base/modules.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Any, Dict, Tuple, Union
import warnings

import torch
import torch.nn as nn

Expand All @@ -6,6 +9,74 @@
except ImportError:
InPlaceABN = None

def normalize_use_norm(decoder_use_norm: Union[bool, str, Dict[str, Any]]) -> Dict[str, Any]:
if isinstance(decoder_use_norm, str):
norm_str = decoder_use_norm.lower()
if norm_str == "inplace":
decoder_use_norm = {
"type": "inplace",
"activation": "leaky_relu",
"activation_param": 0.0,
}
elif norm_str in (
"batchnorm",
"identity",
"layernorm",
"groupnorm",
"instancenorm",
):
decoder_use_norm = {"type": norm_str}
else:
raise ValueError("Unrecognized normalization type string provided")
elif isinstance(decoder_use_norm, bool):
decoder_use_norm = {"type": "batchnorm" if decoder_use_norm else "identity"}
elif not isinstance(decoder_use_norm, dict):
raise ValueError("use_norm must be a dictionary, boolean, or string")

return decoder_use_norm

def normalize_decoder_norm(decoder_use_batchnorm: Union[bool, str, None], decoder_use_norm: Union[bool, str, Dict[str, Any]]) -> Dict[str, Any]:
if decoder_use_batchnorm is not None:
warnings.warn(
"The usage of use_batchnorm is deprecated. Please modify your code for use_norm",
DeprecationWarning,
stacklevel=2
)
if decoder_use_batchnorm is True:
decoder_use_norm = {"type": "batchnorm"}
elif decoder_use_batchnorm is False:
decoder_use_norm = {"type": "identity"}
elif decoder_use_batchnorm == "inplace":
decoder_use_norm = {
"type": "inplace",
"activation": "leaky_relu",
"activation_param": 0.0,
}
else:
raise ValueError("Unrecognized value for use_batchnorm")

decoder_use_norm = normalize_use_norm(decoder_use_norm)
return decoder_use_norm


def get_norm_layer(use_norm: Dict[str, Any], out_channels: int) -> nn.Module:
norm_type = use_norm["type"]
extra_kwargs = {k: v for k, v in use_norm.items() if k != "type"}

if norm_type == "inplace":
norm = InPlaceABN(out_channels, **extra_kwargs)
elif norm_type == "batchnorm":
norm = nn.BatchNorm2d(out_channels, **extra_kwargs)
elif norm_type == "identity":
norm = nn.Identity()
elif norm_type == "layernorm":
norm = nn.LayerNorm(out_channels, **extra_kwargs)
elif norm_type == "instancenorm":
norm = nn.InstanceNorm2d(out_channels, **extra_kwargs)
else:
raise ValueError(f"Unrecognized normalization type: {norm_type}")

return norm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather put everything into a single simple function as follows:

if use_norm is True:
    params = {"type": "batchnorm"}
elif use_norm is False:
    params = {"type": "identity"}
elif use_norm == "inplace":
    params = {"type": "inplace", "activation": "leaky_relu", "activation_param": 0.0}

if not isinstance(params, dict) -> raise
if not "type" in params -> raise
if not params["type"] in supproted_norms -> raise

<dispatch to norm layer>

return norm 

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So no str for default usage like LayerNorm etc... ??? I kept it for the moment. Tell me if I have to remove.

Did changed accordingly and tried to match the proposed flow as closely as possible


class Conv2dReLU(nn.Sequential):
def __init__(
Expand All @@ -15,12 +86,13 @@ def __init__(
kernel_size,
padding=0,
stride=1,
use_batchnorm=True,
use_norm="batchnorm",
):
if use_batchnorm == "inplace" and InPlaceABN is None:
use_norm = normalize_use_norm(use_norm)
if use_norm["type"] == "inplace" and InPlaceABN is None:
raise RuntimeError(
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
+ "To install see: https://github.com/mapillary/inplace_abn"
"In order to use `use_batchnorm='inplace'` or `use_norm='inplace'` the inplace_abn package must be installed. "
"To install see: https://github.com/mapillary/inplace_abn"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also not needed here, should be under get_norm_layer function

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed and placed in get_norm_layer

)

conv = nn.Conv2d(
Expand All @@ -29,21 +101,16 @@ def __init__(
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
bias=use_norm["type"] != "inplace",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can initialize norm first and use a separate varaible

is_inplace_batchnorm = norm.__name__ == "InPlaceABN"

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeap simpler

)
relu = nn.ReLU(inplace=True)
norm = get_norm_layer(use_norm, out_channels)

if use_batchnorm == "inplace":
bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
if use_norm["type"] == "inplace":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

relu = nn.Identity()

elif use_batchnorm and use_batchnorm != "inplace":
bn = nn.BatchNorm2d(out_channels)

else:
bn = nn.Identity()
relu = nn.ReLU(inplace=True)

super(Conv2dReLU, self).__init__(conv, bn, relu)
super(Conv2dReLU, self).__init__(conv, norm, relu)


class SCSEModule(nn.Module):
Expand Down
38 changes: 28 additions & 10 deletions segmentation_models_pytorch/decoders/linknet/decoder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import torch
import torch.nn as nn

from typing import List, Optional
from typing import Any, Dict, List, Optional, Union
from segmentation_models_pytorch.base import modules


class TransposeX2(nn.Sequential):
def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
def __init__(
self,
in_channels: int,
out_channels: int,
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
):
super().__init__()
layers = [
nn.ConvTranspose2d(
Expand All @@ -15,31 +20,40 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr
nn.ReLU(inplace=True),
]

if use_batchnorm:
layers.insert(1, nn.BatchNorm2d(out_channels))
if use_norm != "identity":
if isinstance(use_norm, dict):
if use_norm.get("type") != "identity":
layers.insert(1, modules.get_norm_layer(use_norm, out_channels))
else:
layers.insert(1, modules.get_norm_layer(use_norm, out_channels))

super().__init__(*layers)


class DecoderBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
def __init__(
self,
in_channels: int,
out_channels: int,
use_norm: Union[bool, str, Dict[str, Any]] = True,
):
super().__init__()

self.block = nn.Sequential(
modules.Conv2dReLU(
in_channels,
in_channels // 4,
kernel_size=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
),
TransposeX2(
in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm
in_channels // 4, in_channels // 4, use_norm=use_norm
),
modules.Conv2dReLU(
in_channels // 4,
out_channels,
kernel_size=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
),
)

Expand All @@ -58,7 +72,7 @@ def __init__(
encoder_channels: List[int],
prefinal_channels: int = 32,
n_blocks: int = 5,
use_batchnorm: bool = True,
use_norm: Union[bool, str, Dict[str, Any]] = True,
):
super().__init__()

Expand All @@ -71,7 +85,11 @@ def __init__(

self.blocks = nn.ModuleList(
[
DecoderBlock(channels[i], channels[i + 1], use_batchnorm=use_batchnorm)
DecoderBlock(
channels[i],
channels[i + 1],
use_norm=use_norm,
)
for i in range(n_blocks)
]
)
Expand Down
29 changes: 25 additions & 4 deletions segmentation_models_pytorch/decoders/linknet/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Union
from typing import Any, Dict, Optional, Union

from segmentation_models_pytorch.base import (
ClassificationHead,
Expand All @@ -7,6 +7,7 @@
)
from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
from segmentation_models_pytorch.base.modules import normalize_decoder_norm

from .decoder import LinknetDecoder

Expand All @@ -29,9 +30,27 @@ class Linknet(SegmentationModel):
Default is 5
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
other pretrained weights (see table with available weights for each encoder_name)
decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers
decoder_use_batchnorm: (**Deprecated**) If **True**, BatchNorm2d layer between Conv2D and Activation layers
is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption.
Available options are **True, False, "inplace"**

**Note:** Deprecated, prefer using `decoder_use_norm` and set this to None.
decoder_use_norm: Specifies normalization between Conv2D and activation.
Accepts the following types:
- **True**: Defaults to `"batchnorm"`.
- **False**: No normalization (`nn.Identity`).
- **str**: Specifies normalization type using default parameters. Available values:
`"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`.
- **dict**: Fully customizable normalization settings. Structure:
```python
{"type": <norm_type>, **kwargs}
```
where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation.

**Example**:
```python
use_norm={"type": "layernorm", "eps": 1e-2}
```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed docstring, really appretiate it!

in_channels: A number of input channels for the model, default is 3 (RGB images)
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
activation: An activation function to apply after the final convolution layer.
Expand Down Expand Up @@ -60,7 +79,8 @@ def __init__(
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
decoder_use_batchnorm: bool = True,
decoder_use_batchnorm: Union[bool, str, None] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's removedecoder_use_batchnorm from signature and pop from krawgs later

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed decoder_use_batchnorm and used kwargs. I like this approach

decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
in_channels: int = 3,
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
Expand All @@ -82,11 +102,12 @@ def __init__(
**kwargs,
)

decoder_use_norm = normalize_decoder_norm(decoder_use_batchnorm, decoder_use_norm)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we just resolve the name

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tell me if this is what you had in mind

self.decoder = LinknetDecoder(
encoder_channels=self.encoder.out_channels,
n_blocks=encoder_depth,
prefinal_channels=32,
use_batchnorm=decoder_use_batchnorm,
use_norm=decoder_use_norm,
)

self.segmentation_head = SegmentationHead(
Expand Down
29 changes: 17 additions & 12 deletions segmentation_models_pytorch/decoders/manet/decoder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any, Dict, List, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import List, Optional

from segmentation_models_pytorch.base import modules as md


Expand Down Expand Up @@ -49,7 +49,7 @@ def __init__(
in_channels: int,
skip_channels: int,
out_channels: int,
use_batchnorm: bool = True,
use_norm: Union[bool, str, Dict[str, Any]] = True,
reduction: int = 16,
):
# MFABBlock is just a modified version of SE-blocks, one for skip, one for input
Expand All @@ -60,10 +60,13 @@ def __init__(
in_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
),
md.Conv2dReLU(
in_channels, skip_channels, kernel_size=1, use_batchnorm=use_batchnorm
in_channels,
skip_channels,
kernel_size=1,
use_norm=use_norm,
),
)
reduced_channels = max(1, skip_channels // reduction)
Expand All @@ -87,14 +90,14 @@ def __init__(
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
)
self.conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
)

def forward(
Expand All @@ -119,22 +122,22 @@ def __init__(
in_channels: int,
skip_channels: int,
out_channels: int,
use_batchnorm: bool = True,
use_norm: Union[bool, str, Dict[str, Any]] = True,
):
super().__init__()
self.conv1 = md.Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
)
self.conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
)

def forward(
Expand All @@ -155,7 +158,7 @@ def __init__(
decoder_channels: List[int],
n_blocks: int = 5,
reduction: int = 16,
use_batchnorm: bool = True,
use_norm: Union[bool, str, Dict[str, Any]] = True,
pab_channels: int = 64,
):
super().__init__()
Expand All @@ -182,7 +185,9 @@ def __init__(
self.center = PABBlock(head_channels, pab_channels=pab_channels)

# combine decoder keyword arguments
kwargs = dict(use_batchnorm=use_batchnorm) # no attention type here
kwargs = dict(
use_norm=use_norm
) # no attention type here
blocks = [
MFABBlock(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs)
if skip_ch > 0
Expand Down
Loading