Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions fastMONAI/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,24 @@
'syms': { 'fastMONAI.dataset_info': { 'fastMONAI.dataset_info.MedDataset': ('dataset_info.html#meddataset', 'fastMONAI/dataset_info.py'),
'fastMONAI.dataset_info.MedDataset.__init__': ( 'dataset_info.html#meddataset.__init__',
'fastMONAI/dataset_info.py'),
'fastMONAI.dataset_info.MedDataset._auto_cache_path': ( 'dataset_info.html#meddataset._auto_cache_path',
'fastMONAI/dataset_info.py'),
'fastMONAI.dataset_info.MedDataset._create_data_frame': ( 'dataset_info.html#meddataset._create_data_frame',
'fastMONAI/dataset_info.py'),
'fastMONAI.dataset_info.MedDataset._get_data_info': ( 'dataset_info.html#meddataset._get_data_info',
'fastMONAI/dataset_info.py'),
'fastMONAI.dataset_info.MedDataset._load_cache': ( 'dataset_info.html#meddataset._load_cache',
'fastMONAI/dataset_info.py'),
'fastMONAI.dataset_info.MedDataset._process_with_cache': ( 'dataset_info.html#meddataset._process_with_cache',
'fastMONAI/dataset_info.py'),
'fastMONAI.dataset_info.MedDataset._save_cache': ( 'dataset_info.html#meddataset._save_cache',
'fastMONAI/dataset_info.py'),
'fastMONAI.dataset_info.MedDataset._visualize_single_case': ( 'dataset_info.html#meddataset._visualize_single_case',
'fastMONAI/dataset_info.py'),
'fastMONAI.dataset_info.MedDataset.calculate_target_size': ( 'dataset_info.html#meddataset.calculate_target_size',
'fastMONAI/dataset_info.py'),
'fastMONAI.dataset_info.MedDataset.fingerprint': ( 'dataset_info.html#meddataset.fingerprint',
'fastMONAI/dataset_info.py'),
'fastMONAI.dataset_info.MedDataset.get_size_statistics': ( 'dataset_info.html#meddataset.get_size_statistics',
'fastMONAI/dataset_info.py'),
'fastMONAI.dataset_info.MedDataset.get_suggestion': ( 'dataset_info.html#meddataset.get_suggestion',
Expand Down Expand Up @@ -79,7 +89,7 @@
'fastMONAI/utils.py'),
'fastMONAI.utils.ModelTrackingCallback._extract_training_params': ( 'utils.html#modeltrackingcallback._extract_training_params',
'fastMONAI/utils.py'),
'fastMONAI.utils.ModelTrackingCallback._log_datasets': ( 'utils.html#modeltrackingcallback._log_datasets',
'fastMONAI.utils.ModelTrackingCallback._log_split_df': ( 'utils.html#modeltrackingcallback._log_split_df',
'fastMONAI/utils.py'),
'fastMONAI.utils.ModelTrackingCallback._register_pytorch_model': ( 'utils.html#modeltrackingcallback._register_pytorch_model',
'fastMONAI/utils.py'),
Expand All @@ -103,13 +113,9 @@
'fastMONAI.utils._extract_loss_name': ('utils.html#_extract_loss_name', 'fastMONAI/utils.py'),
'fastMONAI.utils._extract_model_name': ('utils.html#_extract_model_name', 'fastMONAI/utils.py'),
'fastMONAI.utils._extract_patch_config': ('utils.html#_extract_patch_config', 'fastMONAI/utils.py'),
'fastMONAI.utils._extract_patch_dataset_dfs': ( 'utils.html#_extract_patch_dataset_dfs',
'fastMONAI/utils.py'),
'fastMONAI.utils._extract_size_from_transforms': ( 'utils.html#_extract_size_from_transforms',
'fastMONAI/utils.py'),
'fastMONAI.utils._extract_standard_config': ('utils.html#_extract_standard_config', 'fastMONAI/utils.py'),
'fastMONAI.utils._extract_standard_dataset_dfs': ( 'utils.html#_extract_standard_dataset_dfs',
'fastMONAI/utils.py'),
'fastMONAI.utils.create_mlflow_callback': ('utils.html#create_mlflow_callback', 'fastMONAI/utils.py'),
'fastMONAI.utils.load_patch_variables': ('utils.html#load_patch_variables', 'fastMONAI/utils.py'),
'fastMONAI.utils.load_variables': ('utils.html#load_variables', 'fastMONAI/utils.py'),
Expand Down Expand Up @@ -473,6 +479,8 @@
'fastMONAI/vision_patch.py'),
'fastMONAI.vision_patch.MedPatchDataLoaders.show_batch': ( 'vision_patch.html#medpatchdataloaders.show_batch',
'fastMONAI/vision_patch.py'),
'fastMONAI.vision_patch.MedPatchDataLoaders.split_df': ( 'vision_patch.html#medpatchdataloaders.split_df',
'fastMONAI/vision_patch.py'),
'fastMONAI.vision_patch.MedPatchDataLoaders.target_spacing': ( 'vision_patch.html#medpatchdataloaders.target_spacing',
'fastMONAI/vision_patch.py'),
'fastMONAI.vision_patch.MedPatchDataLoaders.to': ( 'vision_patch.html#medpatchdataloaders.to',
Expand Down
152 changes: 147 additions & 5 deletions fastMONAI/dataset_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@

from sklearn.utils.class_weight import compute_class_weight
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import pandas as pd
import numpy as np
import torch
import glob
import hashlib
import os
import pickle
import matplotlib.pyplot as plt

# %% ../nbs/08_dataset_info.ipynb #7401beac
Expand All @@ -23,7 +27,8 @@ class MedDataset:

def __init__(self, dataframe=None, image_col:str=None, mask_col:str="mask_path",
path=None, img_list=None, postfix:str='', apply_reorder:bool=True,
dtype:(MedImage, MedMask)=MedImage, max_workers:int=1):
dtype:(MedImage, MedMask)=MedImage, max_workers:int=1,
use_cache:bool=True, cache_path=None):
"""Constructs MedDataset object.

Args:
Expand All @@ -36,6 +41,12 @@ def __init__(self, dataframe=None, image_col:str=None, mask_col:str="mask_path",
apply_reorder: Whether to reorder images to RAS+ orientation.
dtype: MedImage for images or MedMask for segmentation masks.
max_workers: Number of parallel workers for processing.
use_cache: Enable metadata caching. When True (default) and cache_path
is not provided, auto-generates a cache path in ~/.cache/fastmonai/.
Set to False to disable caching entirely.
cache_path: Explicit path to a pickle file for metadata caching.
Only used when use_cache=True. If None and use_cache=True,
a path is auto-generated from the file list and config.
"""
self.input_df = dataframe
self.image_col = image_col
Expand All @@ -46,6 +57,8 @@ def __init__(self, dataframe=None, image_col:str=None, mask_col:str="mask_path",
self.apply_reorder = apply_reorder
self.dtype = dtype
self.max_workers = max_workers
self.use_cache = use_cache
self.cache_path = cache_path
self.df = self._create_data_frame()

def _create_data_frame(self):
Expand All @@ -70,9 +83,16 @@ def _create_data_frame(self):
print('Error: Must provide path, img_list, or dataframe with mask_col')
return pd.DataFrame()

# Process images to extract metadata
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
data_info_dict = list(executor.map(self._get_data_info, file_list))
# Resolve cache path
if self.use_cache and self.cache_path is None:
self.cache_path = self._auto_cache_path(file_list)

# Process images to extract metadata (with optional caching)
if self.use_cache and self.cache_path is not None:
data_info_dict = self._process_with_cache(file_list)
else:
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
data_info_dict = list(executor.map(self._get_data_info, file_list))

df = pd.DataFrame(data_info_dict)

Expand All @@ -85,6 +105,34 @@ def _create_data_frame(self):

return df

def _auto_cache_path(self, file_list):
"""Generate automatic cache path from file list and config."""
try:
cache_dir = Path.home() / '.cache' / 'fastmonai'
cache_dir.mkdir(parents=True, exist_ok=True)
config_str = f"reorder={self.apply_reorder}|dtype={self.dtype.__name__}"
abs_paths = sorted(os.path.abspath(fn) for fn in file_list)
key_input = config_str + '|' + '|'.join(abs_paths)
key = hashlib.md5(key_input.encode()).hexdigest()
return str(cache_dir / f'med_dataset_{key}.pkl')
except Exception:
return None

@property
def fingerprint(self):
"""Compute a dataset fingerprint from per-file content hashes.

Returns a deterministic MD5 hex digest representing the dataset contents.
Returns None if content hashes are not available (e.g., all files failed).
"""
if 'content_hash' not in self.df.columns:
return None
hashes = self.df['content_hash'].dropna().sort_values().tolist()
if not hashes:
return None
combined = ''.join(hashes)
return hashlib.md5(combined.encode()).hexdigest()

def summary(self):
"""Summary DataFrame of the dataset with example path for similar data."""

Expand Down Expand Up @@ -125,7 +173,10 @@ def _get_data_info(self, fn: str):
try:
_, o, _ = med_img_reader(fn, apply_reorder=self.apply_reorder, only_tensor=False, dtype=self.dtype)

info_dict = {'path': fn, 'dim_0': o.shape[1], 'dim_1': o.shape[2], 'dim_2': o.shape[3],
# Hash loaded tensor data (implicitly captures apply_reorder state)
content_hash = hashlib.md5(o.data.numpy().tobytes()).hexdigest()

info_dict = {'path': fn, 'content_hash': content_hash, 'dim_0': o.shape[1], 'dim_1': o.shape[2], 'dim_2': o.shape[3],
'voxel_0': round(o.spacing[0], 4), 'voxel_1': round(o.spacing[1], 4), 'voxel_2': round(o.spacing[2], 4),
'orientation': f'{"".join(o.orientation)}+'}

Expand All @@ -149,6 +200,97 @@ def _get_data_info(self, fn: str):
print(f"Warning: Failed to process {fn}: {e}")
return {'path': fn, 'error': str(e)}

def _load_cache(self):
"""Load metadata cache from disk. Returns empty dict on any failure."""
try:
with open(self.cache_path, 'rb') as f:
cache = pickle.load(f)
if not isinstance(cache, dict) or cache.get('version') != 1:
return {}
return cache.get('entries', {})
except Exception:
return {}

def _save_cache(self, entries):
"""Save metadata cache to disk."""
try:
cache = {'version': 1, 'entries': entries}
with open(self.cache_path, 'wb') as f:
pickle.dump(cache, f)
except Exception as e:
print(f"Warning: Failed to save metadata cache: {e}")

def _process_with_cache(self, file_list):
"""Process file list with per-file metadata caching.

Loads existing cache, identifies files that need reprocessing
(cache misses), processes only those files, updates and saves
the cache, then returns all results in original order.
"""
entries = self._load_cache()

cached_results = []
files_to_process = []
process_indices = []

for i, fn in enumerate(file_list):
abs_path = os.path.abspath(fn)
entry = entries.get(abs_path)
if entry is not None:
try:
stat = os.stat(fn)
if (entry.get('mtime') == stat.st_mtime
and entry.get('size') == stat.st_size
and entry.get('apply_reorder') == self.apply_reorder
and entry.get('dtype') == self.dtype.__name__):
cached_results.append((i, entry['info']))
continue
except OSError:
pass
files_to_process.append(fn)
process_indices.append(i)

# Process cache misses
if files_to_process:
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
fresh_results = list(executor.map(self._get_data_info, files_to_process))
else:
fresh_results = []

# Update cache entries with fresh results (skip errors)
for fn, info in zip(files_to_process, fresh_results):
if 'error' not in info:
abs_path = os.path.abspath(fn)
try:
stat = os.stat(fn)
entries[abs_path] = {
'mtime': stat.st_mtime,
'size': stat.st_size,
'apply_reorder': self.apply_reorder,
'dtype': self.dtype.__name__,
'info': info
}
except OSError:
pass

self._save_cache(entries)

# Reconstruct results in original file_list order
all_results = [None] * len(file_list)
for i, info in cached_results:
all_results[i] = info
for i, info in zip(process_indices, fresh_results):
all_results[i] = info

n_cached = len(cached_results)
n_processed = len(files_to_process)
if n_cached > 0:
print(f"MedDataset cache: {n_cached} cached, {n_processed} processed")
elif n_processed > 0:
print(f"MedDataset: processed {n_processed} files (results cached)")

return all_results

def calculate_target_size(self, target_spacing: list = None) -> list:
"""Calculate the target image size for the dataset.

Expand Down
Loading