diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 8adba8fa25..00ead9a2d5 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -60,6 +60,7 @@ def sliding_window_inference( *args: Any, **kwargs: Any, ) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]: + """ Sliding window inference on `inputs` with `predictor`. @@ -134,6 +135,14 @@ def sliding_window_inference( - input must be channel-first and have a batch dim, supports N-D sliding window. """ + + # auto transform (N,D,H,W,C) → (N,C,D,H,W) + if isinstance(inputs, torch.Tensor) and inputs.ndim == 5 and inputs.shape[-1] in (1, 3, 4): + inputs = inputs.permute(0, 4, 1, 2, 3).contiguous() + + + + buffered = buffer_steps is not None and buffer_steps > 0 num_spatial_dims = len(inputs.shape) - 2 if buffered: diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index 0802cc3364..841585897f 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -134,6 +134,12 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor Raises: ValueError: when `y_pred` has fewer than three dimensions. """ + + if isinstance(y_pred, torch.Tensor) and y_pred.ndim == 5 and y_pred.shape[-1] in (1, 3, 4): + y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous() + if isinstance(y, torch.Tensor) and y.ndim == 5 and y.shape[-1] in (1, 3, 4): + y = y.permute(0, 4, 1, 2, 3).contiguous() + dims = y_pred.ndimension() if dims < 3: raise ValueError(f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.")