Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate use_batchnorm in favor of generalized use_norm parameter #1095

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

GuillaumeErhard
Copy link

@GuillaumeErhard GuillaumeErhard commented Mar 19, 2025

Hello,

a bit late to implement my suggestion in #983 but here it is. Open to any suggestion to make it clearer / simpler or any wanted test.

Closes #983

Copy link

codecov bot commented Mar 20, 2025

Codecov Report

Attention: Patch coverage is 60.34483% with 23 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
segmentation_models_pytorch/base/modules.py 46.51% 23 Missing ⚠️
Files with missing lines Coverage Δ
...ntation_models_pytorch/decoders/linknet/decoder.py 100.00% <100.00%> (ø)
...mentation_models_pytorch/decoders/linknet/model.py 94.73% <100.00%> (ø)
...mentation_models_pytorch/decoders/manet/decoder.py 97.75% <100.00%> (ø)
...egmentation_models_pytorch/decoders/manet/model.py 100.00% <100.00%> (ø)
...gmentation_models_pytorch/decoders/unet/decoder.py 91.37% <100.00%> (ø)
segmentation_models_pytorch/decoders/unet/model.py 100.00% <100.00%> (ø)
...on_models_pytorch/decoders/unetplusplus/decoder.py 92.85% <100.00%> (ø)
...tion_models_pytorch/decoders/unetplusplus/model.py 95.00% <100.00%> (ø)
segmentation_models_pytorch/base/modules.py 50.45% <46.51%> (-4.95%) ⬇️
🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @GuillaumeErhard, huge thanks for taking time and working on the PR! Here a few comments

Comment on lines 23 to 68
if use_batchnorm is not None:
warnings.warn(
"The usage of use_batchnorm is deprecated. Please modify your code for use_norm",
DeprecationWarning,
)
if use_batchnorm is True:
use_norm = {"type": "batchnorm"}
elif use_batchnorm is False:
use_norm = {"type": "identity"}
elif use_batchnorm == "inplace":
use_norm = {
"type": "inplace",
"activation": "leaky_relu",
"activation_param": 0.0,
}
else:
raise ValueError("Unrecognized value for use_batchnorm")

if isinstance(use_norm, str):
norm_str = use_norm.lower()
if norm_str == "inplace":
use_norm = {
"type": "inplace",
"activation": "leaky_relu",
"activation_param": 0.0,
}
elif norm_str in (
"batchnorm",
"identity",
"layernorm",
"groupnorm",
"instancenorm",
):
use_norm = {"type": norm_str}
else:
raise ValueError("Unrecognized normalization type string provided")
elif isinstance(use_norm, bool):
use_norm = {"type": "batchnorm" if use_norm else "identity"}
elif not isinstance(use_norm, dict):
raise ValueError("use_norm must be a dictionary, boolean, or string")

