Skip to content

Remove deprecated functionality for v1.5 #8430

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 9, 2025
1 change: 0 additions & 1 deletion docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ Metrics
`Hausdorff distance`
--------------------
.. autofunction:: compute_hausdorff_distance
.. autofunction:: compute_percent_hausdorff_distance

.. autoclass:: HausdorffDistanceMetric
:members:
Expand Down
4 changes: 2 additions & 2 deletions monai/apps/deepgrow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,8 @@ def __call__(self, data):

if np.all(np.less(current_size, self.spatial_size)):
cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size)
box_start = np.array([s.start for s in cropper.slices])
box_end = np.array([s.stop for s in cropper.slices])
box_start = [s.start for s in cropper.slices]
box_end = [s.stop for s in cropper.slices]
else:
cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)

Expand Down
25 changes: 1 addition & 24 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

from monai._version import get_versions
from monai.apps.utils import _basename, download_url, extractall, get_logger
from monai.bundle.config_item import ConfigComponent
from monai.bundle.config_parser import ConfigParser
from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA, merge_kv
from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow
Expand All @@ -48,7 +47,6 @@
from monai.utils import (
IgniteInfo,
check_parent_dir,
deprecated_arg,
ensure_tuple,
get_equivalent_dtype,
min_version,
Expand Down Expand Up @@ -629,9 +627,6 @@ def download(
_check_monai_version(bundle_dir_, name_)


@deprecated_arg("net_name", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.")
@deprecated_arg("net_kwargs", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.")
@deprecated_arg("return_state_dict", since="1.2", removed="1.5")
def load(
name: str,
model: torch.nn.Module | None = None,
Expand All @@ -650,10 +645,7 @@ def load(
workflow_name: str | BundleWorkflow | None = None,
args_file: str | None = None,
copy_model_args: dict | None = None,
return_state_dict: bool = True,
net_override: dict | None = None,
net_name: str | None = None,
**net_kwargs: Any,
) -> object | tuple[torch.nn.Module, dict, dict] | Any:
"""
Load model weights or TorchScript module of a bundle.
Expand Down Expand Up @@ -699,12 +691,7 @@ def load(
workflow_name: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow".
args_file: a JSON or YAML file to provide default values for all the args in "download" function.
copy_model_args: other arguments for the `monai.networks.copy_model_state` function.
return_state_dict: whether to return state dict, if True, return state_dict, else a corresponding network
from `_workflow.network_def` will be instantiated and load the achieved weights.
net_override: id-value pairs to override the parameters in the network of the bundle, default to `None`.
net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights.
This argument only works when loading weights.
net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`.

Returns:
1. If `load_ts_module` is `False` and `model` is `None`,
Expand All @@ -719,9 +706,6 @@ def load(
when `model` and `net_name` are all `None`.

"""
if return_state_dict and (model is not None or net_name is not None):
warnings.warn("Incompatible values: model and net_name are all specified, return state dict instead.")

bundle_dir_ = _process_bundle_dir(bundle_dir)
net_override = {} if net_override is None else net_override
copy_model_args = {} if copy_model_args is None else copy_model_args
Expand Down Expand Up @@ -757,11 +741,8 @@ def load(
warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.")
model_dict = get_state_dict(model_dict)

if return_state_dict:
return model_dict

_workflow = None
if model is None and net_name is None:
if model is None:
bundle_config_file = bundle_dir_ / name / "configs" / f"{workflow_type}.json"
if bundle_config_file.is_file():
_net_override = {f"network_def#{key}": value for key, value in net_override.items()}
Expand All @@ -781,10 +762,6 @@ def load(
return model_dict
else:
model = _workflow.network_def
elif net_name is not None:
net_kwargs["_target_"] = net_name
configer = ConfigComponent(config=net_kwargs)
model = configer.instantiate() # type: ignore

model.to(device) # type: ignore

Expand Down
28 changes: 1 addition & 27 deletions monai/bundle/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from monai.bundle.properties import InferProperties, MetaProperties, TrainProperties
from monai.bundle.utils import DEFAULT_EXP_MGMT_SETTINGS, EXPR_KEY, ID_REF_KEY, ID_SEP_KEY
from monai.config import PathLike
from monai.utils import BundleProperty, BundlePropertyConfig, deprecated_arg, ensure_tuple
from monai.utils import BundleProperty, BundlePropertyConfig, ensure_tuple

__all__ = ["BundleWorkflow", "ConfigWorkflow"]

Expand All @@ -45,10 +45,6 @@ class BundleWorkflow(ABC):
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
default to `None` for only using meta properties.
workflow: specifies the workflow type: "train" or "training" for a training workflow,
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
default to `None` for common workflow.
properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be
loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified,
properties will default to loading from "meta". If `properties_path` is None, default properties
Expand All @@ -65,17 +61,9 @@ class BundleWorkflow(ABC):
supported_train_type: tuple = ("train", "training")
supported_infer_type: tuple = ("infer", "inference", "eval", "evaluation")

@deprecated_arg(
"workflow",
since="1.2",
removed="1.5",
new_name="workflow_type",
msg_suffix="please use `workflow_type` instead.",
)
def __init__(
self,
workflow_type: str | None = None,
workflow: str | None = None,
properties_path: PathLike | None = None,
meta_file: str | Sequence[str] | None = None,
logging_file: str | None = None,
Expand All @@ -102,7 +90,6 @@ def __init__(
)
meta_file = None

workflow_type = workflow if workflow is not None else workflow_type
if workflow_type is not None:
if workflow_type.lower() in self.supported_train_type:
workflow_type = "train"
Expand Down Expand Up @@ -403,10 +390,6 @@ class ConfigWorkflow(BundleWorkflow):
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
default to `None` for common workflow.
workflow: specifies the workflow type: "train" or "training" for a training workflow,
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
default to `None` for common workflow.
properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be
loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified,
properties will default to loading from "train". If `properties_path` is None, default properties
Expand All @@ -419,13 +402,6 @@ class ConfigWorkflow(BundleWorkflow):

"""

@deprecated_arg(
"workflow",
since="1.2",
removed="1.5",
new_name="workflow_type",
msg_suffix="please use `workflow_type` instead.",
)
def __init__(
self,
config_file: str | Sequence[str],
Expand All @@ -436,11 +412,9 @@ def __init__(
final_id: str = "finalize",
tracking: str | dict | None = None,
workflow_type: str | None = "train",
workflow: str | None = None,
properties_path: PathLike | None = None,
**override: Any,
) -> None:
workflow_type = workflow if workflow is not None else workflow_type
if config_file is not None:
_config_files = ensure_tuple(config_file)
config_root_path = Path(_config_files[0]).parent
Expand Down
2 changes: 1 addition & 1 deletion monai/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .fid import FIDMetric, compute_frechet_distance
from .froc import compute_fp_tp_probs, compute_fp_tp_probs_nd, compute_froc_curve_data, compute_froc_score
from .generalized_dice import GeneralizedDiceScore, compute_generalized_dice
from .hausdorff_distance import HausdorffDistanceMetric, compute_hausdorff_distance, compute_percent_hausdorff_distance
from .hausdorff_distance import HausdorffDistanceMetric, compute_hausdorff_distance
from .loss_metric import LossMetric
from .meandice import DiceHelper, DiceMetric, compute_dice
from .meaniou import MeanIoU, compute_iou
Expand Down
17 changes: 4 additions & 13 deletions monai/metrics/generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch

from monai.metrics.utils import do_metric_reduction, ignore_background
from monai.utils import MetricReduction, Weight, deprecated_arg, deprecated_arg_default, look_up_option
from monai.utils import MetricReduction, Weight, deprecated_arg, look_up_option

from .metric import CumulativeIterationMetric

Expand All @@ -37,28 +37,19 @@ class GeneralizedDiceScore(CumulativeIterationMetric):
reduction: Define mode of reduction to the metrics. Available reduction modes:
{``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
Default value is changed from `MetricReduction.MEAN_BATCH` to `MetricReduction.MEAN` in v1.5.0.
Old versions computed `mean` when `mean_batch` was provided due to bug in reduction.
weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
ground truth volume into a weight factor. Defaults to ``"square"``.

Raises:
ValueError: When the `reduction` is not one of MetricReduction enum.
"""

@deprecated_arg_default(
"reduction",
old_default=MetricReduction.MEAN_BATCH,
new_default=MetricReduction.MEAN,
since="1.4.0",
replaced="1.5.0",
msg_suffix=(
"Old versions computed `mean` when `mean_batch` was provided due to bug in reduction, "
"If you want to retain the old behavior (calculating the mean), please explicitly set the parameter to 'mean'."
),
)
def __init__(
self,
include_background: bool = True,
reduction: MetricReduction | str = MetricReduction.MEAN_BATCH,
reduction: MetricReduction | str = MetricReduction.MEAN,
weight_type: Weight | str = Weight.SQUARE,
) -> None:
super().__init__()
Expand Down
40 changes: 3 additions & 37 deletions monai/metrics/hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,12 @@
import numpy as np
import torch

from monai.metrics.utils import (
do_metric_reduction,
get_edge_surface_distance,
get_surface_distance,
ignore_background,
prepare_spacing,
)
from monai.utils import MetricReduction, convert_data_type, deprecated
from monai.metrics.utils import do_metric_reduction, get_edge_surface_distance, ignore_background, prepare_spacing
from monai.utils import MetricReduction, convert_data_type

from .metric import CumulativeIterationMetric

__all__ = ["HausdorffDistanceMetric", "compute_hausdorff_distance", "compute_percent_hausdorff_distance"]
__all__ = ["HausdorffDistanceMetric", "compute_hausdorff_distance"]


class HausdorffDistanceMetric(CumulativeIterationMetric):
Expand Down Expand Up @@ -216,31 +210,3 @@ def _compute_percentile_hausdorff_distance(
if 0 <= percentile <= 100:
return torch.quantile(surface_distance, percentile / 100)
raise ValueError(f"percentile should be a value between 0 and 100, get {percentile}.")


@deprecated(since="1.3.0", removed="1.5.0")
def compute_percent_hausdorff_distance(
edges_pred: np.ndarray,
edges_gt: np.ndarray,
distance_metric: str = "euclidean",
percentile: float | None = None,
spacing: int | float | np.ndarray | Sequence[int | float] | None = None,
) -> float:
"""
This function is used to compute the directed Hausdorff distance.
"""

surface_distance: np.ndarray = get_surface_distance( # type: ignore
edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing
)

# for both pred and gt do not have foreground
if surface_distance.shape == (0,):
return np.nan

if not percentile:
return surface_distance.max() # type: ignore[no-any-return]

if 0 <= percentile <= 100:
return np.percentile(surface_distance, percentile) # type: ignore[no-any-return]
raise ValueError(f"percentile should be a value between 0 and 100, get {percentile}.")
7 changes: 2 additions & 5 deletions monai/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
convert_to_numpy,
convert_to_tensor,
deprecated_arg,
deprecated_arg_default,
ensure_tuple_rep,
look_up_option,
optional_import,
Expand Down Expand Up @@ -131,9 +130,6 @@ def do_metric_reduction(
return f, not_nans


@deprecated_arg_default(
name="always_return_as_numpy", since="1.3.0", replaced="1.5.0", old_default=True, new_default=False
)
@deprecated_arg(
name="always_return_as_numpy",
since="1.5.0",
Expand All @@ -146,7 +142,7 @@ def get_mask_edges(
label_idx: int = 1,
crop: bool = True,
spacing: Sequence | None = None,
always_return_as_numpy: bool = True,
always_return_as_numpy: bool = False,
) -> tuple[NdarrayTensor, NdarrayTensor]:
"""
Compute edges from binary segmentation masks. This
Expand Down Expand Up @@ -175,6 +171,7 @@ def get_mask_edges(
otherwise `scipy`'s binary erosion is used to calculate the edges.
always_return_as_numpy: whether to a numpy array regardless of the input type.
If False, return the same type as inputs.
The default value is changed from `True` to `False` in v1.5.0.
"""
# move in the funciton to avoid using all the GPUs
cucim_binary_erosion, has_cucim_binary_erosion = optional_import("cucim.skimage.morphology", name="binary_erosion")
Expand Down
24 changes: 4 additions & 20 deletions monai/networks/nets/swin_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock
from monai.networks.layers import DropPath, trunc_normal_
from monai.utils import ensure_tuple_rep, look_up_option, optional_import
from monai.utils.deprecate_utils import deprecated_arg

rearrange, _ = optional_import("einops", name="rearrange")

Expand All @@ -50,16 +49,8 @@ class SwinUNETR(nn.Module):
<https://arxiv.org/abs/2201.01266>"
"""

@deprecated_arg(
name="img_size",
since="1.3",
removed="1.5",
msg_suffix="The img_size argument is not required anymore and "
"checks on the input size are run during forward().",
)
def __init__(
self,
img_size: Sequence[int] | int,
in_channels: int,
out_channels: int,
patch_size: int = 2,
Expand All @@ -83,10 +74,6 @@ def __init__(
) -> None:
"""
Args:
img_size: spatial dimension of input image.
This argument is only used for checking that the input image size is divisible by the patch size.
The tensor passed to forward() can have a dynamic shape as long as its spatial dimensions are divisible by 2**5.
It will be removed in an upcoming version.
in_channels: dimension of input channels.
out_channels: dimension of output channels.
patch_size: size of the patch token.
Expand All @@ -113,13 +100,13 @@ def __init__(
Examples::

# for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
>>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)
>>> net = SwinUNETR(in_channels=1, out_channels=4, feature_size=48)

# for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
>>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))
>>> net = SwinUNETR(in_channels=4, out_channels=3, depths=(2,4,2,2))

# for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
>>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
>>> net = SwinUNETR(in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)

"""

Expand All @@ -130,12 +117,9 @@ def __init__(

self.patch_size = patch_size

img_size = ensure_tuple_rep(img_size, spatial_dims)
patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims)
window_size = ensure_tuple_rep(window_size, spatial_dims)

self._check_input_size(img_size)

if not (0 <= drop_rate <= 1):
raise ValueError("dropout rate should be between 0 and 1.")

Expand Down Expand Up @@ -1109,7 +1093,7 @@ def filter_swinunetr(key, value):
from monai.networks.utils import copy_model_state
from monai.networks.nets.swin_unetr import SwinUNETR, filter_swinunetr

model = SwinUNETR(img_size=(96, 96, 96), in_channels=1, out_channels=3, feature_size=48)
model = SwinUNETR(in_channels=1, out_channels=3, feature_size=48)
resource = (
"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth"
)
Expand Down
Loading
Loading