Conversation
- adapted segmentation_loader example - refactored assert_equal_shapes test
|
This PR is still in draft mode - is the plan to finalise it at some point or is the code ready for review? |
There was a problem hiding this comment.
Pull Request Overview
This PR adds support for fuzzy segmentation-based volume creation, where tissue classes can have probabilistic assignments to voxels rather than hard labels. The implementation includes GPU acceleration and updates the example to demonstrate the new functionality.
Key changes:
- Added fuzzy segmentation mode to
SegmentationBasedAdapterwith automatic detection based on input dtype - Enabled GPU tensor operations throughout the adapter (removed CPU-only constraint)
- Refactored
assert_equal_shapesfor improved clarity and correctness
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
segmentation_loader.py |
Added fuzzy parameter and one-hot encoding + Gaussian smoothing for fuzzy segmentation example |
data_sanity_testing.py |
Simplified shape comparison logic for better readability and correctness |
segmentation_based_adapter.py |
Implemented fuzzy/hard segmentation detection and weighted volume assignment with GPU support |
Comments suppressed due to low confidence (1)
simpa/core/simulation_modules/volume_creation_module/segmentation_based_adapter.py:65
- The 3D map case (lines 62-65) is not updated to handle fuzzy segmentation. When
fuzzy=True, this code path will fail becausesegmentation_volume == seg_classdoesn't work on the 4D fuzzy volume. This branch should either be disabled for fuzzy mode or properly implemented with weighted assignment similar to the scalar case.
elif len(torch.Tensor.size(class_properties[volume_key])) == 3: # 3D map
assigned_prop = class_properties[volume_key][torch.tensor(segmentation_volume == seg_class)]
assigned_prop[assigned_prop is None] = torch.nan
volumes[volume_key][torch.tensor(segmentation_volume == seg_class)] = assigned_prop
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| parser = ArgumentParser(description='Run the segmentation loader example') | ||
| parser.add_argument("--spacing", default=1, type=float, help='the voxel spacing in mm') | ||
| parser.add_argument("--input_spacing", default=0.2, type=float, help='the input spacing in mm') | ||
| parser.add_argument("--fuzzy", default=False, type=bool, help='whether to use fuzzy segmentation adapter') |
There was a problem hiding this comment.
Using type=bool with argparse doesn't work as expected. Any non-empty string will be converted to True, so --fuzzy False will actually set fuzzy to True. Use action='store_true' instead, or implement a proper boolean parser like lambda x: x.lower() == 'true'.
| if fuzzy: | ||
| segmentation_volume_mask = np.eye(C)[segmentation_volume_mask] | ||
| segmentation_volume_mask = np.moveaxis(segmentation_volume_mask, -1, 0) | ||
| segmentation_volume_mask = gaussian_filter(segmentation_volume_mask, sigma=1e-5, axes=(1, 2, 3)) # smooth the segmentation |
There was a problem hiding this comment.
The Gaussian filter sigma of 1e-5 is extremely small and will have negligible smoothing effect. This may not achieve the intended purpose of creating a fuzzy segmentation. Consider using a larger sigma value (e.g., 1.0 or higher) to create meaningful fuzzy boundaries.
| segmentation_volume_mask = gaussian_filter(segmentation_volume_mask, sigma=1e-5, axes=(1, 2, 3)) # smooth the segmentation | |
| segmentation_volume_mask = gaussian_filter(segmentation_volume_mask, sigma=1.0, axes=(1, 2, 3)) # smooth the segmentation |
| if torch.is_floating_point(segmentation_volume): | ||
| assert len(segmentation_volume.shape) == 4 and segmentation_volume.shape[0] == len(class_mapping), \ | ||
| "Fuzzy segmentation must be a 4D array with the first dimension being the number of classes." | ||
| fuzzy = True | ||
| segmentation_classes = np.arange(segmentation_volume.shape[0]) | ||
|
|
||
| else: | ||
| assert len(segmentation_volume.shape) == 3, "Hard segmentations must be a 3D array." | ||
| fuzzy = False | ||
| segmentation_classes = torch.unique(segmentation_volume, return_counts=False).cpu().numpy() | ||
|
|
There was a problem hiding this comment.
The fuzzy segmentation detection relies on dtype, which could be fragile. A hard segmentation converted to float would be incorrectly detected as fuzzy. Consider adding an explicit parameter or checking both dtype and shape (e.g., len(shape) == 4) to make the detection more robust.
| if torch.is_floating_point(segmentation_volume): | |
| assert len(segmentation_volume.shape) == 4 and segmentation_volume.shape[0] == len(class_mapping), \ | |
| "Fuzzy segmentation must be a 4D array with the first dimension being the number of classes." | |
| fuzzy = True | |
| segmentation_classes = np.arange(segmentation_volume.shape[0]) | |
| else: | |
| assert len(segmentation_volume.shape) == 3, "Hard segmentations must be a 3D array." | |
| fuzzy = False | |
| segmentation_classes = torch.unique(segmentation_volume, return_counts=False).cpu().numpy() | |
| # Robustly detect fuzzy vs. hard segmentation by checking both shape and dtype | |
| if len(segmentation_volume.shape) == 4 and segmentation_volume.shape[0] == len(class_mapping): | |
| # Fuzzy segmentation: 4D array, first dim = num classes | |
| assert torch.is_floating_point(segmentation_volume), \ | |
| "Fuzzy segmentation must be a floating point 4D array with the first dimension being the number of classes." | |
| fuzzy = True | |
| segmentation_classes = np.arange(segmentation_volume.shape[0]) | |
| elif len(segmentation_volume.shape) == 3: | |
| # Hard segmentation: 3D array | |
| fuzzy = False | |
| segmentation_classes = torch.unique(segmentation_volume, return_counts=False).cpu().numpy() | |
| else: | |
| raise AssertionError("Segmentation must be either a 3D (hard) or 4D (fuzzy) array.") |
| assigned_prop = torch.nan | ||
| volumes[volume_key][segmentation_volume == seg_class] = assigned_prop | ||
| if fuzzy: | ||
| volumes[volume_key] += segmentation_volume[seg_class] * assigned_prop |
There was a problem hiding this comment.
The fuzzy segmentation implementation doesn't handle the DATA_FIELD_SEGMENTATION field, which represents the segmentation map itself. In fuzzy mode, it's unclear how this field should be populated (the PR description asks 'how to handle "seg" field for multiple classes per voxels?'). This should be documented or explicitly handled, potentially by storing the class with maximum probability or raising a warning that segmentation field is not meaningful in fuzzy mode.
Please check the following before creating the pull request (PR):
List any specific code review questions
Provide issue / feature request fixed by this PR
Fixes #399