if use_norm["type"] == "inplace" and InPlaceABN is None:
raise RuntimeError(
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
+ "To install see: https://github.com/mapillary/inplace_abn"
"In order to use `use_batchnorm='inplace'` or `use_norm='inplace'` the inplace_abn package must be installed. "
"To install see: https://github.com/mapillary/inplace_abn"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have a separate function get_norm_layer which will validate input params and return norm layer

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Much simpler. Done the changes

Comment on lines 185 to 188
if __name__ == "__main__":
print(Conv2dReLU(3, 12, 4))
print(Conv2dReLU(3, 12, 4, use_norm={"type": "batchnorm"}))
print(Conv2dReLU(3, 12, 4, use_norm={"type": "layernorm", "eps": 1e-3}))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be removed, instead, it would be nice to add a test

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeap, added some test around Conv2dReLu

Comment on lines 13 to 14
use_batchnorm: Union[bool, str, None] = True,
use_norm: Union[bool, str, Dict[str, Any]] = True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have only use_norm here, use_bathcnorm should be replaced on top level (e.g. Unet model)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made a proposition

Comment on lines 31 to 37
def __init__(
self,
in_channels: int,
out_channels: int,
use_batchnorm: Union[bool, str, None] = True,
use_norm: Union[bool, str, Dict[str, Any]] = True,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, let's move it to the model, and leave only use_norm

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made a proposition

@GuillaumeErhard
Copy link
Author

GuillaumeErhard commented Mar 21, 2025

I am missing some test to check before and after equivalency when creating model using decoder_use_batchnorm and decoder_use_norm. I will do this tomorrow. ( tests/encoders/test_batchnorm_deprecation.py )

But I have hit an error that I did not anticipate with the test. (Thank god for test driven dev). GroupNorm uses num_groups, num_channels param instead of using a simple out_channels As of right now I don't see a good way to make it work so I am thinking of dropping the support for it. If you have any idea, I am open for it

Fix issue in Linknet
Handle norm and relu on itsd own
Align pspnet
Add use_norm in upernet
Remove groupnorm possibility
Update doc
@GuillaumeErhard
Copy link
Author

Added the test so that we get the same model before and after the changes. Which showed that Linknet additional changes which are now done.

I respected that all element in model/decoder can still work use_norm can be str / bool / dict for the moment in the case people use them directly. For that I need to sanitize at the model level and on Conv2DRelu. Tell me if you want me to change to dict only in decoder

I also saw that pspnet had also batchnorm usage but through psp_use_batchnorm and aligned it with use_norm.
By the way maybe on line 20 of pspnet/decoder.py it can be changed for a norm instead of identity ?

Removed groupnorm for the moment. Taking any suggestion to make it possible. See problem described above

Added support in upernet to also use_norm.
Is the behavior on l 35 expected ? That is always batchnorm ? Or I can pass use_norm also ?

Make warning visible by changing filter and add a test for it
Fix test before after so that the value is looked and not the shape of tensor
Copy link
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing comments, please see the review 🤗

Comment on lines 12 to 79
def normalize_use_norm(decoder_use_norm: Union[bool, str, Dict[str, Any]]) -> Dict[str, Any]:
if isinstance(decoder_use_norm, str):
norm_str = decoder_use_norm.lower()
if norm_str == "inplace":
decoder_use_norm = {
"type": "inplace",
"activation": "leaky_relu",
"activation_param": 0.0,
}
elif norm_str in (
"batchnorm",
"identity",
"layernorm",
"groupnorm",
"instancenorm",
):
decoder_use_norm = {"type": norm_str}
else:
raise ValueError("Unrecognized normalization type string provided")
elif isinstance(decoder_use_norm, bool):
decoder_use_norm = {"type": "batchnorm" if decoder_use_norm else "identity"}
elif not isinstance(decoder_use_norm, dict):
raise ValueError("use_norm must be a dictionary, boolean, or string")

return decoder_use_norm

def normalize_decoder_norm(decoder_use_batchnorm: Union[bool, str, None], decoder_use_norm: Union[bool, str, Dict[str, Any]]) -> Dict[str, Any]:
if decoder_use_batchnorm is not None:
warnings.warn(
"The usage of use_batchnorm is deprecated. Please modify your code for use_norm",
DeprecationWarning,
stacklevel=2
)
if decoder_use_batchnorm is True:
decoder_use_norm = {"type": "batchnorm"}
elif decoder_use_batchnorm is False:
decoder_use_norm = {"type": "identity"}
elif decoder_use_batchnorm == "inplace":
decoder_use_norm = {
"type": "inplace",
"activation": "leaky_relu",
"activation_param": 0.0,
}
else:
raise ValueError("Unrecognized value for use_batchnorm")

decoder_use_norm = normalize_use_norm(decoder_use_norm)
return decoder_use_norm


def get_norm_layer(use_norm: Dict[str, Any], out_channels: int) -> nn.Module:
norm_type = use_norm["type"]
extra_kwargs = {k: v for k, v in use_norm.items() if k != "type"}

if norm_type == "inplace":
norm = InPlaceABN(out_channels, **extra_kwargs)
elif norm_type == "batchnorm":
norm = nn.BatchNorm2d(out_channels, **extra_kwargs)
elif norm_type == "identity":
norm = nn.Identity()
elif norm_type == "layernorm":
norm = nn.LayerNorm(out_channels, **extra_kwargs)
elif norm_type == "instancenorm":
norm = nn.InstanceNorm2d(out_channels, **extra_kwargs)
else:
raise ValueError(f"Unrecognized normalization type: {norm_type}")

return norm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather put everything into a single simple function as follows:

if use_norm is True:
    params = {"type": "batchnorm"}
elif use_norm is False:
    params = {"type": "identity"}
elif use_norm == "inplace":
    params = {"type": "inplace", "activation": "leaky_relu", "activation_param": 0.0}

if not isinstance(params, dict) -> raise
if not "type" in params -> raise
if not params["type"] in supproted_norms -> raise

<dispatch to norm layer>

return norm 

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So no str for default usage like LayerNorm etc... ??? I kept it for the moment. Tell me if I have to remove.

Did changed accordingly and tried to match the proposed flow as closely as possible

Comment on lines 91 to 95
use_norm = normalize_use_norm(use_norm)
if use_norm["type"] == "inplace" and InPlaceABN is None:
raise RuntimeError(
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
+ "To install see: https://github.com/mapillary/inplace_abn"
"In order to use `use_batchnorm='inplace'` or `use_norm='inplace'` the inplace_abn package must be installed. "
"To install see: https://github.com/mapillary/inplace_abn"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also not needed here, should be under get_norm_layer function

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed and placed in get_norm_layer

@@ -29,21 +101,16 @@ def __init__(
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
bias=use_norm["type"] != "inplace",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can initialize norm first and use a separate varaible

is_inplace_batchnorm = norm.__name__ == "InPlaceABN"

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeap simpler


if use_batchnorm == "inplace":
bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
if use_norm["type"] == "inplace":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -60,7 +79,8 @@ def __init__(
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
decoder_use_batchnorm: bool = True,
decoder_use_batchnorm: Union[bool, str, None] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's removedecoder_use_batchnorm from signature and pop from krawgs later

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed decoder_use_batchnorm and used kwargs. I like this approach

@@ -82,11 +102,12 @@ def __init__(
**kwargs,
)

decoder_use_norm = normalize_decoder_norm(decoder_use_batchnorm, decoder_use_norm)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we just resolve the name

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tell me if this is what you had in mind

Comment on lines 85 to 86
decoder_use_batchnorm: Union[bool, str, None] = None,
decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep default

decoder_use_norm: Union[bool, str, Dict[str, Any]] = True

and remove decoder_use_batchnorm entirely from signature, it will be passed in kwargs and poped from there

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not really that sold with True as default. "batchnorm" seems more meaningful and the bool more of an historic usage.
Changed it to True on model. Tell me if I have to propagate in all the decoder

removed decoder_use_batchnorm and used kwargs. I like this approach

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"batchnorm" seems more meaningful and the bool more of an historic usage.

Ok, sounds good!

Comment on lines 4 to 10
def test_conv2drelu_batchnorm():
module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="batchnorm")

expected = ('Conv2dReLU(\n (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))'
'\n (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)'
'\n (2): ReLU(inplace=True)\n)')
assert repr(module) == expected
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not super robust, formatting can changed in any version, lets use

assert isinstance(module[1], nn.BatchNorm2d)

and below please

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed asserts

Comment on lines 27 to 40
@pytest.mark.parametrize("decoder_option", [True, False, "inplace"])
def test_pspnet_before_after_use_norm(decoder_option):
torch.manual_seed(42)
with pytest.warns(DeprecationWarning):
model_decoder_batchnorm = create_model(
"pspnet",
"mobilenet_v2",
None,
psp_use_batchnorm=decoder_option
)
torch.manual_seed(42)
model_decoder_norm = create_model("pspnet", "mobilenet_v2", None, psp_use_batchnorm=None, decoder_use_norm=decoder_option)

check_two_models_strictly_equal(model_decoder_batchnorm, model_decoder_norm)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

tests/utils.py Outdated
Comment on lines 63 to 67
def check_two_models_strictly_equal(model_a: torch.nn.Module, model_b: torch.nn.Module) -> None:
for (k1, v1), (k2, v2) in zip(model_a.state_dict().items(),
model_b.state_dict().items()):
assert k1 == k2, f"Key mismatch: {k1} != {k2}"
assert (v1 == v2).all(), f"Tensor mismatch at key '{k1}':\n{v1} !=\n{v2}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also add forward pass here and compare logits

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add input_data param to check that

Copy link
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just a few minor comments and we are good to merge!

elif isinstance(use_norm, dict):
norm_params = use_norm
else:
raise ValueError("use_norm must be a dictionary, boolean, or string. Please refer to the documentation.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have a more descriptive error here, I mean specify what kind of string and dict structure it should be.

):
norm_params = {"type": norm_str}
else:
raise ValueError(f"Unrecognized normalization type string provided: {use_norm}. Should be in {supported_norms}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Comment on lines +33 to +48
decoder_use_norm: Specifies normalization between Conv2D and activation.
Accepts the following types:
- **True**: Defaults to `"batchnorm"`.
- **False**: No normalization (`nn.Identity`).
- **str**: Specifies normalization type using default parameters. Available values:
`"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`.
- **dict**: Fully customizable normalization settings. Structure:
```python
{"type": <norm_type>, **kwargs}
```
where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation.

**Example**:
```python
use_norm={"type": "layernorm", "eps": 1e-2}
```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed docstring, really appretiate it!

**Example**:
```python
use_norm={"type": "layernorm", "eps": 1e-2}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
use_norm={"type": "layernorm", "eps": 1e-2}
decoder_use_norm={"type": "layernorm", "eps": 1e-2}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Open other normalization function
2 participants