Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions fastMONAI/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,9 @@ class ModelTrackingCallback(Callback):
A FastAI callback for comprehensive MLflow experiment tracking.

This callback automatically logs hyperparameters, metrics, model artifacts,
and configuration to MLflow during training. If a SaveModelCallback is present,
the best model checkpoint will also be logged as an artifact.
and configuration to MLflow during training. If a checkpoint callback
(SaveModelCallback, EMACheckpoint, or any TrackerCallback with fname) is
present, the best model checkpoint will also be logged as an artifact.

Supports auto-managed runs when created via `create_mlflow_callback()`.
"""
Expand Down Expand Up @@ -455,18 +456,18 @@ def _save_model_artifacts(self, temp_dir: Path) -> None:
"""Save model weights, learner, and configuration as artifacts."""
import shutil

# Save final epoch weights
# Save final epoch weights (without optimizer state to reduce file size)
weights_path = temp_dir / "final_weights"
self.learn.save(str(weights_path))
self.learn.save(str(weights_path), with_opt=False)
weights_file = f"{weights_path}.pth"
if os.path.exists(weights_file):
mlflow.log_artifact(weights_file, "model")

# Auto-detect SaveModelCallback and log best model weights
from fastai.callback.tracker import SaveModelCallback
# Auto-detect checkpoint callback (SaveModelCallback, EMACheckpoint, etc.)
from fastai.callback.tracker import TrackerCallback
best_model_cb = None
for cb in self.learn.cbs:
if isinstance(cb, SaveModelCallback):
if isinstance(cb, TrackerCallback) and hasattr(cb, 'fname'):
best_model_path = self.learn.path / self.learn.model_dir / f'{cb.fname}.pth'
if best_model_path.exists():
best_weights_dest = temp_dir / "best_weights.pth"
Expand Down
9 changes: 9 additions & 0 deletions fastMONAI/vision_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,15 @@ def __post_init__(self):
f"Overlap >= patch_size creates step_size <= 0 (infinite patches)."
)

# Warn if patch_size dimensions are not divisible by 16
non_div = [s for s in self.patch_size if s % 16 != 0]
if non_div:
warnings.warn(
f"patch_size {self.patch_size} has dimensions not divisible by 16. "
f"Most encoder-decoder architectures (e.g., U-Net) require patch sizes "
f"divisible by 16 (2^4 for 4 downsampling levels)."
)

@classmethod
def from_dataset(
cls,
Expand Down
2 changes: 1 addition & 1 deletion nbs/07_utils.ipynb

Large diffs are not rendered by default.

154 changes: 1 addition & 153 deletions nbs/10_vision_patch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -178,159 +178,7 @@
"id": "cell-5",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"@dataclass\n",
"class PatchConfig:\n",
" \"\"\"Configuration for patch-based training and inference.\n",
" \n",
" Args:\n",
" patch_size: Size of patches [x, y, z].\n",
" patch_overlap: Overlap for inference GridSampler (int, float 0-1, or list).\n",
" - Float 0-1: fraction of patch_size (e.g., 0.5 = 50% overlap)\n",
" - Int >= 1: pixel overlap (e.g., 48 = 48 pixel overlap)\n",
" - List: per-dimension overlap in pixels\n",
" samples_per_volume: Number of patches to extract per volume during training.\n",
" sampler_type: Type of sampler ('uniform', 'label', 'weighted').\n",
" label_probabilities: For LabelSampler, dict mapping label values to probabilities.\n",
" queue_length: Maximum number of patches to store in queue.\n",
" queue_num_workers: Number of workers for parallel patch extraction.\n",
" aggregation_mode: For inference, how to combine overlapping patches ('crop', 'average', 'hann').\n",
" apply_reorder: Whether to reorder to RAS+ canonical orientation. Must match between\n",
" training and inference. Defaults to True (the common case).\n",
" target_spacing: Target voxel spacing [x, y, z] for resampling. Must match between\n",
" training and inference.\n",
" padding_mode: Padding mode for CropOrPad when image < patch_size. Default is 0 (zero padding)\n",
" to align with nnU-Net's approach. Can be int, float, or string (e.g., 'minimum', 'mean').\n",
" keep_largest_component: If True, keep only the largest connected component\n",
" in binary segmentation predictions. Only applies during inference when\n",
" return_probabilities=False. Defaults to False.\n",
" \n",
" Example:\n",
" >>> config = PatchConfig(\n",
" ... patch_size=[96, 96, 96],\n",
" ... samples_per_volume=16,\n",
" ... sampler_type='label',\n",
" ... label_probabilities={0: 0.1, 1: 0.9},\n",
" ... target_spacing=[0.5, 0.5, 0.5]\n",
" ... )\n",
" \"\"\"\n",
" patch_size: list = field(default_factory=lambda: [96, 96, 96])\n",
" patch_overlap: int | float | list = 0\n",
" samples_per_volume: int = 8\n",
" sampler_type: str = 'uniform'\n",
" label_probabilities: dict = None\n",
" queue_length: int = 300\n",
" queue_num_workers: int = 4\n",
" aggregation_mode: str = 'hann'\n",
" # Preprocessing parameters - must match between training and inference\n",
" apply_reorder: bool = True # Defaults to True (the common case)\n",
" target_spacing: list = None\n",
" padding_mode: int | float | str = 0 # Zero padding (nnU-Net standard)\n",
" # Post-processing (binary segmentation only)\n",
" keep_largest_component: bool = False\n",
" \n",
" def __post_init__(self):\n",
" \"\"\"Validate configuration.\"\"\"\n",
" valid_samplers = ['uniform', 'label', 'weighted']\n",
" if self.sampler_type not in valid_samplers:\n",
" raise ValueError(f\"sampler_type must be one of {valid_samplers}\")\n",
" \n",
" valid_aggregation = ['crop', 'average', 'hann']\n",
" if self.aggregation_mode not in valid_aggregation:\n",
" raise ValueError(f\"aggregation_mode must be one of {valid_aggregation}\")\n",
" \n",
" # Validate patch_overlap\n",
" # Negative overlap doesn't make sense\n",
" if isinstance(self.patch_overlap, (int, float)):\n",
" if self.patch_overlap < 0:\n",
" raise ValueError(\"patch_overlap cannot be negative\")\n",
" # Check if overlap as pixels would exceed patch_size (causes step_size=0)\n",
" if self.patch_overlap >= 1: # Pixel value, not fraction\n",
" for ps in self.patch_size:\n",
" if self.patch_overlap >= ps:\n",
" raise ValueError(\n",
" f\"patch_overlap ({self.patch_overlap}) must be less than patch_size ({ps}). \"\n",
" f\"Overlap >= patch_size creates step_size <= 0 (infinite patches).\"\n",
" )\n",
" elif isinstance(self.patch_overlap, (list, tuple)):\n",
" for i, (overlap, ps) in enumerate(zip(self.patch_overlap, self.patch_size)):\n",
" if overlap < 0:\n",
" raise ValueError(f\"patch_overlap[{i}] cannot be negative\")\n",
" if overlap >= ps:\n",
" raise ValueError(\n",
" f\"patch_overlap[{i}] ({overlap}) must be less than patch_size[{i}] ({ps}). \"\n",
" f\"Overlap >= patch_size creates step_size <= 0 (infinite patches).\"\n",
" )\n",
"\n",
" @classmethod\n",
" def from_dataset(\n",
" cls,\n",
" dataset: 'MedDataset',\n",
" target_spacing: list = None,\n",
" min_patch_size: list = None,\n",
" max_patch_size: list = None,\n",
" divisor: int = 16,\n",
" **kwargs\n",
" ) -> 'PatchConfig':\n",
" \"\"\"Create PatchConfig with automatic patch_size from dataset analysis.\n",
"\n",
" Combines dataset preprocessing suggestions with patch size calculation\n",
" for a complete, DRY configuration.\n",
"\n",
" Args:\n",
" dataset: MedDataset instance with analyzed images.\n",
" target_spacing: Target voxel spacing [x, y, z]. If None, uses\n",
" dataset.get_suggestion()['target_spacing'].\n",
" min_patch_size: Minimum per dimension [32, 32, 32].\n",
" max_patch_size: Maximum per dimension [256, 256, 256].\n",
" divisor: Divisibility constraint (default 16 for UNet compatibility).\n",
" **kwargs: Additional PatchConfig parameters (samples_per_volume,\n",
" sampler_type, label_probabilities, etc.).\n",
"\n",
" Returns:\n",
" PatchConfig with suggested patch_size, apply_reorder, target_spacing.\n",
"\n",
" Example:\n",
" >>> from fastMONAI.dataset_info import MedDataset\n",
" >>> dataset = MedDataset(dataframe=df, mask_col='mask_path', dtype=MedMask)\n",
" >>> \n",
" >>> # Use recommended spacing\n",
" >>> config = PatchConfig.from_dataset(dataset, samples_per_volume=16)\n",
" >>> \n",
" >>> # Use custom spacing\n",
" >>> config = PatchConfig.from_dataset(\n",
" ... dataset,\n",
" ... target_spacing=[1.0, 1.0, 2.0],\n",
" ... samples_per_volume=16\n",
" ... )\n",
" \"\"\"\n",
" # Get preprocessing suggestion from dataset\n",
" suggestion = dataset.get_suggestion()\n",
"\n",
" # Use explicit spacing or dataset suggestion\n",
" _target_spacing = target_spacing if target_spacing is not None else suggestion['target_spacing']\n",
"\n",
" # Calculate patch size for the target spacing\n",
" patch_size = suggest_patch_size(\n",
" dataset,\n",
" target_spacing=_target_spacing,\n",
" min_patch_size=min_patch_size,\n",
" max_patch_size=max_patch_size,\n",
" divisor=divisor\n",
" )\n",
"\n",
" # Merge with explicit kwargs (kwargs override defaults)\n",
" # Use dataset.apply_reorder directly (not from get_suggestion() since it's not data-derived)\n",
" config_kwargs = {\n",
" 'patch_size': patch_size,\n",
" 'apply_reorder': dataset.apply_reorder,\n",
" 'target_spacing': _target_spacing,\n",
" }\n",
" config_kwargs.update(kwargs)\n",
"\n",
" return cls(**config_kwargs)"
]
"source": "#| export\n@dataclass\nclass PatchConfig:\n \"\"\"Configuration for patch-based training and inference.\n \n Args:\n patch_size: Size of patches [x, y, z].\n patch_overlap: Overlap for inference GridSampler (int, float 0-1, or list).\n - Float 0-1: fraction of patch_size (e.g., 0.5 = 50% overlap)\n - Int >= 1: pixel overlap (e.g., 48 = 48 pixel overlap)\n - List: per-dimension overlap in pixels\n samples_per_volume: Number of patches to extract per volume during training.\n sampler_type: Type of sampler ('uniform', 'label', 'weighted').\n label_probabilities: For LabelSampler, dict mapping label values to probabilities.\n queue_length: Maximum number of patches to store in queue.\n queue_num_workers: Number of workers for parallel patch extraction.\n aggregation_mode: For inference, how to combine overlapping patches ('crop', 'average', 'hann').\n apply_reorder: Whether to reorder to RAS+ canonical orientation. Must match between\n training and inference. Defaults to True (the common case).\n target_spacing: Target voxel spacing [x, y, z] for resampling. Must match between\n training and inference.\n padding_mode: Padding mode for CropOrPad when image < patch_size. Default is 0 (zero padding)\n to align with nnU-Net's approach. Can be int, float, or string (e.g., 'minimum', 'mean').\n keep_largest_component: If True, keep only the largest connected component\n in binary segmentation predictions. Only applies during inference when\n return_probabilities=False. Defaults to False.\n \n Example:\n >>> config = PatchConfig(\n ... patch_size=[96, 96, 96],\n ... samples_per_volume=16,\n ... sampler_type='label',\n ... label_probabilities={0: 0.1, 1: 0.9},\n ... target_spacing=[0.5, 0.5, 0.5]\n ... )\n \"\"\"\n patch_size: list = field(default_factory=lambda: [96, 96, 96])\n patch_overlap: int | float | list = 0\n samples_per_volume: int = 8\n sampler_type: str = 'uniform'\n label_probabilities: dict = None\n queue_length: int = 300\n queue_num_workers: int = 4\n aggregation_mode: str = 'hann'\n # Preprocessing parameters - must match between training and inference\n apply_reorder: bool = True # Defaults to True (the common case)\n target_spacing: list = None\n padding_mode: int | float | str = 0 # Zero padding (nnU-Net standard)\n # Post-processing (binary segmentation only)\n keep_largest_component: bool = False\n \n def __post_init__(self):\n \"\"\"Validate configuration.\"\"\"\n valid_samplers = ['uniform', 'label', 'weighted']\n if self.sampler_type not in valid_samplers:\n raise ValueError(f\"sampler_type must be one of {valid_samplers}\")\n \n valid_aggregation = ['crop', 'average', 'hann']\n if self.aggregation_mode not in valid_aggregation:\n raise ValueError(f\"aggregation_mode must be one of {valid_aggregation}\")\n \n # Validate patch_overlap\n # Negative overlap doesn't make sense\n if isinstance(self.patch_overlap, (int, float)):\n if self.patch_overlap < 0:\n raise ValueError(\"patch_overlap cannot be negative\")\n # Check if overlap as pixels would exceed patch_size (causes step_size=0)\n if self.patch_overlap >= 1: # Pixel value, not fraction\n for ps in self.patch_size:\n if self.patch_overlap >= ps:\n raise ValueError(\n f\"patch_overlap ({self.patch_overlap}) must be less than patch_size ({ps}). \"\n f\"Overlap >= patch_size creates step_size <= 0 (infinite patches).\"\n )\n elif isinstance(self.patch_overlap, (list, tuple)):\n for i, (overlap, ps) in enumerate(zip(self.patch_overlap, self.patch_size)):\n if overlap < 0:\n raise ValueError(f\"patch_overlap[{i}] cannot be negative\")\n if overlap >= ps:\n raise ValueError(\n f\"patch_overlap[{i}] ({overlap}) must be less than patch_size[{i}] ({ps}). \"\n f\"Overlap >= patch_size creates step_size <= 0 (infinite patches).\"\n )\n\n # Warn if patch_size dimensions are not divisible by 16\n non_div = [s for s in self.patch_size if s % 16 != 0]\n if non_div:\n warnings.warn(\n f\"patch_size {self.patch_size} has dimensions not divisible by 16. \"\n f\"Most encoder-decoder architectures (e.g., U-Net) require patch sizes \"\n f\"divisible by 16 (2^4 for 4 downsampling levels).\"\n )\n\n @classmethod\n def from_dataset(\n cls,\n dataset: 'MedDataset',\n target_spacing: list = None,\n min_patch_size: list = None,\n max_patch_size: list = None,\n divisor: int = 16,\n **kwargs\n ) -> 'PatchConfig':\n \"\"\"Create PatchConfig with automatic patch_size from dataset analysis.\n\n Combines dataset preprocessing suggestions with patch size calculation\n for a complete, DRY configuration.\n\n Args:\n dataset: MedDataset instance with analyzed images.\n target_spacing: Target voxel spacing [x, y, z]. If None, uses\n dataset.get_suggestion()['target_spacing'].\n min_patch_size: Minimum per dimension [32, 32, 32].\n max_patch_size: Maximum per dimension [256, 256, 256].\n divisor: Divisibility constraint (default 16 for UNet compatibility).\n **kwargs: Additional PatchConfig parameters (samples_per_volume,\n sampler_type, label_probabilities, etc.).\n\n Returns:\n PatchConfig with suggested patch_size, apply_reorder, target_spacing.\n\n Example:\n >>> from fastMONAI.dataset_info import MedDataset\n >>> dataset = MedDataset(dataframe=df, mask_col='mask_path', dtype=MedMask)\n >>> \n >>> # Use recommended spacing\n >>> config = PatchConfig.from_dataset(dataset, samples_per_volume=16)\n >>> \n >>> # Use custom spacing\n >>> config = PatchConfig.from_dataset(\n ... dataset,\n ... target_spacing=[1.0, 1.0, 2.0],\n ... samples_per_volume=16\n ... )\n \"\"\"\n # Get preprocessing suggestion from dataset\n suggestion = dataset.get_suggestion()\n\n # Use explicit spacing or dataset suggestion\n _target_spacing = target_spacing if target_spacing is not None else suggestion['target_spacing']\n\n # Calculate patch size for the target spacing\n patch_size = suggest_patch_size(\n dataset,\n target_spacing=_target_spacing,\n min_patch_size=min_patch_size,\n max_patch_size=max_patch_size,\n divisor=divisor\n )\n\n # Merge with explicit kwargs (kwargs override defaults)\n # Use dataset.apply_reorder directly (not from get_suggestion() since it's not data-derived)\n config_kwargs = {\n 'patch_size': patch_size,\n 'apply_reorder': dataset.apply_reorder,\n 'target_spacing': _target_spacing,\n }\n config_kwargs.update(kwargs)\n\n return cls(**config_kwargs)"
},
{
"cell_type": "code",
Expand Down
Loading