From d000e182a876d0f2bca4dfc3d799f8395e0cadbe Mon Sep 17 00:00:00 2001 From: Sathiesh Date: Fri, 13 Feb 2026 15:14:57 +0100 Subject: [PATCH] feat: add TTA, EMA checkpointing, RandomAnisotropy; bump to 0.8.0 - vision_patch: nnU-Net-style mirror TTA with 8 flip combinations - vision_metrics: EMACheckpoint callback for robust model selection - vision_augmentation: RandomAnisotropy, suggest_patch_augmentations(), RandomFlip axes default to 'LRAPIS' - tutorial: update 12a_tutorial_patch_training with new features --- .../skills/fastmonai-upstream-guide/SKILL.md | 107 ++++- .gitignore | 1 + fastMONAI/__init__.py | 2 +- fastMONAI/_modidx.py | 22 +- fastMONAI/vision_augmentation.py | 101 ++++- fastMONAI/vision_metrics.py | 67 ++- fastMONAI/vision_patch.py | 84 +++- nbs/03_vision_augment.ipynb | 41 +- nbs/05_vision_metrics.ipynb | 24 +- nbs/10_vision_patch.ipynb | 396 +----------------- nbs/12a_tutorial_patch_training.ipynb | 75 +--- settings.ini | 2 +- 12 files changed, 423 insertions(+), 499 deletions(-) diff --git a/.claude/skills/fastmonai-upstream-guide/SKILL.md b/.claude/skills/fastmonai-upstream-guide/SKILL.md index 0e0bcf7..28a9a9b 100644 --- a/.claude/skills/fastmonai-upstream-guide/SKILL.md +++ b/.claude/skills/fastmonai-upstream-guide/SKILL.md @@ -41,9 +41,18 @@ Use this skill proactively when: - **Losses:** https://docs.monai.io/en/stable/losses.html - **Transforms:** https://docs.monai.io/en/stable/transforms.html -### nnU-Net -- **Repository:** https://github.com/MIC-DKFZ/nnUNet -- **Documentation:** https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/ +### nnU-Net v2 +- **Repository:** https://github.com/MIC-DKFZ/nnUNet (complete rewrite of v1, Apache 2.0) +- **Documentation index:** https://github.com/MIC-DKFZ/nnUNet/tree/master/documentation +- **How to use:** https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/how_to_use_nnunet.md +- **Dataset format:** https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/dataset_format.md +- **Extending nnU-Net:** https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/extending_nnunet.md +- **Normalization explained:** https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/explanation_normalization.md +- **Plans files explained:** https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/explanation_plans_files.md +- **Pretraining/finetuning:** https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/pretraining_and_finetuning.md +- **Region-based training:** https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/region_based_training.md +- **Ignore label:** https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/ignore_label.md +- **ResEnc presets:** https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/resenc_presets.md ## Source Code URLs (for implementation details) @@ -59,9 +68,21 @@ Use this skill proactively when: - **Hausdorff:** https://raw.githubusercontent.com/Project-MONAI/MONAI/dev/monai/metrics/hausdorff_distance.py - **Losses:** https://raw.githubusercontent.com/Project-MONAI/MONAI/dev/monai/losses/ -### nnU-Net GitHub -- **Preprocessing:** https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/preprocessing -- **Training:** https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/training +### nnU-Net v2 GitHub (source code) +- **Package root:** https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2 +- **Preprocessing root:** https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/preprocessing +- **Normalization schemes:** https://raw.githubusercontent.com/MIC-DKFZ/nnUNet/master/nnunetv2/preprocessing/normalization/default_normalization_schemes.py +- **Resampling:** https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/preprocessing/resampling +- **Training root:** https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/training +- **Base trainer:** https://raw.githubusercontent.com/MIC-DKFZ/nnUNet/master/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py +- **Trainer variants:** https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/training/nnUNetTrainer/variants +- **Dice loss:** https://raw.githubusercontent.com/MIC-DKFZ/nnUNet/master/nnunetv2/training/loss/dice.py +- **Compound losses (DC+CE, DC+BCE, DC+TopK):** https://raw.githubusercontent.com/MIC-DKFZ/nnUNet/master/nnunetv2/training/loss/compound_losses.py +- **Robust CE loss:** https://raw.githubusercontent.com/MIC-DKFZ/nnUNet/master/nnunetv2/training/loss/robust_ce_loss.py +- **Deep supervision loss:** https://raw.githubusercontent.com/MIC-DKFZ/nnUNet/master/nnunetv2/training/loss/deep_supervision.py +- **Evaluation metrics:** https://raw.githubusercontent.com/MIC-DKFZ/nnUNet/master/nnunetv2/evaluation/evaluate_predictions.py +- **Sliding window inference:** https://raw.githubusercontent.com/MIC-DKFZ/nnUNet/master/nnunetv2/inference/sliding_window_prediction.py +- **Predict from raw data:** https://raw.githubusercontent.com/MIC-DKFZ/nnUNet/master/nnunetv2/inference/predict_from_raw_data.py ## Workflow Steps @@ -70,7 +91,9 @@ Use this skill proactively when: - Metric → MONAI (compute functions) + nnU-Net (accumulated patterns) - Loss function → MONAI - Patch workflow → TorchIO (Queue, GridSampler) - - Preprocessing → nnU-Net conventions + - Preprocessing/normalization → nnU-Net v2 conventions + - Sliding window inference → nnU-Net v2 (Gaussian weighting, tile step size) + - Evaluation metrics → nnU-Net v2 (TP/FP/FN accumulation, Dice, IoU) 2. **Fetch relevant documentation** Use WebFetch to retrieve the appropriate documentation page from the URLs above. @@ -117,13 +140,45 @@ tio.CropOrPad(target_size) # Avoid - uses TorchIO default - TorchIO uses: [C, D, H, W] for single images - fastMONAI MedImage/MedMask: 4D tensors with affine tracking -### Accumulated Metrics (nnU-Net pattern) +### Accumulated Metrics (nnU-Net v2 pattern) For patch-based training, use accumulated TP/FP/FN rather than per-batch averaging: ```python class AccumulatedDice(Metric): # Accumulates across batches, not averages ``` +### nnU-Net v2 Quick Reference + +**Architecture:** The `nnunetv2/` package is organized into: preprocessing, training, inference, evaluation, experiment_planning, postprocessing, utilities. + +**Normalization schemes** (in `default_normalization_schemes.py`): +- `ZScoreNormalization` - z-score with optional foreground masking (default for MR) +- `CTNormalization` - clip to 0.5th-99.5th percentiles then z-score (for CT) +- `NoNormalization` - passthrough (dtype conversion only) +- `RescaleTo01Normalization` - min-max to [0,1] +- `RGBTo01Normalization` - divide uint8 by 255 + +**Loss functions** (in `training/loss/`): +- `SoftDiceLoss(apply_nonlin, batch_dice, do_bg, smooth=1.0, ddp, clip_tp)` - standard soft Dice +- `MemoryEfficientSoftDiceLoss` - same API, ~1.6GB less memory via no-grad one-hot encoding +- `DC_and_CE_loss(soft_dice_kwargs, ce_kwargs, weight_ce=1, weight_dice=1, ignore_label)` - Dice + Cross-Entropy (default nnU-Net loss) +- `DC_and_BCE_loss(bce_kwargs, soft_dice_kwargs, weight_ce=1, weight_dice=1)` - Dice + Binary CE +- `DC_and_topk_loss(soft_dice_kwargs, ce_kwargs, weight_ce=1, weight_dice=1)` - Dice + TopK + +**Evaluation metrics** (in `evaluation/evaluate_predictions.py`): +- Computes per-class: Dice, IoU, TP, FP, FN, TN +- `compute_tp_fp_fn_tn(mask_ref, mask_pred, ignore_mask)` - core confusion matrix +- `compute_metrics(reference_file, prediction_file, image_reader_writer, labels_or_regions, ignore_label)` - per-file metrics +- `compute_metrics_on_folder()` / `compute_metrics_on_folder_simple()` - batch evaluation with multiprocessing + +**Sliding window inference** (in `inference/sliding_window_prediction.py`): +- `compute_steps_for_sliding_window(image_size, tile_size, tile_step_size)` - tile_step_size is 0-1 (0.5 = 50% overlap) +- `compute_gaussian(tile_size)` - Gaussian importance weighting for smooth tile aggregation (center-weighted, no zeros) +- Supports test-time augmentation via mirror axes + +**Trainer variants** (in `training/nnUNetTrainer/variants/`): +Subdirectories for: benchmarking, data_augmentation, loss, lr_schedule, network_architecture, optimizer, sampling, training_length + ### nbdev Workflow - All code changes in `nbs/*.ipynb` - Run `nbdev_prepare` to regenerate .py files and run tests @@ -140,10 +195,19 @@ When implementing, fetch relevant docs/code: → Fetch MONAI metrics documentation **"How does nnU-Net handle validation metrics?"** -→ Fetch nnU-Net training code for accumulated patterns +→ Fetch nnU-Net v2 evaluation code for TP/FP/FN accumulation patterns + +**"What loss does nnU-Net v2 use by default?"** +→ Fetch compound_losses.py for DC_and_CE_loss (Dice + Cross-Entropy) + +**"How does nnU-Net v2 do sliding window inference?"** +→ Fetch sliding_window_prediction.py for tile stepping and Gaussian weighting + +**"What normalization does nnU-Net v2 use for MR images?"** +→ Fetch default_normalization_schemes.py for ZScoreNormalization **"I want to add a new loss function"** -→ Fetch MONAI losses documentation and source +→ Fetch MONAI losses documentation and nnU-Net v2 compound_losses.py ## Quick Reference Commands @@ -156,6 +220,27 @@ WebFetch: https://torchio.readthedocs.io/transforms/augmentation.html # MONAI metric implementation WebFetch: https://raw.githubusercontent.com/Project-MONAI/MONAI/dev/monai/metrics/meandice.py -# nnU-Net preprocessing patterns +# nnU-Net v2 usage guide WebFetch: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/how_to_use_nnunet.md + +# nnU-Net v2 normalization schemes (ZScore, CT, RescaleTo01, etc.) +WebFetch: https://raw.githubusercontent.com/MIC-DKFZ/nnUNet/master/nnunetv2/preprocessing/normalization/default_normalization_schemes.py + +# nnU-Net v2 default loss (Dice + CE compound) +WebFetch: https://raw.githubusercontent.com/MIC-DKFZ/nnUNet/master/nnunetv2/training/loss/compound_losses.py + +# nnU-Net v2 Dice loss implementation +WebFetch: https://raw.githubusercontent.com/MIC-DKFZ/nnUNet/master/nnunetv2/training/loss/dice.py + +# nnU-Net v2 evaluation metrics (Dice, IoU, TP/FP/FN) +WebFetch: https://raw.githubusercontent.com/MIC-DKFZ/nnUNet/master/nnunetv2/evaluation/evaluate_predictions.py + +# nnU-Net v2 sliding window inference (Gaussian weighting) +WebFetch: https://raw.githubusercontent.com/MIC-DKFZ/nnUNet/master/nnunetv2/inference/sliding_window_prediction.py + +# nnU-Net v2 base trainer (training loop, validation, scheduling) +WebFetch: https://raw.githubusercontent.com/MIC-DKFZ/nnUNet/master/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py + +# nnU-Net v2 extending guide (custom trainers, architectures) +WebFetch: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/extending_nnunet.md ``` diff --git a/.gitignore b/.gitignore index 7fe723b..db62944 100644 --- a/.gitignore +++ b/.gitignore @@ -170,6 +170,7 @@ _proc/ # MLflow mlruns/ +mlruns.db # Reports reports/ diff --git a/fastMONAI/__init__.py b/fastMONAI/__init__.py index 49e0fc1..777f190 100644 --- a/fastMONAI/__init__.py +++ b/fastMONAI/__init__.py @@ -1 +1 @@ -__version__ = "0.7.0" +__version__ = "0.8.0" diff --git a/fastMONAI/_modidx.py b/fastMONAI/_modidx.py index 7bad5ad..d290a71 100644 --- a/fastMONAI/_modidx.py +++ b/fastMONAI/_modidx.py @@ -163,6 +163,14 @@ 'fastMONAI/vision_augmentation.py'), 'fastMONAI.vision_augmentation.RandomAffine.__init__': ( 'vision_augment.html#randomaffine.__init__', 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation.RandomAnisotropy': ( 'vision_augment.html#randomanisotropy', + 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation.RandomAnisotropy.__init__': ( 'vision_augment.html#randomanisotropy.__init__', + 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation.RandomAnisotropy.encodes': ( 'vision_augment.html#randomanisotropy.encodes', + 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation.RandomAnisotropy.tio_transform': ( 'vision_augment.html#randomanisotropy.tio_transform', + 'fastMONAI/vision_augmentation.py'), 'fastMONAI.vision_augmentation.RandomBiasField': ( 'vision_augment.html#randombiasfield', 'fastMONAI/vision_augmentation.py'), 'fastMONAI.vision_augmentation.RandomBiasField.__init__': ( 'vision_augment.html#randombiasfield.__init__', @@ -284,7 +292,9 @@ 'fastMONAI.vision_augmentation._create_ellipsoid_mask': ( 'vision_augment.html#_create_ellipsoid_mask', 'fastMONAI/vision_augmentation.py'), 'fastMONAI.vision_augmentation.do_pad_or_crop': ( 'vision_augment.html#do_pad_or_crop', - 'fastMONAI/vision_augmentation.py')}, + 'fastMONAI/vision_augmentation.py'), + 'fastMONAI.vision_augmentation.suggest_patch_augmentations': ( 'vision_augment.html#suggest_patch_augmentations', + 'fastMONAI/vision_augmentation.py')}, 'fastMONAI.vision_core': { 'fastMONAI.vision_core.MedBase': ('vision_core.html#medbase', 'fastMONAI/vision_core.py'), 'fastMONAI.vision_core.MedBase.__copy__': ( 'vision_core.html#medbase.__copy__', 'fastMONAI/vision_core.py'), @@ -383,6 +393,14 @@ 'fastMONAI/vision_metrics.py'), 'fastMONAI.vision_metrics.AccumulatedMultiDice.value': ( 'vision_metrics.html#accumulatedmultidice.value', 'fastMONAI/vision_metrics.py'), + 'fastMONAI.vision_metrics.EMACheckpoint': ( 'vision_metrics.html#emacheckpoint', + 'fastMONAI/vision_metrics.py'), + 'fastMONAI.vision_metrics.EMACheckpoint.__init__': ( 'vision_metrics.html#emacheckpoint.__init__', + 'fastMONAI/vision_metrics.py'), + 'fastMONAI.vision_metrics.EMACheckpoint.after_epoch': ( 'vision_metrics.html#emacheckpoint.after_epoch', + 'fastMONAI/vision_metrics.py'), + 'fastMONAI.vision_metrics.EMACheckpoint.before_fit': ( 'vision_metrics.html#emacheckpoint.before_fit', + 'fastMONAI/vision_metrics.py'), 'fastMONAI.vision_metrics.binary_dice_score': ( 'vision_metrics.html#binary_dice_score', 'fastMONAI/vision_metrics.py'), 'fastMONAI.vision_metrics.binary_hausdorff_distance': ( 'vision_metrics.html#binary_hausdorff_distance', @@ -513,6 +531,8 @@ 'fastMONAI/vision_patch.py'), 'fastMONAI.vision_patch._normalize_patch_overlap': ( 'vision_patch.html#_normalize_patch_overlap', 'fastMONAI/vision_patch.py'), + 'fastMONAI.vision_patch._predict_patch_tta': ( 'vision_patch.html#_predict_patch_tta', + 'fastMONAI/vision_patch.py'), 'fastMONAI.vision_patch._warn_config_override': ( 'vision_patch.html#_warn_config_override', 'fastMONAI/vision_patch.py'), 'fastMONAI.vision_patch.create_patch_sampler': ( 'vision_patch.html#create_patch_sampler', diff --git a/fastMONAI/vision_augmentation.py b/fastMONAI/vision_augmentation.py index 3e95607..520863d 100644 --- a/fastMONAI/vision_augmentation.py +++ b/fastMONAI/vision_augmentation.py @@ -3,8 +3,8 @@ # %% auto #0 __all__ = ['CustomDictTransform', 'do_pad_or_crop', 'PadOrCrop', 'ZNormalization', 'RescaleIntensity', 'NormalizeIntensity', 'BraTSMaskConverter', 'BinaryConverter', 'RandomGhosting', 'RandomSpike', 'RandomNoise', 'RandomBiasField', - 'RandomBlur', 'RandomGamma', 'RandomIntensityScale', 'RandomMotion', 'RandomCutout', - 'RandomElasticDeformation', 'RandomAffine', 'RandomFlip', 'OneOf'] + 'RandomBlur', 'RandomGamma', 'RandomIntensityScale', 'RandomMotion', 'RandomAnisotropy', 'RandomCutout', + 'RandomElasticDeformation', 'RandomAffine', 'RandomFlip', 'OneOf', 'suggest_patch_augmentations'] # %% ../nbs/03_vision_augment.ipynb #2d6694aa from fastai.data.all import * @@ -480,6 +480,30 @@ def encodes(self, o: MedImage): def encodes(self, o: MedMask): return o +# %% ../nbs/03_vision_augment.ipynb #cl3ei8hm3z9 +class RandomAnisotropy(DisplayedTransform): + '''Apply TorchIO `RandomAnisotropy`.''' + + split_idx, order = 0, 1 + + def __init__(self, axes=(0, 1, 2), downsampling=(1.5, 5), + image_interpolation='linear', scalars_only=True, p=0.5): + self.add_anisotropy = tio.RandomAnisotropy( + axes=axes, downsampling=downsampling, + image_interpolation=image_interpolation, + scalars_only=scalars_only, p=p) + + @property + def tio_transform(self): + """Return the underlying TorchIO transform.""" + return self.add_anisotropy + + def encodes(self, o: MedImage): + return MedImage.create(self.add_anisotropy(o)) + + def encodes(self, o: MedMask): + return o + # %% ../nbs/03_vision_augment.ipynb #e7ea6486 def _create_ellipsoid_mask(shape, center, radii): """Create a 3D ellipsoid mask. @@ -750,7 +774,7 @@ def __init__(self, scales=0, degrees=10, translation=0, isotropic=False, class RandomFlip(CustomDictTransform): """Apply TorchIO `RandomFlip`.""" - def __init__(self, axes='LR', p=0.5): + def __init__(self, axes='LRAPIS', p=0.5): super().__init__(tio.RandomFlip(axes=axes, flip_probability=p)) # %% ../nbs/03_vision_augment.ipynb #ddd7b99b @@ -759,3 +783,74 @@ class OneOf(CustomDictTransform): def __init__(self, transform_dict, p=1): super().__init__(tio.OneOf(transform_dict, p=p)) + +# %% ../nbs/03_vision_augment.ipynb #t6hak044rc +def suggest_patch_augmentations(patch_size, target_spacing, + anisotropy_threshold=3.0, + translation_fraction=0.15): + """Suggest patch-based augmentations with nnU-Net-inspired defaults. + + Derives rotation degrees, translation, and RandomAnisotropy axes from + patch geometry and voxel spacing. Returns a list of fastMONAI transform + instances ready for the ``patch_tfms`` parameter in MedPatchDataLoaders. + + Anisotropy detection: if max(spacing)/min(spacing) >= threshold, rotation + is restricted to 5 deg out-of-plane and 30 deg in-plane. Otherwise 30 deg + symmetric. Translation is patch_size * fraction per axis. + + Args: + patch_size: List/tuple of 3 ints -- patch dimensions. + target_spacing: List/tuple of 3 floats -- voxel spacing. + anisotropy_threshold: Ratio threshold for anisotropy detection (default 3.0). + translation_fraction: Fraction of patch_size for translation (default 0.15). + + Returns: + list: fastMONAI transform instances (7 normally, 6 if RandomAnisotropy omitted). + + Example:: + + >>> patch_tfms = suggest_patch_augmentations([128, 128, 32], [0.5, 0.5, 1.5]) + >>> dls = MedPatchDataLoaders.from_config(..., patch_tfms=patch_tfms) + """ + if len(patch_size) != 3: + raise ValueError(f"patch_size must have 3 elements, got {len(patch_size)}") + if len(target_spacing) != 3: + raise ValueError(f"target_spacing must have 3 elements, got {len(target_spacing)}") + + # Determine anisotropy + spacing = list(target_spacing) + ratio = max(spacing) / min(spacing) + is_aniso = ratio >= anisotropy_threshold + aniso_axis = spacing.index(max(spacing)) if is_aniso else None + + # Rotation degrees + if is_aniso: + degrees = [5, 5, 5] + degrees[aniso_axis] = 30 + degrees = tuple(degrees) + else: + degrees = 30 + + # Translation + translation = tuple(round(p * translation_fraction) for p in patch_size) + + # RandomAnisotropy axes: all axes where patch_size > 1 + aniso_axes = tuple(i for i in range(3) if patch_size[i] > 1) + + transforms = [ + RandomAffine(scales=(0.7, 1.4), degrees=degrees, translation=translation, + default_pad_value=0., p=0.2), + ] + + if len(aniso_axes) > 0: + transforms.append(RandomAnisotropy(axes=aniso_axes, downsampling=(1.5, 4), p=0.25)) + + transforms.extend([ + RandomGamma(log_gamma=(-0.3, 0.3), p=0.3), + RandomIntensityScale(scale_range=(0.75, 1.25), p=0.1), + RandomNoise(std=0.1, p=0.1), + RandomBlur(std=(0.5, 1.0), p=0.2), + RandomFlip(p=0.5), + ]) + + return transforms diff --git a/fastMONAI/vision_metrics.py b/fastMONAI/vision_metrics.py index d9aba1a..cf4c609 100644 --- a/fastMONAI/vision_metrics.py +++ b/fastMONAI/vision_metrics.py @@ -5,7 +5,7 @@ 'multi_hausdorff_distance', 'calculate_confusion_metrics', 'binary_sensitivity', 'multi_sensitivity', 'binary_precision', 'multi_precision', 'calculate_lesion_detection_rate', 'binary_lesion_detection_rate', 'multi_lesion_detection_rate', 'calculate_signed_rve', 'binary_signed_rve', 'multi_signed_rve', - 'AccumulatedDice', 'AccumulatedMultiDice'] + 'AccumulatedDice', 'AccumulatedMultiDice', 'EMACheckpoint'] # %% ../nbs/05_vision_metrics.ipynb #8b6a83ac import torch @@ -14,6 +14,7 @@ from scipy.ndimage import label as scipy_label from .vision_data import pred_to_binary_mask, batch_pred_to_multiclass_mask from fastai.learner import Metric +from fastai.callback.tracker import TrackerCallback # %% ../nbs/05_vision_metrics.ipynb #5c16dc6c-e07a-44b5-85af-79bd6c4ce390 def calculate_dsc(pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: @@ -499,3 +500,67 @@ def value(self): @property def name(self): return 'accumulated_multi_dice' + +# %% ../nbs/05_vision_metrics.ipynb #mwwwrlj5sm +class EMACheckpoint(TrackerCallback): + """Save model checkpoint based on EMA of a monitored metric (nnU-Net style). + + Instead of saving the best model based on a single (noisy) epoch metric, + this tracks the exponential moving average and saves when the EMA improves. + More robust for patch-based training where per-epoch metrics fluctuate. + + Formula: ema = momentum * previous_ema + (1 - momentum) * current_value + + Unlike SaveModelCallback, this does NOT auto-load the best model after + training. Load explicitly with ``learn.load(fname)``. + + Args: + monitor: Metric name to track (default: 'accumulated_dice'). + momentum: EMA momentum (default: 0.9, matching nnU-Net). + Higher momentum = more smoothing. Range: (0, 1). + nnU-Net uses 0.9 (keeps 90% of history, adds 10% of current epoch). + comp: Comparison function (default: np.greater for higher-is-better). + fname: Model save filename (default: 'best_model'). + with_opt: Whether to save optimizer state (default: False). + + Example: + ```python + save_best = EMACheckpoint( + monitor='accumulated_dice', + momentum=0.9, + fname='best_model' + ) + learn.fit_one_cycle(30, lr, cbs=[save_best]) + + # Load best model after training: + learn.load('best_model') + + # Access EMA history for plotting: + save_best.ema_history + ``` + """ + order = 60 # Same priority as SaveModelCallback + + def __init__(self, monitor='accumulated_dice', momentum=0.9, comp=np.greater, + fname='best_model', with_opt=False): + super().__init__(monitor=monitor, comp=comp) + self.fname = fname + self.with_opt = with_opt + self.momentum = momentum + + def before_fit(self): + super().before_fit() # Sets self.idx, self.best via TrackerCallback + self.ema_value = None + self.ema_history = [] + + def after_epoch(self): + val = self.recorder.values[-1][self.idx] + if isinstance(val, torch.Tensor): val = val.item() + + self.ema_value = val if self.ema_value is None else ( + self.momentum * self.ema_value + (1 - self.momentum) * val) + self.ema_history.append(self.ema_value) + + if self.comp(self.ema_value, self.best): + self.best = self.ema_value + self.learn.save(self.fname, with_opt=self.with_opt) diff --git a/fastMONAI/vision_patch.py b/fastMONAI/vision_patch.py index 60706e4..79e960b 100644 --- a/fastMONAI/vision_patch.py +++ b/fastMONAI/vision_patch.py @@ -990,6 +990,45 @@ def _normalize_patch_overlap(patch_overlap, patch_size): return tuple(result) +# nnU-Net-style mirror TTA: all 2^3 = 8 flip combinations for 3D. +# Batch tensor shape: [B, C, D, H, W], spatial dims are 2, 3, 4. +_TTA_FLIP_AXES = ( + (), # original + (4,), # flip LR (W) + (3,), # flip AP (H) + (2,), # flip IS (D) + (3, 4), # flip LR+AP + (2, 4), # flip LR+IS + (2, 3), # flip AP+IS + (2, 3, 4), # flip all +) + + +def _predict_patch_tta(model, patch_input): + """nnU-Net-style mirror TTA: average probabilities over 8 flip combinations. + + Runs 8 forward passes with a running sum for memory efficiency (2x memory, + not 9x). Each pass: flip input -> forward -> activate -> flip back -> accumulate. + + Args: + model: PyTorch model in eval mode (already on device). + patch_input: Batch tensor [B, C, D, H, W] already on device. + + Returns: + Averaged probability tensor [B, C, D, H, W] on CPU. + """ + summed_probs = None + for axes in _TTA_FLIP_AXES: + flipped = torch.flip(patch_input, list(axes)) if axes else patch_input + logits = model(flipped) + n_classes = logits.shape[1] + probs = torch.sigmoid(logits) if n_classes == 1 else torch.softmax(logits, dim=1) + if axes: + probs = torch.flip(probs, list(axes)) + summed_probs = probs if summed_probs is None else summed_probs + probs + return (summed_probs / len(_TTA_FLIP_AXES)).cpu() + + class PatchInferenceEngine: """Patch-based inference with automatic volume reconstruction. @@ -1067,7 +1106,8 @@ def predict( self, img_path: Path | str, return_probabilities: bool = False, - return_affine: bool = False + return_affine: bool = False, + tta: bool = False ) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]: """Predict on a single volume using patch-based inference. @@ -1075,6 +1115,10 @@ def predict( img_path: Path to input image. return_probabilities: If True, return probability map instead of argmax. return_affine: If True, return (prediction, affine) tuple instead of just prediction. + tta: If True, apply nnU-Net-style mirror test-time augmentation + (8 flip combinations, averaged probabilities). Requires ~8x inference + time but improves prediction quality. Works best when training used + RandomFlip(axes='LRAPIS', p=0.5). Defaults to False. Returns: Predicted segmentation mask tensor, or tuple (prediction, affine) if return_affine=True. @@ -1141,20 +1185,25 @@ def predict( patch_input = patches_batch['image'][tio.DATA].to(self._device) locations = patches_batch[tio.LOCATION] - # Forward pass - get logits - logits = self.model(patch_input) - - # Convert logits to probabilities BEFORE aggregation - # This is critical: softmax is non-linear, so we must aggregate - # probabilities, not logits, to get correct boundary handling - n_classes = logits.shape[1] - if n_classes == 1: - probs = torch.sigmoid(logits) + if tta: + probs = _predict_patch_tta(self.model, patch_input) else: - probs = torch.softmax(logits, dim=1) # dim=1 for batch [B, C, H, W, D] + # Forward pass - get logits + logits = self.model(patch_input) + + # Convert logits to probabilities BEFORE aggregation + # This is critical: softmax is non-linear, so we must aggregate + # probabilities, not logits, to get correct boundary handling + n_classes = logits.shape[1] + if n_classes == 1: + probs = torch.sigmoid(logits) + else: + probs = torch.softmax(logits, dim=1) # dim=1 for batch [B, C, H, W, D] + + probs = probs.cpu() # Add probabilities to aggregator - aggregator.add_batch(probs.cpu(), locations) + aggregator.add_batch(probs, locations) # Get reconstructed output (now contains probabilities, not logits) output = aggregator.get_output_tensor() @@ -1233,7 +1282,8 @@ def patch_inference( return_probabilities: bool = False, progress: bool = True, save_dir: str = None, - pre_inference_tfms: list = None + pre_inference_tfms: list = None, + tta: bool = False ) -> list: """Batch patch-based inference on multiple volumes. @@ -1250,6 +1300,7 @@ def patch_inference( save_dir: Directory to save predictions as NIfTI files. If None, predictions are not saved. pre_inference_tfms: List of TorchIO transforms to apply before patch extraction. IMPORTANT: Should match the pre_patch_tfms used during training (e.g., [tio.ZNormalization()]). + tta: If True, apply nnU-Net-style mirror TTA (8 flip combinations). Returns: List of predicted tensors. @@ -1282,14 +1333,15 @@ def patch_inference( save_path.mkdir(parents=True, exist_ok=True) predictions = [] - iterator = tqdm(file_paths, desc='Patch inference') if progress else file_paths + desc = 'Patch inference (TTA)' if tta else 'Patch inference' + iterator = tqdm(file_paths, desc=desc) if progress else file_paths for path in iterator: # Get prediction and affine when saving is needed if save_dir is not None: - pred, affine = engine.predict(path, return_probabilities, return_affine=True) + pred, affine = engine.predict(path, return_probabilities, return_affine=True, tta=tta) else: - pred = engine.predict(path, return_probabilities) + pred = engine.predict(path, return_probabilities, tta=tta) predictions.append(pred) # Save prediction if save_dir specified diff --git a/nbs/03_vision_augment.ipynb b/nbs/03_vision_augment.ipynb index a30da55..d1fb7ad 100644 --- a/nbs/03_vision_augment.ipynb +++ b/nbs/03_vision_augment.ipynb @@ -440,6 +440,14 @@ "outputs": [], "source": "#| export\nclass RandomMotion(DisplayedTransform):\n \"\"\"Apply TorchIO `RandomMotion`.\"\"\"\n\n split_idx, order = 0, 1\n\n def __init__(\n self, \n degrees=10, \n translation=10, \n num_transforms=2, \n image_interpolation='linear', \n p=0.5\n ):\n self.add_motion = tio.RandomMotion(\n degrees=degrees, \n translation=translation, \n num_transforms=num_transforms, \n image_interpolation=image_interpolation, \n p=p\n )\n\n @property\n def tio_transform(self):\n \"\"\"Return the underlying TorchIO transform.\"\"\"\n return self.add_motion\n\n def encodes(self, o: MedImage):\n return MedImage.create(self.add_motion(o))\n\n def encodes(self, o: MedMask):\n return o" }, + { + "cell_type": "code", + "execution_count": null, + "id": "cl3ei8hm3z9", + "metadata": {}, + "outputs": [], + "source": "#| export\nclass RandomAnisotropy(DisplayedTransform):\n '''Apply TorchIO `RandomAnisotropy`.'''\n\n split_idx, order = 0, 1\n\n def __init__(self, axes=(0, 1, 2), downsampling=(1.5, 5),\n image_interpolation='linear', scalars_only=True, p=0.5):\n self.add_anisotropy = tio.RandomAnisotropy(\n axes=axes, downsampling=downsampling,\n image_interpolation=image_interpolation,\n scalars_only=scalars_only, p=p)\n\n @property\n def tio_transform(self):\n \"\"\"Return the underlying TorchIO transform.\"\"\"\n return self.add_anisotropy\n\n def encodes(self, o: MedImage):\n return MedImage.create(self.add_anisotropy(o))\n\n def encodes(self, o: MedMask):\n return o" + }, { "cell_type": "code", "execution_count": null, @@ -760,14 +768,7 @@ "id": "022c90cf", "metadata": {}, "outputs": [], - "source": [ - "# | export\n", - "class RandomFlip(CustomDictTransform):\n", - " \"\"\"Apply TorchIO `RandomFlip`.\"\"\"\n", - "\n", - " def __init__(self, axes='LR', p=0.5):\n", - " super().__init__(tio.RandomFlip(axes=axes, flip_probability=p))" - ] + "source": "# | export\nclass RandomFlip(CustomDictTransform):\n \"\"\"Apply TorchIO `RandomFlip`.\"\"\"\n\n def __init__(self, axes='LRAPIS', p=0.5):\n super().__init__(tio.RandomFlip(axes=axes, flip_probability=p))" }, { "cell_type": "code", @@ -784,13 +785,35 @@ " super().__init__(tio.OneOf(transform_dict, p=p))" ] }, + { + "cell_type": "markdown", + "id": "3sz0qzsue2e", + "metadata": {}, + "source": "## Augmentation suggestion" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "t6hak044rc", + "metadata": {}, + "outputs": [], + "source": "#| export\ndef suggest_patch_augmentations(patch_size, target_spacing,\n anisotropy_threshold=3.0,\n translation_fraction=0.15):\n \"\"\"Suggest patch-based augmentations with nnU-Net-inspired defaults.\n\n Derives rotation degrees, translation, and RandomAnisotropy axes from\n patch geometry and voxel spacing. Returns a list of fastMONAI transform\n instances ready for the ``patch_tfms`` parameter in MedPatchDataLoaders.\n\n Anisotropy detection: if max(spacing)/min(spacing) >= threshold, rotation\n is restricted to 5 deg out-of-plane and 30 deg in-plane. Otherwise 30 deg\n symmetric. Translation is patch_size * fraction per axis.\n\n Args:\n patch_size: List/tuple of 3 ints -- patch dimensions.\n target_spacing: List/tuple of 3 floats -- voxel spacing.\n anisotropy_threshold: Ratio threshold for anisotropy detection (default 3.0).\n translation_fraction: Fraction of patch_size for translation (default 0.15).\n\n Returns:\n list: fastMONAI transform instances (7 normally, 6 if RandomAnisotropy omitted).\n\n Example::\n\n >>> patch_tfms = suggest_patch_augmentations([128, 128, 32], [0.5, 0.5, 1.5])\n >>> dls = MedPatchDataLoaders.from_config(..., patch_tfms=patch_tfms)\n \"\"\"\n if len(patch_size) != 3:\n raise ValueError(f\"patch_size must have 3 elements, got {len(patch_size)}\")\n if len(target_spacing) != 3:\n raise ValueError(f\"target_spacing must have 3 elements, got {len(target_spacing)}\")\n\n # Determine anisotropy\n spacing = list(target_spacing)\n ratio = max(spacing) / min(spacing)\n is_aniso = ratio >= anisotropy_threshold\n aniso_axis = spacing.index(max(spacing)) if is_aniso else None\n\n # Rotation degrees\n if is_aniso:\n degrees = [5, 5, 5]\n degrees[aniso_axis] = 30\n degrees = tuple(degrees)\n else:\n degrees = 30\n\n # Translation\n translation = tuple(round(p * translation_fraction) for p in patch_size)\n\n # RandomAnisotropy axes: all axes where patch_size > 1\n aniso_axes = tuple(i for i in range(3) if patch_size[i] > 1)\n\n transforms = [\n RandomAffine(scales=(0.7, 1.4), degrees=degrees, translation=translation,\n default_pad_value=0., p=0.2),\n ]\n\n if len(aniso_axes) > 0:\n transforms.append(RandomAnisotropy(axes=aniso_axes, downsampling=(1.5, 4), p=0.25))\n\n transforms.extend([\n RandomGamma(log_gamma=(-0.3, 0.3), p=0.3),\n RandomIntensityScale(scale_range=(0.75, 1.25), p=0.1),\n RandomNoise(std=0.1, p=0.1),\n RandomBlur(std=(0.5, 1.0), p=0.2),\n RandomFlip(p=0.5),\n ])\n\n return transforms" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eu5z2yomz9j", + "metadata": {}, + "outputs": [], + "source": "from fastcore.test import test_eq, test_fail\n\n# Isotropic case\ntfms = suggest_patch_augmentations([128, 128, 128], [1.0, 1.0, 1.0])\ntest_eq(len(tfms), 7)\ntest_eq(type(tfms[0]), RandomAffine)\ntest_eq(type(tfms[-1]), RandomFlip)\n\n# Anisotropic case (axis 2 thick): degrees=(5, 5, 30) -> (-5, 5, -5, 5, -30, 30)\ntfms = suggest_patch_augmentations([128, 128, 32], target_spacing=[0.5, 0.5, 1.5])\ntest_eq(len(tfms), 7)\naff = tfms[0].tio_transform\ntest_eq(aff.degrees, (-5, 5, -5, 5, -30, 30))\n\n# Anisotropic case (axis 0 thick): degrees=(30, 5, 5) -> (-30, 30, -5, 5, -5, 5)\ntfms = suggest_patch_augmentations([32, 128, 128], target_spacing=[3.0, 0.5, 0.5])\naff = tfms[0].tio_transform\ntest_eq(aff.degrees, (-30, 30, -5, 5, -5, 5))\n\n# Isotropic spacing -> symmetric degrees: 30 -> (-30, 30, -30, 30, -30, 30)\ntfms = suggest_patch_augmentations([64, 64, 64], [1.0, 1.0, 1.0])\naff = tfms[0].tio_transform\ntest_eq(aff.degrees, (-30, 30, -30, 30, -30, 30))\n\n# 2D-like patch [128, 128, 1]\ntfms = suggest_patch_augmentations([128, 128, 1], [1.0, 1.0, 1.0])\naniso_tfm = tfms[1]\ntest_eq(type(aniso_tfm), RandomAnisotropy)\ntest_eq(aniso_tfm.add_anisotropy.axes, (0, 1))\n\n# All dims 1 -> RandomAnisotropy omitted\ntfms = suggest_patch_augmentations([1, 1, 1], [1.0, 1.0, 1.0])\ntest_eq(len(tfms), 6)\ntest_eq(all(not isinstance(t, RandomAnisotropy) for t in tfms), True)\n\n# Wrong input lengths\ntest_fail(lambda: suggest_patch_augmentations([128, 128], [1.0, 1.0, 1.0]))\ntest_fail(lambda: suggest_patch_augmentations([128, 128, 128], [1.0, 1.0]))\n\n# All returned transforms have .tio_transform\ntfms = suggest_patch_augmentations([128, 128, 64], [1.0, 1.0, 1.0])\nfor t in tfms:\n assert hasattr(t, 'tio_transform'), f\"{type(t).__name__} missing .tio_transform\"" + }, { "cell_type": "code", "execution_count": null, "id": "5117c50a", "metadata": {}, "outputs": [], - "source": "# Test .tio_transform property\n# CustomDictTransform-based wrappers\ntest_eq(type(RandomAffine(degrees=10).tio_transform), tio.RandomAffine)\ntest_eq(type(RandomFlip(p=0.5).tio_transform), tio.RandomFlip)\ntest_eq(type(RandomElasticDeformation(p=0.5).tio_transform), tio.RandomElasticDeformation)\n\n# DisplayedTransform-based wrappers\ntest_eq(type(PadOrCrop([64, 64, 64]).tio_transform), tio.CropOrPad)\ntest_eq(type(ZNormalization().tio_transform), tio.ZNormalization)\ntest_eq(type(RescaleIntensity((-1, 1), (-1000, 1000)).tio_transform), tio.RescaleIntensity)\ntest_eq(type(RandomGamma(p=0.5).tio_transform), tio.RandomGamma)\ntest_eq(type(RandomNoise(p=0.5).tio_transform), tio.RandomNoise)\ntest_eq(type(RandomBiasField(p=0.5).tio_transform), tio.RandomBiasField)\ntest_eq(type(RandomBlur(p=0.5).tio_transform), tio.RandomBlur)\ntest_eq(type(RandomGhosting(p=0.5).tio_transform), tio.RandomGhosting)\ntest_eq(type(RandomSpike(p=0.5).tio_transform), tio.RandomSpike)\ntest_eq(type(RandomMotion(p=0.5).tio_transform), tio.RandomMotion)\n\n# Custom TorchIO wrappers (isinstance check since these are custom subclasses)\ntest_eq(isinstance(RandomIntensityScale(p=0.5).tio_transform, tio.IntensityTransform), True)\ntest_eq(isinstance(NormalizeIntensity().tio_transform, tio.IntensityTransform), True)" + "source": "# Test .tio_transform property\n# CustomDictTransform-based wrappers\ntest_eq(type(RandomAffine(degrees=10).tio_transform), tio.RandomAffine)\ntest_eq(type(RandomFlip(p=0.5).tio_transform), tio.RandomFlip)\ntest_eq(type(RandomElasticDeformation(p=0.5).tio_transform), tio.RandomElasticDeformation)\n\n# DisplayedTransform-based wrappers\ntest_eq(type(PadOrCrop([64, 64, 64]).tio_transform), tio.CropOrPad)\ntest_eq(type(ZNormalization().tio_transform), tio.ZNormalization)\ntest_eq(type(RescaleIntensity((-1, 1), (-1000, 1000)).tio_transform), tio.RescaleIntensity)\ntest_eq(type(RandomGamma(p=0.5).tio_transform), tio.RandomGamma)\ntest_eq(type(RandomNoise(p=0.5).tio_transform), tio.RandomNoise)\ntest_eq(type(RandomBiasField(p=0.5).tio_transform), tio.RandomBiasField)\ntest_eq(type(RandomBlur(p=0.5).tio_transform), tio.RandomBlur)\ntest_eq(type(RandomGhosting(p=0.5).tio_transform), tio.RandomGhosting)\ntest_eq(type(RandomSpike(p=0.5).tio_transform), tio.RandomSpike)\ntest_eq(type(RandomMotion(p=0.5).tio_transform), tio.RandomMotion)\ntest_eq(type(RandomAnisotropy(p=0.5).tio_transform), tio.RandomAnisotropy)\n\n# Custom TorchIO wrappers (isinstance check since these are custom subclasses)\ntest_eq(isinstance(RandomIntensityScale(p=0.5).tio_transform, tio.IntensityTransform), True)\ntest_eq(isinstance(NormalizeIntensity().tio_transform, tio.IntensityTransform), True)" }, { "cell_type": "code", diff --git a/nbs/05_vision_metrics.ipynb b/nbs/05_vision_metrics.ipynb index a02757b..fdbc57f 100644 --- a/nbs/05_vision_metrics.ipynb +++ b/nbs/05_vision_metrics.ipynb @@ -16,7 +16,7 @@ "id": "8b6a83ac", "metadata": {}, "outputs": [], - "source": "#| export\nimport torch\nimport numpy as np\nfrom monai.metrics import compute_hausdorff_distance, compute_dice, get_confusion_matrix, compute_confusion_matrix_metric\nfrom scipy.ndimage import label as scipy_label\nfrom fastMONAI.vision_data import pred_to_binary_mask, batch_pred_to_multiclass_mask\nfrom fastai.learner import Metric" + "source": "#| export\nimport torch\nimport numpy as np\nfrom monai.metrics import compute_hausdorff_distance, compute_dice, get_confusion_matrix, compute_confusion_matrix_metric\nfrom scipy.ndimage import label as scipy_label\nfrom fastMONAI.vision_data import pred_to_binary_mask, batch_pred_to_multiclass_mask\nfrom fastai.learner import Metric\nfrom fastai.callback.tracker import TrackerCallback" }, { "cell_type": "markdown", @@ -203,6 +203,28 @@ "outputs": [], "source": "#| export\nclass AccumulatedMultiDice(AccumulatedDice):\n \"\"\"Multi-class version of AccumulatedDice that returns per-class Dice scores.\n\n Instead of returning a single mean Dice, this returns a tensor with the\n Dice score for each foreground class. Useful for monitoring per-class\n performance during training.\n\n Example:\n ```python\n # For 3-class segmentation (background + 2 foreground classes)\n learn = Learner(dls, model, loss_func=loss_func,\n metrics=[AccumulatedMultiDice(n_classes=3)])\n ```\n \"\"\"\n @property\n def value(self):\n \"\"\"Return per-class Dice scores.\"\"\"\n dice = 2 * self.tp / (2 * self.tp + self.fp + self.fn + 1e-8)\n return dice # Returns tensor, fastai will display all values\n\n @property\n def name(self):\n return 'accumulated_multi_dice'" }, + { + "cell_type": "markdown", + "id": "lgp2x0jxbh9", + "metadata": {}, + "source": "## EMA Model Checkpoint\n\nExponential Moving Average (EMA) based model selection, inspired by nnU-Net's approach\nof using smoothed Dice scores rather than noisy per-epoch values for checkpoint selection." + }, + { + "cell_type": "code", + "execution_count": null, + "id": "mwwwrlj5sm", + "metadata": {}, + "outputs": [], + "source": "#| export\nclass EMACheckpoint(TrackerCallback):\n \"\"\"Save model checkpoint based on EMA of a monitored metric (nnU-Net style).\n\n Instead of saving the best model based on a single (noisy) epoch metric,\n this tracks the exponential moving average and saves when the EMA improves.\n More robust for patch-based training where per-epoch metrics fluctuate.\n\n Formula: ema = momentum * previous_ema + (1 - momentum) * current_value\n\n Unlike SaveModelCallback, this does NOT auto-load the best model after\n training. Load explicitly with ``learn.load(fname)``.\n\n Args:\n monitor: Metric name to track (default: 'accumulated_dice').\n momentum: EMA momentum (default: 0.9, matching nnU-Net).\n Higher momentum = more smoothing. Range: (0, 1).\n nnU-Net uses 0.9 (keeps 90% of history, adds 10% of current epoch).\n comp: Comparison function (default: np.greater for higher-is-better).\n fname: Model save filename (default: 'best_model').\n with_opt: Whether to save optimizer state (default: False).\n\n Example:\n ```python\n save_best = EMACheckpoint(\n monitor='accumulated_dice',\n momentum=0.9,\n fname='best_model'\n )\n learn.fit_one_cycle(30, lr, cbs=[save_best])\n\n # Load best model after training:\n learn.load('best_model')\n\n # Access EMA history for plotting:\n save_best.ema_history\n ```\n \"\"\"\n order = 60 # Same priority as SaveModelCallback\n\n def __init__(self, monitor='accumulated_dice', momentum=0.9, comp=np.greater,\n fname='best_model', with_opt=False):\n super().__init__(monitor=monitor, comp=comp)\n self.fname = fname\n self.with_opt = with_opt\n self.momentum = momentum\n\n def before_fit(self):\n super().before_fit() # Sets self.idx, self.best via TrackerCallback\n self.ema_value = None\n self.ema_history = []\n\n def after_epoch(self):\n val = self.recorder.values[-1][self.idx]\n if isinstance(val, torch.Tensor): val = val.item()\n\n self.ema_value = val if self.ema_value is None else (\n self.momentum * self.ema_value + (1 - self.momentum) * val)\n self.ema_history.append(self.ema_value)\n\n if self.comp(self.ema_value, self.best):\n self.best = self.ema_value\n self.learn.save(self.fname, with_opt=self.with_opt)" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "s71194eyra", + "metadata": {}, + "outputs": [], + "source": "#| hide\n# Test EMACheckpoint EMA calculation logic\nfrom unittest.mock import MagicMock\nfrom fastcore.test import test_eq\n\ncb = EMACheckpoint(monitor='accumulated_dice', momentum=0.9)\ncb.learn = MagicMock()\ncb.recorder = MagicMock()\n\n# Simulate before_fit: set idx manually (TrackerCallback would do this)\ncb.idx = 2\ncb.best = -float('inf')\ncb.ema_value = None\ncb.ema_history = []\n\n# Simulate epochs with known Dice values\ndice_values = [0.5, 0.6, 0.55, 0.7, 0.65]\nexpected_ema = []\nema = None\nfor d in dice_values:\n ema = d if ema is None else 0.9 * ema + 0.1 * d\n expected_ema.append(ema)\n\nfor d in dice_values:\n cb.recorder.values = [[1, 0.5, d]] # values[-1][idx=2] = d\n cb.after_epoch()\n\nfor actual, expected in zip(cb.ema_history, expected_ema):\n test_eq(round(actual, 10), round(expected, 10))\n\nassert cb.best is not None\nassert cb.learn.save.called\nprint(f\"EMA history: {[f'{v:.4f}' for v in cb.ema_history]}\")\nprint(\"EMACheckpoint test passed!\")" + }, { "cell_type": "code", "execution_count": null, diff --git a/nbs/10_vision_patch.ipynb b/nbs/10_vision_patch.ipynb index 92067c3..83e032d 100644 --- a/nbs/10_vision_patch.ipynb +++ b/nbs/10_vision_patch.ipynb @@ -780,292 +780,7 @@ "id": "cell-17", "metadata": {}, "outputs": [], - "source": [ - "#| export\n", - "import numbers\n", - "\n", - "def _normalize_patch_overlap(patch_overlap, patch_size):\n", - " \"\"\"Convert patch_overlap to integer pixel values for TorchIO compatibility.\n", - "\n", - " TorchIO's GridSampler expects patch_overlap as a tuple of even integers.\n", - " This function handles:\n", - " - Fractional overlap (0-1): converted to pixel values based on patch_size\n", - " - Numpy scalar types: converted to native Python types\n", - " - Sequences: converted to tuple of integers\n", - "\n", - " Note: Input validation (negative values, overlap >= patch_size) is handled\n", - " by PatchConfig.__post_init__(). This function focuses on format conversion.\n", - "\n", - " Args:\n", - " patch_overlap: int, float (0-1 for fraction), or sequence\n", - " patch_size: list/tuple of patch dimensions [x, y, z]\n", - "\n", - " Returns:\n", - " Tuple of even integers suitable for TorchIO GridSampler\n", - " \"\"\"\n", - " # Handle scalar fractional overlap (0 < x < 1)\n", - " # Note: excludes 1.0 as 100% overlap creates step_size=0 (infinite patches)\n", - " if isinstance(patch_overlap, (int, float, numbers.Number)) and 0 < float(patch_overlap) < 1:\n", - " # Convert fraction to pixel values, ensure even\n", - " result = []\n", - " for ps in patch_size:\n", - " pixels = int(int(ps) * float(patch_overlap))\n", - " # Ensure even (required by TorchIO)\n", - " if pixels % 2 != 0:\n", - " pixels = pixels - 1 if pixels > 0 else 0\n", - " result.append(pixels)\n", - " return tuple(result)\n", - "\n", - " # Handle scalar integer (including numpy scalars) - values > 1 are pixel counts\n", - " if isinstance(patch_overlap, (int, float, numbers.Number)):\n", - " val = int(patch_overlap)\n", - " # Ensure even\n", - " if val % 2 != 0:\n", - " val = val - 1 if val > 0 else 0\n", - " return tuple(val for _ in patch_size)\n", - "\n", - " # Handle sequences (list, tuple, ndarray)\n", - " result = []\n", - " for val in patch_overlap:\n", - " pixels = int(val)\n", - " if pixels % 2 != 0:\n", - " pixels = pixels - 1 if pixels > 0 else 0\n", - " result.append(pixels)\n", - " return tuple(result)\n", - "\n", - "\n", - "class PatchInferenceEngine:\n", - " \"\"\"Patch-based inference with automatic volume reconstruction.\n", - " \n", - " Uses TorchIO's GridSampler to extract overlapping patches and\n", - " GridAggregator to reconstruct the full volume from predictions.\n", - " \n", - " Args:\n", - " learner: fastai Learner or PyTorch model (nn.Module). When passing a raw\n", - " PyTorch model, load weights first with model.load_state_dict().\n", - " config: PatchConfig with inference settings. Preprocessing params (apply_reorder,\n", - " target_spacing, padding_mode) can be set here for DRY usage.\n", - " apply_reorder: Whether to reorder to RAS+ orientation. If None, uses config value.\n", - " target_spacing: Target voxel spacing. If None, uses config value.\n", - " batch_size: Number of patches to predict at once. Must be positive.\n", - " pre_inference_tfms: List of TorchIO transforms to apply before patch extraction.\n", - " IMPORTANT: Should match the pre_patch_tfms used during training (e.g., [tio.ZNormalization()]).\n", - " This ensures preprocessing consistency between training and inference.\n", - " Accepts both fastMONAI wrappers and raw TorchIO transforms.\n", - " \n", - " Example:\n", - " >>> # Option 1: From fastai Learner\n", - " >>> engine = PatchInferenceEngine(learn, config, pre_inference_tfms=[ZNormalization()])\n", - " >>> pred = engine.predict('image.nii.gz')\n", - " \n", - " >>> # Option 2: From raw PyTorch model (recommended for deployment)\n", - " >>> model = UNet(spatial_dims=3, in_channels=1, out_channels=2, ...)\n", - " >>> model.load_state_dict(torch.load('final_weights.pth'))\n", - " >>> model.cuda().eval()\n", - " >>> engine = PatchInferenceEngine(model, config, pre_inference_tfms=[ZNormalization()])\n", - " >>> pred = engine.predict('image.nii.gz')\n", - " \"\"\"\n", - " \n", - " def __init__(\n", - " self,\n", - " learner,\n", - " config: PatchConfig,\n", - " apply_reorder: bool = None,\n", - " target_spacing: list = None,\n", - " batch_size: int = 4,\n", - " pre_inference_tfms: list = None\n", - " ):\n", - " if batch_size <= 0:\n", - " raise ValueError(f\"batch_size must be positive, got {batch_size}\")\n", - " \n", - " # Extract model from Learner if needed (use isinstance for robust detection)\n", - " # Note: We check for Learner explicitly because some models (e.g., MONAI UNet)\n", - " # have a .model attribute that is NOT the full model but an internal Sequential.\n", - " if isinstance(learner, Learner):\n", - " self.model = learner.model\n", - " else:\n", - " self.model = learner # Assume it's already a PyTorch model\n", - " \n", - " self.config = config\n", - " self.batch_size = batch_size\n", - " \n", - " # Normalize transforms to raw TorchIO (accepts both fastMONAI wrappers and raw TorchIO)\n", - " normalized_tfms = normalize_patch_transforms(pre_inference_tfms)\n", - " self.pre_inference_tfms = tio.Compose(normalized_tfms) if normalized_tfms else None\n", - " \n", - " # Use config values, allow explicit overrides for backward compatibility\n", - " self.apply_reorder = apply_reorder if apply_reorder is not None else config.apply_reorder\n", - " self.target_spacing = target_spacing if target_spacing is not None else config.target_spacing\n", - " \n", - " # Warn if explicit args provided but differ from config (potential mistake)\n", - " _warn_config_override('apply_reorder', config.apply_reorder, apply_reorder)\n", - " _warn_config_override('target_spacing', config.target_spacing, target_spacing)\n", - " \n", - " # Get device from model parameters, with fallback for parameter-less models\n", - " try:\n", - " self._device = next(self.model.parameters()).device\n", - " except StopIteration:\n", - " self._device = _get_default_device()\n", - " \n", - " def predict(\n", - " self,\n", - " img_path: Path | str,\n", - " return_probabilities: bool = False,\n", - " return_affine: bool = False\n", - " ) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]:\n", - " \"\"\"Predict on a single volume using patch-based inference.\n", - "\n", - " Args:\n", - " img_path: Path to input image.\n", - " return_probabilities: If True, return probability map instead of argmax.\n", - " return_affine: If True, return (prediction, affine) tuple instead of just prediction.\n", - "\n", - " Returns:\n", - " Predicted segmentation mask tensor, or tuple (prediction, affine) if return_affine=True.\n", - " \"\"\"\n", - " # Load image - keep org_img and org_size for post-processing\n", - " # Note: med_img_reader handles reorder/resample internally, no global state needed\n", - " org_img, input_img, org_size = med_img_reader(\n", - " img_path, apply_reorder=self.apply_reorder, target_spacing=self.target_spacing, only_tensor=False\n", - " )\n", - "\n", - " # Create TorchIO Subject from preprocessed image\n", - " subject = tio.Subject(\n", - " image=tio.ScalarImage(tensor=input_img.data.float(), affine=input_img.affine)\n", - " )\n", - "\n", - " # Apply pre-inference transforms (e.g., ZNormalization) to match training\n", - " if self.pre_inference_tfms is not None:\n", - " subject = self.pre_inference_tfms(subject)\n", - "\n", - " # Pad dimensions smaller than patch_size, keep larger dimensions intact\n", - " # GridSampler handles large images via overlapping patches\n", - " img_shape = subject['image'].shape[1:] # Exclude channel dim\n", - " target_size = [max(s, p) for s, p in zip(img_shape, self.config.patch_size)]\n", - " \n", - " # Warn if volume needed padding (may cause artifacts if training didn't cover similar sizes)\n", - " if any(s < p for s, p in zip(img_shape, self.config.patch_size)):\n", - " padded_dims = [f\"dim{i}: {s}<{p}\" for i, (s, p) in enumerate(zip(img_shape, self.config.patch_size)) if s < p]\n", - " warnings.warn(\n", - " f\"Image size {list(img_shape)} smaller than patch_size {self.config.patch_size} \"\n", - " f\"in {padded_dims}. Padding with mode={self.config.padding_mode}. \"\n", - " \"Ensure training data covered similar sizes to avoid artifacts.\"\n", - " )\n", - " \n", - " # Use padding_mode from config (default: 0 for zero padding, nnU-Net standard)\n", - " subject = tio.CropOrPad(target_size, padding_mode=self.config.padding_mode)(subject)\n", - "\n", - " # Convert patch_overlap to integer pixel values for TorchIO compatibility\n", - " patch_overlap = _normalize_patch_overlap(self.config.patch_overlap, self.config.patch_size)\n", - "\n", - " # Create GridSampler\n", - " grid_sampler = tio.GridSampler(\n", - " subject,\n", - " patch_size=self.config.patch_size,\n", - " patch_overlap=patch_overlap\n", - " )\n", - "\n", - " # Create GridAggregator\n", - " aggregator = tio.GridAggregator(\n", - " grid_sampler,\n", - " overlap_mode=self.config.aggregation_mode\n", - " )\n", - "\n", - " # Create patch loader\n", - " patch_loader = DataLoader(\n", - " grid_sampler,\n", - " batch_size=self.batch_size,\n", - " num_workers=0\n", - " )\n", - "\n", - " # Predict patches\n", - " self.model.eval()\n", - " with torch.no_grad():\n", - " for patches_batch in patch_loader:\n", - " patch_input = patches_batch['image'][tio.DATA].to(self._device)\n", - " locations = patches_batch[tio.LOCATION]\n", - "\n", - " # Forward pass - get logits\n", - " logits = self.model(patch_input)\n", - "\n", - " # Convert logits to probabilities BEFORE aggregation\n", - " # This is critical: softmax is non-linear, so we must aggregate\n", - " # probabilities, not logits, to get correct boundary handling\n", - " n_classes = logits.shape[1]\n", - " if n_classes == 1:\n", - " probs = torch.sigmoid(logits)\n", - " else:\n", - " probs = torch.softmax(logits, dim=1) # dim=1 for batch [B, C, H, W, D]\n", - "\n", - " # Add probabilities to aggregator\n", - " aggregator.add_batch(probs.cpu(), locations)\n", - "\n", - " # Get reconstructed output (now contains probabilities, not logits)\n", - " output = aggregator.get_output_tensor()\n", - "\n", - " # Convert to prediction mask (only if not returning probabilities)\n", - " if return_probabilities:\n", - " result = output # Keep as float probabilities\n", - " else:\n", - " n_classes = output.shape[0]\n", - " if n_classes == 1:\n", - " result = (output > 0.5).float()\n", - " else:\n", - " result = output.argmax(dim=0, keepdim=True).float()\n", - "\n", - " # Apply keep_largest post-processing for binary segmentation\n", - " if not return_probabilities and self.config.keep_largest_component:\n", - " from fastMONAI.vision_inference import keep_largest\n", - " result = keep_largest(result.squeeze(0)).unsqueeze(0)\n", - "\n", - " # Post-processing: resize back to original size and reorient\n", - " # This matches the workflow in vision_inference.py\n", - " \n", - " # Wrap result in TorchIO Image for resizing\n", - " # Use ScalarImage for probabilities, LabelMap for masks\n", - " if return_probabilities:\n", - " pred_img = tio.ScalarImage(tensor=result.float(), affine=input_img.affine)\n", - " else:\n", - " pred_img = tio.LabelMap(tensor=result.float(), affine=input_img.affine)\n", - " \n", - " # Resize back to original size (before resampling)\n", - " pred_img = _do_resize(pred_img, org_size, image_interpolation='nearest')\n", - " \n", - " # Reorient to original orientation (if reorder was applied)\n", - " # Use explicit .cpu() for consistent device handling\n", - " if self.apply_reorder:\n", - " reoriented_array = _to_original_orientation(\n", - " pred_img.as_sitk(),\n", - " ('').join(org_img.orientation)\n", - " )\n", - " result = torch.from_numpy(reoriented_array).cpu()\n", - " # Only convert to long for masks, not probabilities\n", - " if not return_probabilities:\n", - " result = result.long()\n", - " else:\n", - " result = pred_img.data.cpu()\n", - " # Only convert to long for masks, not probabilities\n", - " if not return_probabilities:\n", - " result = result.long()\n", - "\n", - " # Use original affine matrix for correct spatial alignment\n", - " # org_img.affine is always available from med_img_reader\n", - " if not (hasattr(org_img, 'affine') and org_img.affine is not None):\n", - " raise RuntimeError(\n", - " \"org_img.affine not available. This should never happen - please report this bug.\"\n", - " )\n", - " affine = org_img.affine.copy()\n", - "\n", - " if return_affine:\n", - " return result, affine\n", - " return result\n", - " \n", - " def to(self, device):\n", - " \"\"\"Move engine to device.\"\"\"\n", - " self._device = device\n", - " self.model.to(device)\n", - " return self" - ] + "source": "#| export\nimport numbers\n\ndef _normalize_patch_overlap(patch_overlap, patch_size):\n \"\"\"Convert patch_overlap to integer pixel values for TorchIO compatibility.\n\n TorchIO's GridSampler expects patch_overlap as a tuple of even integers.\n This function handles:\n - Fractional overlap (0-1): converted to pixel values based on patch_size\n - Numpy scalar types: converted to native Python types\n - Sequences: converted to tuple of integers\n\n Note: Input validation (negative values, overlap >= patch_size) is handled\n by PatchConfig.__post_init__(). This function focuses on format conversion.\n\n Args:\n patch_overlap: int, float (0-1 for fraction), or sequence\n patch_size: list/tuple of patch dimensions [x, y, z]\n\n Returns:\n Tuple of even integers suitable for TorchIO GridSampler\n \"\"\"\n # Handle scalar fractional overlap (0 < x < 1)\n # Note: excludes 1.0 as 100% overlap creates step_size=0 (infinite patches)\n if isinstance(patch_overlap, (int, float, numbers.Number)) and 0 < float(patch_overlap) < 1:\n # Convert fraction to pixel values, ensure even\n result = []\n for ps in patch_size:\n pixels = int(int(ps) * float(patch_overlap))\n # Ensure even (required by TorchIO)\n if pixels % 2 != 0:\n pixels = pixels - 1 if pixels > 0 else 0\n result.append(pixels)\n return tuple(result)\n\n # Handle scalar integer (including numpy scalars) - values > 1 are pixel counts\n if isinstance(patch_overlap, (int, float, numbers.Number)):\n val = int(patch_overlap)\n # Ensure even\n if val % 2 != 0:\n val = val - 1 if val > 0 else 0\n return tuple(val for _ in patch_size)\n\n # Handle sequences (list, tuple, ndarray)\n result = []\n for val in patch_overlap:\n pixels = int(val)\n if pixels % 2 != 0:\n pixels = pixels - 1 if pixels > 0 else 0\n result.append(pixels)\n return tuple(result)\n\n\n# nnU-Net-style mirror TTA: all 2^3 = 8 flip combinations for 3D.\n# Batch tensor shape: [B, C, D, H, W], spatial dims are 2, 3, 4.\n_TTA_FLIP_AXES = (\n (), # original\n (4,), # flip LR (W)\n (3,), # flip AP (H)\n (2,), # flip IS (D)\n (3, 4), # flip LR+AP\n (2, 4), # flip LR+IS\n (2, 3), # flip AP+IS\n (2, 3, 4), # flip all\n)\n\n\ndef _predict_patch_tta(model, patch_input):\n \"\"\"nnU-Net-style mirror TTA: average probabilities over 8 flip combinations.\n\n Runs 8 forward passes with a running sum for memory efficiency (2x memory,\n not 9x). Each pass: flip input -> forward -> activate -> flip back -> accumulate.\n\n Args:\n model: PyTorch model in eval mode (already on device).\n patch_input: Batch tensor [B, C, D, H, W] already on device.\n\n Returns:\n Averaged probability tensor [B, C, D, H, W] on CPU.\n \"\"\"\n summed_probs = None\n for axes in _TTA_FLIP_AXES:\n flipped = torch.flip(patch_input, list(axes)) if axes else patch_input\n logits = model(flipped)\n n_classes = logits.shape[1]\n probs = torch.sigmoid(logits) if n_classes == 1 else torch.softmax(logits, dim=1)\n if axes:\n probs = torch.flip(probs, list(axes))\n summed_probs = probs if summed_probs is None else summed_probs + probs\n return (summed_probs / len(_TTA_FLIP_AXES)).cpu()\n\n\nclass PatchInferenceEngine:\n \"\"\"Patch-based inference with automatic volume reconstruction.\n \n Uses TorchIO's GridSampler to extract overlapping patches and\n GridAggregator to reconstruct the full volume from predictions.\n \n Args:\n learner: fastai Learner or PyTorch model (nn.Module). When passing a raw\n PyTorch model, load weights first with model.load_state_dict().\n config: PatchConfig with inference settings. Preprocessing params (apply_reorder,\n target_spacing, padding_mode) can be set here for DRY usage.\n apply_reorder: Whether to reorder to RAS+ orientation. If None, uses config value.\n target_spacing: Target voxel spacing. If None, uses config value.\n batch_size: Number of patches to predict at once. Must be positive.\n pre_inference_tfms: List of TorchIO transforms to apply before patch extraction.\n IMPORTANT: Should match the pre_patch_tfms used during training (e.g., [tio.ZNormalization()]).\n This ensures preprocessing consistency between training and inference.\n Accepts both fastMONAI wrappers and raw TorchIO transforms.\n \n Example:\n >>> # Option 1: From fastai Learner\n >>> engine = PatchInferenceEngine(learn, config, pre_inference_tfms=[ZNormalization()])\n >>> pred = engine.predict('image.nii.gz')\n \n >>> # Option 2: From raw PyTorch model (recommended for deployment)\n >>> model = UNet(spatial_dims=3, in_channels=1, out_channels=2, ...)\n >>> model.load_state_dict(torch.load('final_weights.pth'))\n >>> model.cuda().eval()\n >>> engine = PatchInferenceEngine(model, config, pre_inference_tfms=[ZNormalization()])\n >>> pred = engine.predict('image.nii.gz')\n \"\"\"\n \n def __init__(\n self,\n learner,\n config: PatchConfig,\n apply_reorder: bool = None,\n target_spacing: list = None,\n batch_size: int = 4,\n pre_inference_tfms: list = None\n ):\n if batch_size <= 0:\n raise ValueError(f\"batch_size must be positive, got {batch_size}\")\n \n # Extract model from Learner if needed (use isinstance for robust detection)\n # Note: We check for Learner explicitly because some models (e.g., MONAI UNet)\n # have a .model attribute that is NOT the full model but an internal Sequential.\n if isinstance(learner, Learner):\n self.model = learner.model\n else:\n self.model = learner # Assume it's already a PyTorch model\n \n self.config = config\n self.batch_size = batch_size\n \n # Normalize transforms to raw TorchIO (accepts both fastMONAI wrappers and raw TorchIO)\n normalized_tfms = normalize_patch_transforms(pre_inference_tfms)\n self.pre_inference_tfms = tio.Compose(normalized_tfms) if normalized_tfms else None\n \n # Use config values, allow explicit overrides for backward compatibility\n self.apply_reorder = apply_reorder if apply_reorder is not None else config.apply_reorder\n self.target_spacing = target_spacing if target_spacing is not None else config.target_spacing\n \n # Warn if explicit args provided but differ from config (potential mistake)\n _warn_config_override('apply_reorder', config.apply_reorder, apply_reorder)\n _warn_config_override('target_spacing', config.target_spacing, target_spacing)\n \n # Get device from model parameters, with fallback for parameter-less models\n try:\n self._device = next(self.model.parameters()).device\n except StopIteration:\n self._device = _get_default_device()\n \n def predict(\n self,\n img_path: Path | str,\n return_probabilities: bool = False,\n return_affine: bool = False,\n tta: bool = False\n ) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]:\n \"\"\"Predict on a single volume using patch-based inference.\n\n Args:\n img_path: Path to input image.\n return_probabilities: If True, return probability map instead of argmax.\n return_affine: If True, return (prediction, affine) tuple instead of just prediction.\n tta: If True, apply nnU-Net-style mirror test-time augmentation\n (8 flip combinations, averaged probabilities). Requires ~8x inference\n time but improves prediction quality. Works best when training used\n RandomFlip(axes='LRAPIS', p=0.5). Defaults to False.\n\n Returns:\n Predicted segmentation mask tensor, or tuple (prediction, affine) if return_affine=True.\n \"\"\"\n # Load image - keep org_img and org_size for post-processing\n # Note: med_img_reader handles reorder/resample internally, no global state needed\n org_img, input_img, org_size = med_img_reader(\n img_path, apply_reorder=self.apply_reorder, target_spacing=self.target_spacing, only_tensor=False\n )\n\n # Create TorchIO Subject from preprocessed image\n subject = tio.Subject(\n image=tio.ScalarImage(tensor=input_img.data.float(), affine=input_img.affine)\n )\n\n # Apply pre-inference transforms (e.g., ZNormalization) to match training\n if self.pre_inference_tfms is not None:\n subject = self.pre_inference_tfms(subject)\n\n # Pad dimensions smaller than patch_size, keep larger dimensions intact\n # GridSampler handles large images via overlapping patches\n img_shape = subject['image'].shape[1:] # Exclude channel dim\n target_size = [max(s, p) for s, p in zip(img_shape, self.config.patch_size)]\n \n # Warn if volume needed padding (may cause artifacts if training didn't cover similar sizes)\n if any(s < p for s, p in zip(img_shape, self.config.patch_size)):\n padded_dims = [f\"dim{i}: {s}<{p}\" for i, (s, p) in enumerate(zip(img_shape, self.config.patch_size)) if s < p]\n warnings.warn(\n f\"Image size {list(img_shape)} smaller than patch_size {self.config.patch_size} \"\n f\"in {padded_dims}. Padding with mode={self.config.padding_mode}. \"\n \"Ensure training data covered similar sizes to avoid artifacts.\"\n )\n \n # Use padding_mode from config (default: 0 for zero padding, nnU-Net standard)\n subject = tio.CropOrPad(target_size, padding_mode=self.config.padding_mode)(subject)\n\n # Convert patch_overlap to integer pixel values for TorchIO compatibility\n patch_overlap = _normalize_patch_overlap(self.config.patch_overlap, self.config.patch_size)\n\n # Create GridSampler\n grid_sampler = tio.GridSampler(\n subject,\n patch_size=self.config.patch_size,\n patch_overlap=patch_overlap\n )\n\n # Create GridAggregator\n aggregator = tio.GridAggregator(\n grid_sampler,\n overlap_mode=self.config.aggregation_mode\n )\n\n # Create patch loader\n patch_loader = DataLoader(\n grid_sampler,\n batch_size=self.batch_size,\n num_workers=0\n )\n\n # Predict patches\n self.model.eval()\n with torch.no_grad():\n for patches_batch in patch_loader:\n patch_input = patches_batch['image'][tio.DATA].to(self._device)\n locations = patches_batch[tio.LOCATION]\n\n if tta:\n probs = _predict_patch_tta(self.model, patch_input)\n else:\n # Forward pass - get logits\n logits = self.model(patch_input)\n\n # Convert logits to probabilities BEFORE aggregation\n # This is critical: softmax is non-linear, so we must aggregate\n # probabilities, not logits, to get correct boundary handling\n n_classes = logits.shape[1]\n if n_classes == 1:\n probs = torch.sigmoid(logits)\n else:\n probs = torch.softmax(logits, dim=1) # dim=1 for batch [B, C, H, W, D]\n\n probs = probs.cpu()\n\n # Add probabilities to aggregator\n aggregator.add_batch(probs, locations)\n\n # Get reconstructed output (now contains probabilities, not logits)\n output = aggregator.get_output_tensor()\n\n # Convert to prediction mask (only if not returning probabilities)\n if return_probabilities:\n result = output # Keep as float probabilities\n else:\n n_classes = output.shape[0]\n if n_classes == 1:\n result = (output > 0.5).float()\n else:\n result = output.argmax(dim=0, keepdim=True).float()\n\n # Apply keep_largest post-processing for binary segmentation\n if not return_probabilities and self.config.keep_largest_component:\n from fastMONAI.vision_inference import keep_largest\n result = keep_largest(result.squeeze(0)).unsqueeze(0)\n\n # Post-processing: resize back to original size and reorient\n # This matches the workflow in vision_inference.py\n \n # Wrap result in TorchIO Image for resizing\n # Use ScalarImage for probabilities, LabelMap for masks\n if return_probabilities:\n pred_img = tio.ScalarImage(tensor=result.float(), affine=input_img.affine)\n else:\n pred_img = tio.LabelMap(tensor=result.float(), affine=input_img.affine)\n \n # Resize back to original size (before resampling)\n pred_img = _do_resize(pred_img, org_size, image_interpolation='nearest')\n \n # Reorient to original orientation (if reorder was applied)\n # Use explicit .cpu() for consistent device handling\n if self.apply_reorder:\n reoriented_array = _to_original_orientation(\n pred_img.as_sitk(),\n ('').join(org_img.orientation)\n )\n result = torch.from_numpy(reoriented_array).cpu()\n # Only convert to long for masks, not probabilities\n if not return_probabilities:\n result = result.long()\n else:\n result = pred_img.data.cpu()\n # Only convert to long for masks, not probabilities\n if not return_probabilities:\n result = result.long()\n\n # Use original affine matrix for correct spatial alignment\n # org_img.affine is always available from med_img_reader\n if not (hasattr(org_img, 'affine') and org_img.affine is not None):\n raise RuntimeError(\n \"org_img.affine not available. This should never happen - please report this bug.\"\n )\n affine = org_img.affine.copy()\n\n if return_affine:\n return result, affine\n return result\n \n def to(self, device):\n \"\"\"Move engine to device.\"\"\"\n self._device = device\n self.model.to(device)\n return self" }, { "cell_type": "code", @@ -1104,106 +819,15 @@ "id": "cell-18", "metadata": {}, "outputs": [], - "source": [ - "#| export\n", - "def patch_inference(\n", - " learner,\n", - " config: PatchConfig,\n", - " file_paths: list,\n", - " apply_reorder: bool = None,\n", - " target_spacing: list = None,\n", - " batch_size: int = 4,\n", - " return_probabilities: bool = False,\n", - " progress: bool = True,\n", - " save_dir: str = None,\n", - " pre_inference_tfms: list = None\n", - ") -> list:\n", - " \"\"\"Batch patch-based inference on multiple volumes.\n", - " \n", - " Args:\n", - " learner: PyTorch model or fastai Learner.\n", - " config: PatchConfig with inference settings. Preprocessing params (apply_reorder,\n", - " target_spacing) can be set here for DRY usage.\n", - " file_paths: List of image paths.\n", - " apply_reorder: Whether to reorder to RAS+ orientation. If None, uses config value.\n", - " target_spacing: Target voxel spacing. If None, uses config value.\n", - " batch_size: Patches per batch.\n", - " return_probabilities: Return probability maps.\n", - " progress: Show progress bar.\n", - " save_dir: Directory to save predictions as NIfTI files. If None, predictions are not saved.\n", - " pre_inference_tfms: List of TorchIO transforms to apply before patch extraction.\n", - " IMPORTANT: Should match the pre_patch_tfms used during training (e.g., [tio.ZNormalization()]).\n", - " \n", - " Returns:\n", - " List of predicted tensors.\n", - " \n", - " Example:\n", - " >>> config = PatchConfig(\n", - " ... patch_size=[96, 96, 96],\n", - " ... apply_reorder=True,\n", - " ... target_spacing=[0.4102, 0.4102, 1.5]\n", - " ... )\n", - " >>> predictions = patch_inference(\n", - " ... learner=learn,\n", - " ... config=config, # apply_reorder and target_spacing from config\n", - " ... file_paths=val_paths,\n", - " ... pre_inference_tfms=[tio.ZNormalization()],\n", - " ... save_dir='predictions/patch_based'\n", - " ... )\n", - " \"\"\"\n", - " # Use config values if not explicitly provided\n", - " _apply_reorder = apply_reorder if apply_reorder is not None else config.apply_reorder\n", - " _target_spacing = target_spacing if target_spacing is not None else config.target_spacing\n", - " \n", - " engine = PatchInferenceEngine(\n", - " learner, config, _apply_reorder, _target_spacing, batch_size, pre_inference_tfms\n", - " )\n", - " \n", - " # Create save directory if specified\n", - " if save_dir is not None:\n", - " save_path = Path(save_dir)\n", - " save_path.mkdir(parents=True, exist_ok=True)\n", - " \n", - " predictions = []\n", - " iterator = tqdm(file_paths, desc='Patch inference') if progress else file_paths\n", - " \n", - " for path in iterator:\n", - " # Get prediction and affine when saving is needed\n", - " if save_dir is not None:\n", - " pred, affine = engine.predict(path, return_probabilities, return_affine=True)\n", - " else:\n", - " pred = engine.predict(path, return_probabilities)\n", - " predictions.append(pred)\n", - " \n", - " # Save prediction if save_dir specified\n", - " if save_dir is not None:\n", - " input_path = Path(path)\n", - " # Create output filename based on input using suffix-based approach\n", - " # This handles .nii.gz correctly without corrupting filenames with .nii elsewhere\n", - " stem = input_path.stem\n", - " if input_path.suffix == '.gz' and stem.endswith('.nii'):\n", - " # Handle .nii.gz files: stem is \"filename.nii\", strip the .nii\n", - " stem = stem[:-4]\n", - " out_name = f\"{stem}_pred.nii.gz\"\n", - " elif input_path.suffix == '.nii':\n", - " # Handle .nii files\n", - " out_name = f\"{stem}_pred.nii\"\n", - " else:\n", - " # Fallback for other formats\n", - " out_name = f\"{stem}_pred.nii.gz\"\n", - " out_path = save_path / out_name\n", - " \n", - " # affine is guaranteed to be valid from engine.predict() with return_affine=True\n", - " # Save as NIfTI using TorchIO with correct type\n", - " # Use ScalarImage for probabilities (float), LabelMap for masks (int)\n", - " if return_probabilities:\n", - " pred_img = tio.ScalarImage(tensor=pred, affine=affine)\n", - " else:\n", - " pred_img = tio.LabelMap(tensor=pred, affine=affine)\n", - " pred_img.save(out_path)\n", - " \n", - " return predictions" - ] + "source": "#| export\ndef patch_inference(\n learner,\n config: PatchConfig,\n file_paths: list,\n apply_reorder: bool = None,\n target_spacing: list = None,\n batch_size: int = 4,\n return_probabilities: bool = False,\n progress: bool = True,\n save_dir: str = None,\n pre_inference_tfms: list = None,\n tta: bool = False\n) -> list:\n \"\"\"Batch patch-based inference on multiple volumes.\n \n Args:\n learner: PyTorch model or fastai Learner.\n config: PatchConfig with inference settings. Preprocessing params (apply_reorder,\n target_spacing) can be set here for DRY usage.\n file_paths: List of image paths.\n apply_reorder: Whether to reorder to RAS+ orientation. If None, uses config value.\n target_spacing: Target voxel spacing. If None, uses config value.\n batch_size: Patches per batch.\n return_probabilities: Return probability maps.\n progress: Show progress bar.\n save_dir: Directory to save predictions as NIfTI files. If None, predictions are not saved.\n pre_inference_tfms: List of TorchIO transforms to apply before patch extraction.\n IMPORTANT: Should match the pre_patch_tfms used during training (e.g., [tio.ZNormalization()]).\n tta: If True, apply nnU-Net-style mirror TTA (8 flip combinations).\n \n Returns:\n List of predicted tensors.\n \n Example:\n >>> config = PatchConfig(\n ... patch_size=[96, 96, 96],\n ... apply_reorder=True,\n ... target_spacing=[0.4102, 0.4102, 1.5]\n ... )\n >>> predictions = patch_inference(\n ... learner=learn,\n ... config=config, # apply_reorder and target_spacing from config\n ... file_paths=val_paths,\n ... pre_inference_tfms=[tio.ZNormalization()],\n ... save_dir='predictions/patch_based'\n ... )\n \"\"\"\n # Use config values if not explicitly provided\n _apply_reorder = apply_reorder if apply_reorder is not None else config.apply_reorder\n _target_spacing = target_spacing if target_spacing is not None else config.target_spacing\n \n engine = PatchInferenceEngine(\n learner, config, _apply_reorder, _target_spacing, batch_size, pre_inference_tfms\n )\n \n # Create save directory if specified\n if save_dir is not None:\n save_path = Path(save_dir)\n save_path.mkdir(parents=True, exist_ok=True)\n \n predictions = []\n desc = 'Patch inference (TTA)' if tta else 'Patch inference'\n iterator = tqdm(file_paths, desc=desc) if progress else file_paths\n \n for path in iterator:\n # Get prediction and affine when saving is needed\n if save_dir is not None:\n pred, affine = engine.predict(path, return_probabilities, return_affine=True, tta=tta)\n else:\n pred = engine.predict(path, return_probabilities, tta=tta)\n predictions.append(pred)\n \n # Save prediction if save_dir specified\n if save_dir is not None:\n input_path = Path(path)\n # Create output filename based on input using suffix-based approach\n # This handles .nii.gz correctly without corrupting filenames with .nii elsewhere\n stem = input_path.stem\n if input_path.suffix == '.gz' and stem.endswith('.nii'):\n # Handle .nii.gz files: stem is \"filename.nii\", strip the .nii\n stem = stem[:-4]\n out_name = f\"{stem}_pred.nii.gz\"\n elif input_path.suffix == '.nii':\n # Handle .nii files\n out_name = f\"{stem}_pred.nii\"\n else:\n # Fallback for other formats\n out_name = f\"{stem}_pred.nii.gz\"\n out_path = save_path / out_name\n \n # affine is guaranteed to be valid from engine.predict() with return_affine=True\n # Save as NIfTI using TorchIO with correct type\n # Use ScalarImage for probabilities (float), LabelMap for masks (int)\n if return_probabilities:\n pred_img = tio.ScalarImage(tensor=pred, affine=affine)\n else:\n pred_img = tio.LabelMap(tensor=pred, affine=affine)\n pred_img.save(out_path)\n \n return predictions" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "mva8n5dv9q", + "metadata": {}, + "outputs": [], + "source": "# Test _TTA_FLIP_AXES and _predict_patch_tta\nfrom itertools import combinations\n\n# Test 1: _TTA_FLIP_AXES has exactly 8 entries (2^3 combinations for 3 axes)\ntest_eq(len(_TTA_FLIP_AXES), 8)\n\n# Verify all 2^3 combinations are present (each axis in {2,3,4} independently on/off)\nexpected_combos = set()\naxes = [2, 3, 4]\nfor r in range(len(axes) + 1):\n for combo in combinations(axes, r):\n expected_combos.add(combo)\nactual_combos = set(tuple(sorted(a)) for a in _TTA_FLIP_AXES)\ntest_eq(actual_combos, expected_combos)\n\n# Test 2: _predict_patch_tta output shape and probability range\nimport torch.nn as nn\n\nclass _SimpleConv(nn.Module):\n \"\"\"Minimal model for TTA testing.\"\"\"\n def __init__(self, out_channels):\n super().__init__()\n self.conv = nn.Conv3d(1, out_channels, 1)\n def forward(self, x):\n return self.conv(x)\n\n# Binary case (1 output channel -> sigmoid)\nmodel_bin = _SimpleConv(1).eval()\ndummy_input = torch.randn(2, 1, 8, 8, 8) # [B=2, C=1, D, H, W]\nwith torch.no_grad():\n tta_out = _predict_patch_tta(model_bin, dummy_input)\ntest_eq(tta_out.shape, torch.Size([2, 1, 8, 8, 8]))\nassert tta_out.min() >= 0.0 and tta_out.max() <= 1.0, f\"Probabilities out of range: [{tta_out.min()}, {tta_out.max()}]\"\n\n# Multi-class case (3 output channels -> softmax)\nmodel_mc = _SimpleConv(3).eval()\nwith torch.no_grad():\n tta_out_mc = _predict_patch_tta(model_mc, dummy_input)\ntest_eq(tta_out_mc.shape, torch.Size([2, 3, 8, 8, 8]))\nassert tta_out_mc.min() >= 0.0 and tta_out_mc.max() <= 1.0\n\n# Test 3: TTA on constant input matches single forward pass\n# A constant tensor is invariant to flipping, so TTA should equal single pass\nconst_input = torch.ones(1, 1, 8, 8, 8) * 0.5\nwith torch.no_grad():\n single_logits = model_bin(const_input)\n single_probs = torch.sigmoid(single_logits).cpu()\n tta_probs = _predict_patch_tta(model_bin, const_input)\nassert torch.allclose(single_probs, tta_probs, atol=1e-6), \"TTA on constant input should match single forward pass\"\n\nprint(\"TTA tests passed!\")" }, { "cell_type": "markdown", diff --git a/nbs/12a_tutorial_patch_training.ipynb b/nbs/12a_tutorial_patch_training.ipynb index 58d6651..f9edef6 100644 --- a/nbs/12a_tutorial_patch_training.ipynb +++ b/nbs/12a_tutorial_patch_training.ipynb @@ -448,7 +448,7 @@ " patch_size=[160, 160, 80],\n", " samples_per_volume=8,\n", " sampler_type='label',\n", - " label_probabilities={0: 0.3, 1: 0.7},\n", + " label_probabilities={0: 0.67, 1: 0.33},\n", " patch_overlap=0.5,\n", " keep_largest_component=True,\n", " target_spacing=target_spacing,\n", @@ -495,9 +495,9 @@ "# Pre-patch transforms (applied to full volumes by Queue workers)\n", "pre_patch_tfms = [ZNormalization()]\n", "\n", - "# Patch augmentations (applied to training patches only)\n", "patch_tfms = [\n", - " RandomAffine(scales=(0.7, 1.4), degrees=30, translation=0, p=0.2),\n", + " RandomAffine(scales=(0.7, 1.4), degrees=30, translation=(25, 25, 10), p=0.2),\n", + " RandomAnisotropy(downsampling=(1.5, 3), p=0.25),\n", " RandomGamma(log_gamma=(-0.3, 0.3), p=0.3),\n", " RandomIntensityScale(scale_range=(0.75, 1.25), p=0.1),\n", " RandomNoise(std=0.1, p=0.1),\n", @@ -745,15 +745,7 @@ "id": "cell-callbacks", "metadata": {}, "outputs": [], - "source": [ - "best_model_fname = \"best_heart_patch\"\n", - "save_best = SaveModelCallback(\n", - " monitor='accumulated_dice',\n", - " comp=np.greater,\n", - " fname=best_model_fname,\n", - " with_opt=False\n", - ")" - ] + "source": "best_model_fname = \"best_heart_patch\"\nsave_best = EMACheckpoint(\n monitor='accumulated_dice',\n momentum=0.9,\n comp=np.greater,\n fname=best_model_fname,\n with_opt=False\n)" }, { "cell_type": "code", @@ -1206,63 +1198,8 @@ "execution_count": null, "id": "cell-eval-inference", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loaded best model: best_heart_patch\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Saved file doesn't contain an optimizer state.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c6d5114f62d44344aa5d4e6d585dd7e5", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Patch inference: 0%| | 0/1 [00:00=1.2.0 scikit-image==0.26.0 imagedata==3.8.14 mlflow==3.9.0 huggingface-hub gdown gradio opencv-python plum-dispatch