From e0cda55d4ad86efc74e912c023d9bcc5f6d12608 Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Fri, 8 Aug 2025 09:19:36 +0200 Subject: [PATCH 01/10] added list extend to MultiSampleTrait --- monai/transforms/transform.py | 102 +++++++++++++++++++++++++--------- 1 file changed, 77 insertions(+), 25 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 1a365b8d8e..e7b8268432 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -90,12 +90,22 @@ def _apply_transform( """ from monai.transforms.lazy.functional import apply_pending_transforms_in_order - data = apply_pending_transforms_in_order(transform, data, lazy, overrides, logger_name) + data = apply_pending_transforms_in_order( + transform, data, lazy, overrides, logger_name + ) if isinstance(data, tuple) and unpack_parameters: - return transform(*data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(*data) + return ( + transform(*data, lazy=lazy) + if isinstance(transform, LazyTrait) + else transform(*data) + ) - return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data) + return ( + transform(data, lazy=lazy) + if isinstance(transform, LazyTrait) + else transform(data) + ) def apply_transform( @@ -143,31 +153,49 @@ def apply_transform( try: map_items_ = int(map_items) if isinstance(map_items, bool) else map_items if isinstance(data, (list, tuple)) and map_items_ > 0: - return [ - apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides) - for item in data - ] - return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats) + res = [] + for item in data: + res_item = _apply_transform( + transform, item, unpack_items, lazy, overrides, log_stats + ) + if isinstance(res_item, list | tuple): + res.extend(res_item) + else: + res.append(res_item) + return res + return _apply_transform( + transform, data, unpack_items, lazy, overrides, log_stats + ) except Exception as e: # if in debug mode, don't swallow exception so that the breakpoint # appears where the exception was raised. if MONAIEnvVars.debug(): raise - if log_stats is not False and not isinstance(transform, transforms.compose.Compose): + if log_stats is not False and not isinstance( + transform, transforms.compose.Compose + ): # log the input data information of exact transform in the transform chain if isinstance(log_stats, str): - datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False, name=log_stats) + datastats = transforms.utility.array.DataStats( + data_shape=False, value_range=False, name=log_stats + ) else: - datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False) + datastats = transforms.utility.array.DataStats( + data_shape=False, value_range=False + ) logger = logging.getLogger(datastats._logger_name) - logger.error(f"\n=== Transform input info -- {type(transform).__name__} ===") + logger.error( + f"\n=== Transform input info -- {type(transform).__name__} ===" + ) if isinstance(data, (list, tuple)): data = data[0] def _log_stats(data, prefix: str | None = "Data"): if isinstance(data, (np.ndarray, torch.Tensor)): # log data type, shape, range for array - datastats(img=data, data_shape=True, value_range=True, prefix=prefix) + datastats( + img=data, data_shape=True, value_range=True, prefix=prefix + ) else: # log data type and value for other metadata datastats(img=data, data_value=True, prefix=prefix) @@ -194,7 +222,9 @@ class Randomizable(ThreadUnsafe, RandomizableTrait): R: np.random.RandomState = np.random.RandomState() - def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Randomizable: + def set_random_state( + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> Randomizable: """ Set the random state locally, to control the randomness, the derived classes should use :py:attr:`self.R` instead of `np.random` to introduce random @@ -212,14 +242,20 @@ def set_random_state(self, seed: int | None = None, state: np.random.RandomState """ if seed is not None: - _seed = np.int64(id(seed) if not isinstance(seed, (int, np.integer)) else seed) - _seed = _seed % MAX_SEED # need to account for Numpy2.0 which doesn't silently convert to int64 + _seed = np.int64( + id(seed) if not isinstance(seed, (int, np.integer)) else seed + ) + _seed = ( + _seed % MAX_SEED + ) # need to account for Numpy2.0 which doesn't silently convert to int64 self.R = np.random.RandomState(_seed) return self if state is not None: if not isinstance(state, np.random.RandomState): - raise TypeError(f"state must be None or a np.random.RandomState but is {type(state).__name__}.") + raise TypeError( + f"state must be None or a np.random.RandomState but is {type(state).__name__}." + ) self.R = state return self @@ -238,7 +274,9 @@ def randomize(self, data: Any) -> None: Raises: NotImplementedError: When the subclass does not override this method. """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + raise NotImplementedError( + f"Subclass {self.__class__.__name__} must implement this method." + ) class Transform(ABC): @@ -294,7 +332,9 @@ def __call__(self, data: Any): NotImplementedError: When the subclass does not override this method. """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + raise NotImplementedError( + f"Subclass {self.__class__.__name__} must implement this method." + ) class LazyTransform(Transform, LazyTrait): @@ -397,11 +437,15 @@ def __call__(self, data): def __new__(cls, *args, **kwargs): if config.USE_META_DICT: # call_update after MapTransform.__call__ - cls.__call__ = transforms.attach_hook(cls.__call__, MapTransform.call_update, "post") # type: ignore + cls.__call__ = transforms.attach_hook( + cls.__call__, MapTransform.call_update, "post" + ) # type: ignore if hasattr(cls, "inverse"): # inverse_update before InvertibleTransform.inverse - cls.inverse: Any = transforms.attach_hook(cls.inverse, transforms.InvertibleTransform.inverse_update) + cls.inverse: Any = transforms.attach_hook( + cls.inverse, transforms.InvertibleTransform.inverse_update + ) return Transform.__new__(cls) def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: @@ -412,7 +456,9 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No raise ValueError("keys must be non empty.") for key in self.keys: if not isinstance(key, Hashable): - raise TypeError(f"keys must be one of (Hashable, Iterable[Hashable]) but is {type(keys).__name__}.") + raise TypeError( + f"keys must be one of (Hashable, Iterable[Hashable]) but is {type(keys).__name__}." + ) def call_update(self, data): """ @@ -432,7 +478,9 @@ def call_update(self, data): for k in dict_i: if not isinstance(dict_i[k], MetaTensor): continue - list_d[idx] = transforms.sync_meta_info(k, dict_i, t=not isinstance(self, transforms.InvertD)) + list_d[idx] = transforms.sync_meta_info( + k, dict_i, t=not isinstance(self, transforms.InvertD) + ) return list_d[0] if is_dict else list_d @abstractmethod @@ -460,9 +508,13 @@ def __call__(self, data): An updated dictionary version of ``data`` by applying the transform. """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + raise NotImplementedError( + f"Subclass {self.__class__.__name__} must implement this method." + ) - def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Iterable | None) -> Generator: + def key_iterator( + self, data: Mapping[Hashable, Any], *extra_iterables: Iterable | None + ) -> Generator: """ Iterate across keys and optionally extra iterables. If key is missing, exception is raised if `allow_missing_keys==False` (default). If `allow_missing_keys==True`, key is skipped. From 1ad24af4315caba871b1bc1604951518755fa784 Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Fri, 8 Aug 2025 09:19:36 +0200 Subject: [PATCH 02/10] DCO Remediation Commit for Lukas Folle I, Lukas Folle , hereby add my Signed-off-by to this commit: e0cda55d4ad86efc74e912c023d9bcc5f6d12608 Signed-off-by: Lukas Folle --- monai/transforms/transform.py | 102 +++++++++++++++++++++++++--------- 1 file changed, 77 insertions(+), 25 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 1a365b8d8e..e7b8268432 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -90,12 +90,22 @@ def _apply_transform( """ from monai.transforms.lazy.functional import apply_pending_transforms_in_order - data = apply_pending_transforms_in_order(transform, data, lazy, overrides, logger_name) + data = apply_pending_transforms_in_order( + transform, data, lazy, overrides, logger_name + ) if isinstance(data, tuple) and unpack_parameters: - return transform(*data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(*data) + return ( + transform(*data, lazy=lazy) + if isinstance(transform, LazyTrait) + else transform(*data) + ) - return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data) + return ( + transform(data, lazy=lazy) + if isinstance(transform, LazyTrait) + else transform(data) + ) def apply_transform( @@ -143,31 +153,49 @@ def apply_transform( try: map_items_ = int(map_items) if isinstance(map_items, bool) else map_items if isinstance(data, (list, tuple)) and map_items_ > 0: - return [ - apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides) - for item in data - ] - return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats) + res = [] + for item in data: + res_item = _apply_transform( + transform, item, unpack_items, lazy, overrides, log_stats + ) + if isinstance(res_item, list | tuple): + res.extend(res_item) + else: + res.append(res_item) + return res + return _apply_transform( + transform, data, unpack_items, lazy, overrides, log_stats + ) except Exception as e: # if in debug mode, don't swallow exception so that the breakpoint # appears where the exception was raised. if MONAIEnvVars.debug(): raise - if log_stats is not False and not isinstance(transform, transforms.compose.Compose): + if log_stats is not False and not isinstance( + transform, transforms.compose.Compose + ): # log the input data information of exact transform in the transform chain if isinstance(log_stats, str): - datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False, name=log_stats) + datastats = transforms.utility.array.DataStats( + data_shape=False, value_range=False, name=log_stats + ) else: - datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False) + datastats = transforms.utility.array.DataStats( + data_shape=False, value_range=False + ) logger = logging.getLogger(datastats._logger_name) - logger.error(f"\n=== Transform input info -- {type(transform).__name__} ===") + logger.error( + f"\n=== Transform input info -- {type(transform).__name__} ===" + ) if isinstance(data, (list, tuple)): data = data[0] def _log_stats(data, prefix: str | None = "Data"): if isinstance(data, (np.ndarray, torch.Tensor)): # log data type, shape, range for array - datastats(img=data, data_shape=True, value_range=True, prefix=prefix) + datastats( + img=data, data_shape=True, value_range=True, prefix=prefix + ) else: # log data type and value for other metadata datastats(img=data, data_value=True, prefix=prefix) @@ -194,7 +222,9 @@ class Randomizable(ThreadUnsafe, RandomizableTrait): R: np.random.RandomState = np.random.RandomState() - def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Randomizable: + def set_random_state( + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> Randomizable: """ Set the random state locally, to control the randomness, the derived classes should use :py:attr:`self.R` instead of `np.random` to introduce random @@ -212,14 +242,20 @@ def set_random_state(self, seed: int | None = None, state: np.random.RandomState """ if seed is not None: - _seed = np.int64(id(seed) if not isinstance(seed, (int, np.integer)) else seed) - _seed = _seed % MAX_SEED # need to account for Numpy2.0 which doesn't silently convert to int64 + _seed = np.int64( + id(seed) if not isinstance(seed, (int, np.integer)) else seed + ) + _seed = ( + _seed % MAX_SEED + ) # need to account for Numpy2.0 which doesn't silently convert to int64 self.R = np.random.RandomState(_seed) return self if state is not None: if not isinstance(state, np.random.RandomState): - raise TypeError(f"state must be None or a np.random.RandomState but is {type(state).__name__}.") + raise TypeError( + f"state must be None or a np.random.RandomState but is {type(state).__name__}." + ) self.R = state return self @@ -238,7 +274,9 @@ def randomize(self, data: Any) -> None: Raises: NotImplementedError: When the subclass does not override this method. """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + raise NotImplementedError( + f"Subclass {self.__class__.__name__} must implement this method." + ) class Transform(ABC): @@ -294,7 +332,9 @@ def __call__(self, data: Any): NotImplementedError: When the subclass does not override this method. """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + raise NotImplementedError( + f"Subclass {self.__class__.__name__} must implement this method." + ) class LazyTransform(Transform, LazyTrait): @@ -397,11 +437,15 @@ def __call__(self, data): def __new__(cls, *args, **kwargs): if config.USE_META_DICT: # call_update after MapTransform.__call__ - cls.__call__ = transforms.attach_hook(cls.__call__, MapTransform.call_update, "post") # type: ignore + cls.__call__ = transforms.attach_hook( + cls.__call__, MapTransform.call_update, "post" + ) # type: ignore if hasattr(cls, "inverse"): # inverse_update before InvertibleTransform.inverse - cls.inverse: Any = transforms.attach_hook(cls.inverse, transforms.InvertibleTransform.inverse_update) + cls.inverse: Any = transforms.attach_hook( + cls.inverse, transforms.InvertibleTransform.inverse_update + ) return Transform.__new__(cls) def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: @@ -412,7 +456,9 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No raise ValueError("keys must be non empty.") for key in self.keys: if not isinstance(key, Hashable): - raise TypeError(f"keys must be one of (Hashable, Iterable[Hashable]) but is {type(keys).__name__}.") + raise TypeError( + f"keys must be one of (Hashable, Iterable[Hashable]) but is {type(keys).__name__}." + ) def call_update(self, data): """ @@ -432,7 +478,9 @@ def call_update(self, data): for k in dict_i: if not isinstance(dict_i[k], MetaTensor): continue - list_d[idx] = transforms.sync_meta_info(k, dict_i, t=not isinstance(self, transforms.InvertD)) + list_d[idx] = transforms.sync_meta_info( + k, dict_i, t=not isinstance(self, transforms.InvertD) + ) return list_d[0] if is_dict else list_d @abstractmethod @@ -460,9 +508,13 @@ def __call__(self, data): An updated dictionary version of ``data`` by applying the transform. """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + raise NotImplementedError( + f"Subclass {self.__class__.__name__} must implement this method." + ) - def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Iterable | None) -> Generator: + def key_iterator( + self, data: Mapping[Hashable, Any], *extra_iterables: Iterable | None + ) -> Generator: """ Iterate across keys and optionally extra iterables. If key is missing, exception is raised if `allow_missing_keys==False` (default). If `allow_missing_keys==True`, key is skipped. From eeb7e12d604a4c46f5b44d7484b40aa9cac6e9d3 Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Fri, 8 Aug 2025 09:35:57 +0200 Subject: [PATCH 03/10] fixed type errors --- monai/transforms/transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index e7b8268432..73c4093792 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -153,12 +153,12 @@ def apply_transform( try: map_items_ = int(map_items) if isinstance(map_items, bool) else map_items if isinstance(data, (list, tuple)) and map_items_ > 0: - res = [] + res: list[ReturnType] = [] for item in data: res_item = _apply_transform( transform, item, unpack_items, lazy, overrides, log_stats ) - if isinstance(res_item, list | tuple): + if isinstance(res_item, (list, tuple)): res.extend(res_item) else: res.append(res_item) From c011103a67f54b7994f035d61dc4edc6e1fefb5a Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Fri, 8 Aug 2025 09:35:57 +0200 Subject: [PATCH 04/10] DCO Remediation Commit for Lukas Folle I, Lukas Folle , hereby add my Signed-off-by to this commit: e0cda55d4ad86efc74e912c023d9bcc5f6d12608 Signed-off-by: Lukas Folle --- monai/transforms/transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index e7b8268432..73c4093792 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -153,12 +153,12 @@ def apply_transform( try: map_items_ = int(map_items) if isinstance(map_items, bool) else map_items if isinstance(data, (list, tuple)) and map_items_ > 0: - res = [] + res: list[ReturnType] = [] for item in data: res_item = _apply_transform( transform, item, unpack_items, lazy, overrides, log_stats ) - if isinstance(res_item, list | tuple): + if isinstance(res_item, (list, tuple)): res.extend(res_item) else: res.append(res_item) From 77c138d4751826f810a674744c9f949ee48e6f0d Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Fri, 8 Aug 2025 09:54:20 +0200 Subject: [PATCH 05/10] DCO Remediation Commit for Lukas Folle I, Lukas Folle , hereby add my Signed-off-by to this commit: e0cda55d4ad86efc74e912c023d9bcc5f6d12608 Signed-off-by: Lukas Folle --- monai/transforms/transform.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 73c4093792..0c9d8c3cdf 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -155,9 +155,7 @@ def apply_transform( if isinstance(data, (list, tuple)) and map_items_ > 0: res: list[ReturnType] = [] for item in data: - res_item = _apply_transform( - transform, item, unpack_items, lazy, overrides, log_stats - ) + res_item = _apply_transform(transform, item, unpack_items, lazy, overrides, log_stats) if isinstance(res_item, (list, tuple)): res.extend(res_item) else: From 7df8cb919c43f8343c76b2dd750c3ad832ccdf1b Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Fri, 8 Aug 2025 11:46:18 +0200 Subject: [PATCH 06/10] avoided breaking map_item functionality DCO Remediation Commit for Lukas Folle I, Lukas Folle , hereby add my Signed-off-by to this commit: e0cda55d4ad86efc74e912c023d9bcc5f6d12608 Signed-off-by: Lukas Folle --- monai/transforms/transform.py | 102 ++++++++++------------------------ 1 file changed, 30 insertions(+), 72 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 0c9d8c3cdf..65ef429e33 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -90,22 +90,12 @@ def _apply_transform( """ from monai.transforms.lazy.functional import apply_pending_transforms_in_order - data = apply_pending_transforms_in_order( - transform, data, lazy, overrides, logger_name - ) + data = apply_pending_transforms_in_order(transform, data, lazy, overrides, logger_name) if isinstance(data, tuple) and unpack_parameters: - return ( - transform(*data, lazy=lazy) - if isinstance(transform, LazyTrait) - else transform(*data) - ) + return transform(*data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(*data) - return ( - transform(data, lazy=lazy) - if isinstance(transform, LazyTrait) - else transform(data) - ) + return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data) def apply_transform( @@ -155,45 +145,38 @@ def apply_transform( if isinstance(data, (list, tuple)) and map_items_ > 0: res: list[ReturnType] = [] for item in data: - res_item = _apply_transform(transform, item, unpack_items, lazy, overrides, log_stats) - if isinstance(res_item, (list, tuple)): - res.extend(res_item) + res_item = apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides) + # Only extend if we're at the leaf level (map_items_ == 1) and the transform + # actually returned a list (not preserving nested structure) + if isinstance(res_item, list) and map_items_ == 1: + if not isinstance(item, (list, tuple)): + res.extend(res_item) + else: + res.append(res_item) else: res.append(res_item) return res - return _apply_transform( - transform, data, unpack_items, lazy, overrides, log_stats - ) + return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats) except Exception as e: # if in debug mode, don't swallow exception so that the breakpoint # appears where the exception was raised. if MONAIEnvVars.debug(): raise - if log_stats is not False and not isinstance( - transform, transforms.compose.Compose - ): + if log_stats is not False and not isinstance(transform, transforms.compose.Compose): # log the input data information of exact transform in the transform chain if isinstance(log_stats, str): - datastats = transforms.utility.array.DataStats( - data_shape=False, value_range=False, name=log_stats - ) + datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False, name=log_stats) else: - datastats = transforms.utility.array.DataStats( - data_shape=False, value_range=False - ) + datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False) logger = logging.getLogger(datastats._logger_name) - logger.error( - f"\n=== Transform input info -- {type(transform).__name__} ===" - ) + logger.error(f"\n=== Transform input info -- {type(transform).__name__} ===") if isinstance(data, (list, tuple)): data = data[0] def _log_stats(data, prefix: str | None = "Data"): if isinstance(data, (np.ndarray, torch.Tensor)): # log data type, shape, range for array - datastats( - img=data, data_shape=True, value_range=True, prefix=prefix - ) + datastats(img=data, data_shape=True, value_range=True, prefix=prefix) else: # log data type and value for other metadata datastats(img=data, data_value=True, prefix=prefix) @@ -220,9 +203,7 @@ class Randomizable(ThreadUnsafe, RandomizableTrait): R: np.random.RandomState = np.random.RandomState() - def set_random_state( - self, seed: int | None = None, state: np.random.RandomState | None = None - ) -> Randomizable: + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Randomizable: """ Set the random state locally, to control the randomness, the derived classes should use :py:attr:`self.R` instead of `np.random` to introduce random @@ -240,20 +221,14 @@ def set_random_state( """ if seed is not None: - _seed = np.int64( - id(seed) if not isinstance(seed, (int, np.integer)) else seed - ) - _seed = ( - _seed % MAX_SEED - ) # need to account for Numpy2.0 which doesn't silently convert to int64 + _seed = np.int64(id(seed) if not isinstance(seed, (int, np.integer)) else seed) + _seed = _seed % MAX_SEED # need to account for Numpy2.0 which doesn't silently convert to int64 self.R = np.random.RandomState(_seed) return self if state is not None: if not isinstance(state, np.random.RandomState): - raise TypeError( - f"state must be None or a np.random.RandomState but is {type(state).__name__}." - ) + raise TypeError(f"state must be None or a np.random.RandomState but is {type(state).__name__}.") self.R = state return self @@ -272,9 +247,7 @@ def randomize(self, data: Any) -> None: Raises: NotImplementedError: When the subclass does not override this method. """ - raise NotImplementedError( - f"Subclass {self.__class__.__name__} must implement this method." - ) + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") class Transform(ABC): @@ -330,9 +303,7 @@ def __call__(self, data: Any): NotImplementedError: When the subclass does not override this method. """ - raise NotImplementedError( - f"Subclass {self.__class__.__name__} must implement this method." - ) + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") class LazyTransform(Transform, LazyTrait): @@ -435,15 +406,11 @@ def __call__(self, data): def __new__(cls, *args, **kwargs): if config.USE_META_DICT: # call_update after MapTransform.__call__ - cls.__call__ = transforms.attach_hook( - cls.__call__, MapTransform.call_update, "post" - ) # type: ignore + cls.__call__ = transforms.attach_hook(cls.__call__, MapTransform.call_update, "post") # type: ignore if hasattr(cls, "inverse"): # inverse_update before InvertibleTransform.inverse - cls.inverse: Any = transforms.attach_hook( - cls.inverse, transforms.InvertibleTransform.inverse_update - ) + cls.inverse: Any = transforms.attach_hook(cls.inverse, transforms.InvertibleTransform.inverse_update) return Transform.__new__(cls) def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: @@ -454,9 +421,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No raise ValueError("keys must be non empty.") for key in self.keys: if not isinstance(key, Hashable): - raise TypeError( - f"keys must be one of (Hashable, Iterable[Hashable]) but is {type(keys).__name__}." - ) + raise TypeError(f"keys must be one of (Hashable, Iterable[Hashable]) but is {type(keys).__name__}.") def call_update(self, data): """ @@ -476,9 +441,7 @@ def call_update(self, data): for k in dict_i: if not isinstance(dict_i[k], MetaTensor): continue - list_d[idx] = transforms.sync_meta_info( - k, dict_i, t=not isinstance(self, transforms.InvertD) - ) + list_d[idx] = transforms.sync_meta_info(k, dict_i, t=not isinstance(self, transforms.InvertD)) return list_d[0] if is_dict else list_d @abstractmethod @@ -506,13 +469,9 @@ def __call__(self, data): An updated dictionary version of ``data`` by applying the transform. """ - raise NotImplementedError( - f"Subclass {self.__class__.__name__} must implement this method." - ) + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - def key_iterator( - self, data: Mapping[Hashable, Any], *extra_iterables: Iterable | None - ) -> Generator: + def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Iterable | None) -> Generator: """ Iterate across keys and optionally extra iterables. If key is missing, exception is raised if `allow_missing_keys==False` (default). If `allow_missing_keys==True`, key is skipped. @@ -532,8 +491,7 @@ def key_iterator( yield (key,) + tuple(_ex_iters) if extra_iterables else key elif not self.allow_missing_keys: raise KeyError( - f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data" - " and allow_missing_keys==False." + f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data and allow_missing_keys==False." ) def first_key(self, data: dict[Hashable, Any]): From be4601826787f1991b34369bb3a678d5da252c55 Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Fri, 8 Aug 2025 13:26:34 +0200 Subject: [PATCH 07/10] fixed wrong type annotation DCO Remediation Commit for Lukas Folle I, Lukas Folle , hereby add my Signed-off-by to this commit: e0cda55d4ad86efc74e912c023d9bcc5f6d12608 Signed-off-by: Lukas Folle --- monai/transforms/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 65ef429e33..d9a16d53e7 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -143,7 +143,7 @@ def apply_transform( try: map_items_ = int(map_items) if isinstance(map_items, bool) else map_items if isinstance(data, (list, tuple)) and map_items_ > 0: - res: list[ReturnType] = [] + res: list[Any] = [] for item in data: res_item = apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides) # Only extend if we're at the leaf level (map_items_ == 1) and the transform From 2d5877455022444675e200afee36077a709fa784 Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Tue, 16 Sep 2025 11:36:53 +0200 Subject: [PATCH 08/10] added test for many multisample transforms; refactored code --- monai/transforms/transform.py | 7 ++----- tests/transforms/compose/test_compose.py | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index d9a16d53e7..05a08e0743 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -148,11 +148,8 @@ def apply_transform( res_item = apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides) # Only extend if we're at the leaf level (map_items_ == 1) and the transform # actually returned a list (not preserving nested structure) - if isinstance(res_item, list) and map_items_ == 1: - if not isinstance(item, (list, tuple)): - res.extend(res_item) - else: - res.append(res_item) + if isinstance(res_item, list) and map_items_ == 1 and not isinstance(item, (list, tuple)): + res.extend(res_item) else: res.append(res_item) return res diff --git a/tests/transforms/compose/test_compose.py b/tests/transforms/compose/test_compose.py index e6727c976f..9abf635b13 100644 --- a/tests/transforms/compose/test_compose.py +++ b/tests/transforms/compose/test_compose.py @@ -282,6 +282,28 @@ def test_flatten_and_len(self): def test_backwards_compatible_imports(self): from monai.transforms.transform import MapTransform, RandomizableTransform, Transform # noqa: F401 + def test_list_extend_multi_sample_trait(self): + from monai.transforms import CenterSpatialCrop, RandSpatialCropSamples + + center_crop = CenterSpatialCrop([128, 128]) + multi_sample_transform = RandSpatialCropSamples([64, 64], 1) + + img = torch.zeros([1, 512, 512]) + + assert execute_compose(img, [center_crop]).shape == torch.Size([1, 128, 128]) + single_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, center_crop]) + assert ( + isinstance(single_multi_sample_trait_result, list) + and len(single_multi_sample_trait_result) == 1 + and single_multi_sample_trait_result[0].shape == torch.Size([1, 64, 64]) + ) + double_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, multi_sample_transform, center_crop]) + assert ( + isinstance(double_multi_sample_trait_result, list) + and len(double_multi_sample_trait_result) == 1 + and double_multi_sample_trait_result[0].shape == torch.Size([1, 64, 64]) + ) + TEST_COMPOSE_EXECUTE_TEST_CASES = [ [None, tuple()], From ee74761cb623688b7d8a7a8100f78dcfdd16e365 Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Tue, 16 Sep 2025 11:36:53 +0200 Subject: [PATCH 09/10] DCO Remediation Commit for Lukas Folle I, Lukas Folle , hereby add my Signed-off-by to this commit: 2d5877455022444675e200afee36077a709fa784 Signed-off-by: Lukas Folle --- monai/transforms/transform.py | 7 ++----- tests/transforms/compose/test_compose.py | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index d9a16d53e7..05a08e0743 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -148,11 +148,8 @@ def apply_transform( res_item = apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides) # Only extend if we're at the leaf level (map_items_ == 1) and the transform # actually returned a list (not preserving nested structure) - if isinstance(res_item, list) and map_items_ == 1: - if not isinstance(item, (list, tuple)): - res.extend(res_item) - else: - res.append(res_item) + if isinstance(res_item, list) and map_items_ == 1 and not isinstance(item, (list, tuple)): + res.extend(res_item) else: res.append(res_item) return res diff --git a/tests/transforms/compose/test_compose.py b/tests/transforms/compose/test_compose.py index e6727c976f..9abf635b13 100644 --- a/tests/transforms/compose/test_compose.py +++ b/tests/transforms/compose/test_compose.py @@ -282,6 +282,28 @@ def test_flatten_and_len(self): def test_backwards_compatible_imports(self): from monai.transforms.transform import MapTransform, RandomizableTransform, Transform # noqa: F401 + def test_list_extend_multi_sample_trait(self): + from monai.transforms import CenterSpatialCrop, RandSpatialCropSamples + + center_crop = CenterSpatialCrop([128, 128]) + multi_sample_transform = RandSpatialCropSamples([64, 64], 1) + + img = torch.zeros([1, 512, 512]) + + assert execute_compose(img, [center_crop]).shape == torch.Size([1, 128, 128]) + single_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, center_crop]) + assert ( + isinstance(single_multi_sample_trait_result, list) + and len(single_multi_sample_trait_result) == 1 + and single_multi_sample_trait_result[0].shape == torch.Size([1, 64, 64]) + ) + double_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, multi_sample_transform, center_crop]) + assert ( + isinstance(double_multi_sample_trait_result, list) + and len(double_multi_sample_trait_result) == 1 + and double_multi_sample_trait_result[0].shape == torch.Size([1, 64, 64]) + ) + TEST_COMPOSE_EXECUTE_TEST_CASES = [ [None, tuple()], From 9377b63faf9cbd16d35990532df12ea94b61c6e6 Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Tue, 16 Sep 2025 11:51:08 +0200 Subject: [PATCH 10/10] added slight cleanup and additional test DCO Remediation Commit for Lukas Folle I, Lukas Folle , hereby add my Signed-off-by to this commit: 2d5877455022444675e200afee36077a709fa784 Signed-off-by: Lukas Folle --- tests/transforms/compose/test_compose.py | 36 ++++++++++++++---------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/tests/transforms/compose/test_compose.py b/tests/transforms/compose/test_compose.py index 9abf635b13..01c7d92e7d 100644 --- a/tests/transforms/compose/test_compose.py +++ b/tests/transforms/compose/test_compose.py @@ -283,26 +283,32 @@ def test_backwards_compatible_imports(self): from monai.transforms.transform import MapTransform, RandomizableTransform, Transform # noqa: F401 def test_list_extend_multi_sample_trait(self): - from monai.transforms import CenterSpatialCrop, RandSpatialCropSamples - - center_crop = CenterSpatialCrop([128, 128]) - multi_sample_transform = RandSpatialCropSamples([64, 64], 1) + center_crop = mt.CenterSpatialCrop([128, 128]) + multi_sample_transform = mt.RandSpatialCropSamples([64, 64], 1) img = torch.zeros([1, 512, 512]) - assert execute_compose(img, [center_crop]).shape == torch.Size([1, 128, 128]) + self.assertEqual(execute_compose(img, [center_crop]).shape, torch.Size([1, 128, 128])) single_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, center_crop]) - assert ( - isinstance(single_multi_sample_trait_result, list) - and len(single_multi_sample_trait_result) == 1 - and single_multi_sample_trait_result[0].shape == torch.Size([1, 64, 64]) - ) + self.assertIsInstance(single_multi_sample_trait_result, list) + self.assertEqual(len(single_multi_sample_trait_result), 1) + self.assertEqual(single_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64])) + double_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, multi_sample_transform, center_crop]) - assert ( - isinstance(double_multi_sample_trait_result, list) - and len(double_multi_sample_trait_result) == 1 - and double_multi_sample_trait_result[0].shape == torch.Size([1, 64, 64]) - ) + self.assertIsInstance(double_multi_sample_trait_result, list) + self.assertEqual(len(double_multi_sample_trait_result), 1) + self.assertEqual(double_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64])) + + def test_multi_sample_trait_cardinality(self): + img = torch.zeros([1, 128, 128]) + t2 = mt.RandSpatialCropSamples([32, 32], num_samples=2) + + # chaining should multiply counts: 2 x 2 = 4, flattened + res = execute_compose(img, [t2, t2]) + self.assertIsInstance(res, list) + self.assertEqual(len(res), 4) + for r in res: + self.assertEqual(r.shape, torch.Size([1, 32, 32])) TEST_COMPOSE_EXECUTE_TEST_CASES = [