diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 1a365b8d8e..05a08e0743 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -143,10 +143,16 @@ def apply_transform( try: map_items_ = int(map_items) if isinstance(map_items, bool) else map_items if isinstance(data, (list, tuple)) and map_items_ > 0: - return [ - apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides) - for item in data - ] + res: list[Any] = [] + for item in data: + res_item = apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides) + # Only extend if we're at the leaf level (map_items_ == 1) and the transform + # actually returned a list (not preserving nested structure) + if isinstance(res_item, list) and map_items_ == 1 and not isinstance(item, (list, tuple)): + res.extend(res_item) + else: + res.append(res_item) + return res return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats) except Exception as e: # if in debug mode, don't swallow exception so that the breakpoint @@ -482,8 +488,7 @@ def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Iterable yield (key,) + tuple(_ex_iters) if extra_iterables else key elif not self.allow_missing_keys: raise KeyError( - f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data" - " and allow_missing_keys==False." + f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data and allow_missing_keys==False." ) def first_key(self, data: dict[Hashable, Any]): diff --git a/tests/transforms/compose/test_compose.py b/tests/transforms/compose/test_compose.py index e6727c976f..01c7d92e7d 100644 --- a/tests/transforms/compose/test_compose.py +++ b/tests/transforms/compose/test_compose.py @@ -282,6 +282,34 @@ def test_flatten_and_len(self): def test_backwards_compatible_imports(self): from monai.transforms.transform import MapTransform, RandomizableTransform, Transform # noqa: F401 + def test_list_extend_multi_sample_trait(self): + center_crop = mt.CenterSpatialCrop([128, 128]) + multi_sample_transform = mt.RandSpatialCropSamples([64, 64], 1) + + img = torch.zeros([1, 512, 512]) + + self.assertEqual(execute_compose(img, [center_crop]).shape, torch.Size([1, 128, 128])) + single_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, center_crop]) + self.assertIsInstance(single_multi_sample_trait_result, list) + self.assertEqual(len(single_multi_sample_trait_result), 1) + self.assertEqual(single_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64])) + + double_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, multi_sample_transform, center_crop]) + self.assertIsInstance(double_multi_sample_trait_result, list) + self.assertEqual(len(double_multi_sample_trait_result), 1) + self.assertEqual(double_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64])) + + def test_multi_sample_trait_cardinality(self): + img = torch.zeros([1, 128, 128]) + t2 = mt.RandSpatialCropSamples([32, 32], num_samples=2) + + # chaining should multiply counts: 2 x 2 = 4, flattened + res = execute_compose(img, [t2, t2]) + self.assertIsInstance(res, list) + self.assertEqual(len(res), 4) + for r in res: + self.assertEqual(r.shape, torch.Size([1, 32, 32])) + TEST_COMPOSE_EXECUTE_TEST_CASES = [ [None, tuple()],