-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deprecate use_batchnorm in favor of generalized use_norm parameter #1095
base: main
Are you sure you want to change the base?
Changes from 5 commits
e26adcd
d65001b
1b16b25
10d496a
467057a
1ae11c3
be22951
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
from typing import Any, Dict, Tuple, Union | ||
import warnings | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
@@ -6,6 +9,74 @@ | |
except ImportError: | ||
InPlaceABN = None | ||
|
||
def normalize_use_norm(decoder_use_norm: Union[bool, str, Dict[str, Any]]) -> Dict[str, Any]: | ||
if isinstance(decoder_use_norm, str): | ||
norm_str = decoder_use_norm.lower() | ||
if norm_str == "inplace": | ||
decoder_use_norm = { | ||
"type": "inplace", | ||
"activation": "leaky_relu", | ||
"activation_param": 0.0, | ||
} | ||
elif norm_str in ( | ||
"batchnorm", | ||
"identity", | ||
"layernorm", | ||
"groupnorm", | ||
"instancenorm", | ||
): | ||
decoder_use_norm = {"type": norm_str} | ||
else: | ||
raise ValueError("Unrecognized normalization type string provided") | ||
elif isinstance(decoder_use_norm, bool): | ||
decoder_use_norm = {"type": "batchnorm" if decoder_use_norm else "identity"} | ||
elif not isinstance(decoder_use_norm, dict): | ||
raise ValueError("use_norm must be a dictionary, boolean, or string") | ||
|
||
return decoder_use_norm | ||
|
||
def normalize_decoder_norm(decoder_use_batchnorm: Union[bool, str, None], decoder_use_norm: Union[bool, str, Dict[str, Any]]) -> Dict[str, Any]: | ||
if decoder_use_batchnorm is not None: | ||
warnings.warn( | ||
"The usage of use_batchnorm is deprecated. Please modify your code for use_norm", | ||
DeprecationWarning, | ||
stacklevel=2 | ||
) | ||
if decoder_use_batchnorm is True: | ||
decoder_use_norm = {"type": "batchnorm"} | ||
elif decoder_use_batchnorm is False: | ||
decoder_use_norm = {"type": "identity"} | ||
elif decoder_use_batchnorm == "inplace": | ||
decoder_use_norm = { | ||
"type": "inplace", | ||
"activation": "leaky_relu", | ||
"activation_param": 0.0, | ||
} | ||
else: | ||
raise ValueError("Unrecognized value for use_batchnorm") | ||
|
||
decoder_use_norm = normalize_use_norm(decoder_use_norm) | ||
return decoder_use_norm | ||
|
||
|
||
def get_norm_layer(use_norm: Dict[str, Any], out_channels: int) -> nn.Module: | ||
norm_type = use_norm["type"] | ||
extra_kwargs = {k: v for k, v in use_norm.items() if k != "type"} | ||
|
||
if norm_type == "inplace": | ||
norm = InPlaceABN(out_channels, **extra_kwargs) | ||
elif norm_type == "batchnorm": | ||
norm = nn.BatchNorm2d(out_channels, **extra_kwargs) | ||
elif norm_type == "identity": | ||
norm = nn.Identity() | ||
elif norm_type == "layernorm": | ||
norm = nn.LayerNorm(out_channels, **extra_kwargs) | ||
elif norm_type == "instancenorm": | ||
norm = nn.InstanceNorm2d(out_channels, **extra_kwargs) | ||
else: | ||
raise ValueError(f"Unrecognized normalization type: {norm_type}") | ||
|
||
return norm | ||
|
||
class Conv2dReLU(nn.Sequential): | ||
def __init__( | ||
|
@@ -15,12 +86,13 @@ def __init__( | |
kernel_size, | ||
padding=0, | ||
stride=1, | ||
use_batchnorm=True, | ||
use_norm="batchnorm", | ||
): | ||
if use_batchnorm == "inplace" and InPlaceABN is None: | ||
use_norm = normalize_use_norm(use_norm) | ||
if use_norm["type"] == "inplace" and InPlaceABN is None: | ||
raise RuntimeError( | ||
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " | ||
+ "To install see: https://github.com/mapillary/inplace_abn" | ||
"In order to use `use_batchnorm='inplace'` or `use_norm='inplace'` the inplace_abn package must be installed. " | ||
"To install see: https://github.com/mapillary/inplace_abn" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also not needed here, should be under get_norm_layer function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed and placed in get_norm_layer |
||
) | ||
|
||
conv = nn.Conv2d( | ||
|
@@ -29,21 +101,16 @@ def __init__( | |
kernel_size, | ||
stride=stride, | ||
padding=padding, | ||
bias=not (use_batchnorm), | ||
bias=use_norm["type"] != "inplace", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can initialize norm first and use a separate varaible
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeap simpler |
||
) | ||
relu = nn.ReLU(inplace=True) | ||
norm = get_norm_layer(use_norm, out_channels) | ||
|
||
if use_batchnorm == "inplace": | ||
bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) | ||
if use_norm["type"] == "inplace": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
relu = nn.Identity() | ||
|
||
elif use_batchnorm and use_batchnorm != "inplace": | ||
bn = nn.BatchNorm2d(out_channels) | ||
|
||
else: | ||
bn = nn.Identity() | ||
relu = nn.ReLU(inplace=True) | ||
|
||
super(Conv2dReLU, self).__init__(conv, bn, relu) | ||
super(Conv2dReLU, self).__init__(conv, norm, relu) | ||
|
||
|
||
class SCSEModule(nn.Module): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from typing import Any, Optional, Union | ||
from typing import Any, Dict, Optional, Union | ||
|
||
from segmentation_models_pytorch.base import ( | ||
ClassificationHead, | ||
|
@@ -7,6 +7,7 @@ | |
) | ||
from segmentation_models_pytorch.encoders import get_encoder | ||
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading | ||
from segmentation_models_pytorch.base.modules import normalize_decoder_norm | ||
|
||
from .decoder import LinknetDecoder | ||
|
||
|
@@ -29,9 +30,27 @@ class Linknet(SegmentationModel): | |
Default is 5 | ||
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and | ||
other pretrained weights (see table with available weights for each encoder_name) | ||
decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers | ||
decoder_use_batchnorm: (**Deprecated**) If **True**, BatchNorm2d layer between Conv2D and Activation layers | ||
is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. | ||
Available options are **True, False, "inplace"** | ||
|
||
**Note:** Deprecated, prefer using `decoder_use_norm` and set this to None. | ||
decoder_use_norm: Specifies normalization between Conv2D and activation. | ||
Accepts the following types: | ||
- **True**: Defaults to `"batchnorm"`. | ||
- **False**: No normalization (`nn.Identity`). | ||
- **str**: Specifies normalization type using default parameters. Available values: | ||
`"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`. | ||
- **dict**: Fully customizable normalization settings. Structure: | ||
```python | ||
{"type": <norm_type>, **kwargs} | ||
``` | ||
where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation. | ||
|
||
**Example**: | ||
```python | ||
use_norm={"type": "layernorm", "eps": 1e-2} | ||
``` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the detailed docstring, really appretiate it! |
||
in_channels: A number of input channels for the model, default is 3 (RGB images) | ||
classes: A number of classes for output mask (or you can think as a number of channels of output mask) | ||
activation: An activation function to apply after the final convolution layer. | ||
|
@@ -60,7 +79,8 @@ def __init__( | |
encoder_name: str = "resnet34", | ||
encoder_depth: int = 5, | ||
encoder_weights: Optional[str] = "imagenet", | ||
decoder_use_batchnorm: bool = True, | ||
decoder_use_batchnorm: Union[bool, str, None] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed decoder_use_batchnorm and used kwargs. I like this approach |
||
decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", | ||
in_channels: int = 3, | ||
classes: int = 1, | ||
activation: Optional[Union[str, callable]] = None, | ||
|
@@ -82,11 +102,12 @@ def __init__( | |
**kwargs, | ||
) | ||
|
||
decoder_use_norm = normalize_decoder_norm(decoder_use_batchnorm, decoder_use_norm) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here we just resolve the name There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tell me if this is what you had in mind |
||
self.decoder = LinknetDecoder( | ||
encoder_channels=self.encoder.out_channels, | ||
n_blocks=encoder_depth, | ||
prefinal_channels=32, | ||
use_batchnorm=decoder_use_batchnorm, | ||
use_norm=decoder_use_norm, | ||
) | ||
|
||
self.segmentation_head = SegmentationHead( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would rather put everything into a single simple function as follows:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So no str for default usage like LayerNorm etc... ??? I kept it for the moment. Tell me if I have to remove.
Did changed accordingly and tried to match the proposed flow as closely as possible