Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
e0cda55
added list extend to MultiSampleTrait
lukas-folle-snkeos Aug 8, 2025
1ad24af
DCO Remediation Commit for Lukas Folle <[email protected]>
lukas-folle-snkeos Aug 8, 2025
35658f2
Merge branch 'dev' of github.com:lukas-folle-snkeos/MONAI into dev
lukas-folle-snkeos Aug 8, 2025
eeb7e12
fixed type errors
lukas-folle-snkeos Aug 8, 2025
c011103
DCO Remediation Commit for Lukas Folle <[email protected]>
lukas-folle-snkeos Aug 8, 2025
6bb6110
Merge branch 'dev' of github.com:lukas-folle-snkeos/MONAI into dev
lukas-folle-snkeos Aug 8, 2025
e7a9185
DCO Remediation Commit for Lukas Folle <[email protected]>
lukas-folle-snkeos Aug 8, 2025
a5d2261
Merge branch 'dev' of github.com:lukas-folle-snkeos/MONAI into dev
lukas-folle-snkeos Aug 8, 2025
b0dd089
DCO Remediation Commit for Lukas Folle <[email protected]>
lukas-folle-snkeos Aug 8, 2025
7560a37
Merge branch 'dev' of github.com:lukas-folle-snkeos/MONAI into dev
lukas-folle-snkeos Aug 8, 2025
77c138d
DCO Remediation Commit for Lukas Folle <[email protected]>
lukas-folle-snkeos Aug 8, 2025
7df8cb9
avoided breaking map_item functionality
lukas-folle-snkeos Aug 8, 2025
be46018
fixed wrong type annotation
lukas-folle-snkeos Aug 8, 2025
3aa1288
Merge branch 'dev' into dev
ericspod Aug 11, 2025
2d58774
added test for many multisample transforms; refactored code
lukas-folle-snkeos Sep 16, 2025
ee74761
DCO Remediation Commit for Lukas Folle <[email protected]>
lukas-folle-snkeos Sep 16, 2025
2c18f36
Merge branch 'dev' of github.com:lukas-folle-snkeos/MONAI into dev
lukas-folle-snkeos Sep 16, 2025
7ae8a26
DCO Remediation Commit for Lukas Folle <[email protected]>
lukas-folle-snkeos Sep 16, 2025
1d04028
DCO Remediation Commit for Lukas Folle <[email protected]>
lukas-folle-snkeos Sep 16, 2025
9377b63
added slight cleanup and additional test
lukas-folle-snkeos Sep 16, 2025
a56c0c3
Merge branch 'dev' into dev
lukas-folle-snkeos Sep 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,16 @@ def apply_transform(
try:
map_items_ = int(map_items) if isinstance(map_items, bool) else map_items
if isinstance(data, (list, tuple)) and map_items_ > 0:
return [
apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides)
for item in data
]
res: list[Any] = []
for item in data:
res_item = apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides)
# Only extend if we're at the leaf level (map_items_ == 1) and the transform
# actually returned a list (not preserving nested structure)
if isinstance(res_item, list) and map_items_ == 1 and not isinstance(item, (list, tuple)):
res.extend(res_item)
else:
res.append(res_item)
return res
return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
except Exception as e:
# if in debug mode, don't swallow exception so that the breakpoint
Expand Down Expand Up @@ -482,8 +488,7 @@ def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Iterable
yield (key,) + tuple(_ex_iters) if extra_iterables else key
elif not self.allow_missing_keys:
raise KeyError(
f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data"
" and allow_missing_keys==False."
f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data and allow_missing_keys==False."
)

def first_key(self, data: dict[Hashable, Any]):
Expand Down
28 changes: 28 additions & 0 deletions tests/transforms/compose/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,34 @@ def test_flatten_and_len(self):
def test_backwards_compatible_imports(self):
from monai.transforms.transform import MapTransform, RandomizableTransform, Transform # noqa: F401

def test_list_extend_multi_sample_trait(self):
center_crop = mt.CenterSpatialCrop([128, 128])
multi_sample_transform = mt.RandSpatialCropSamples([64, 64], 1)

img = torch.zeros([1, 512, 512])

self.assertEqual(execute_compose(img, [center_crop]).shape, torch.Size([1, 128, 128]))
single_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, center_crop])
self.assertIsInstance(single_multi_sample_trait_result, list)
self.assertEqual(len(single_multi_sample_trait_result), 1)
self.assertEqual(single_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64]))

double_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, multi_sample_transform, center_crop])
self.assertIsInstance(double_multi_sample_trait_result, list)
self.assertEqual(len(double_multi_sample_trait_result), 1)
self.assertEqual(double_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64]))

def test_multi_sample_trait_cardinality(self):
img = torch.zeros([1, 128, 128])
t2 = mt.RandSpatialCropSamples([32, 32], num_samples=2)

# chaining should multiply counts: 2 x 2 = 4, flattened
res = execute_compose(img, [t2, t2])
self.assertIsInstance(res, list)
self.assertEqual(len(res), 4)
for r in res:
self.assertEqual(r.shape, torch.Size([1, 32, 32]))


TEST_COMPOSE_EXECUTE_TEST_CASES = [
[None, tuple()],
Expand Down
Loading