Skip to content

Commit

Permalink
Update LazyLoader and LazyPredictor to be compatible with the new tes…
Browse files Browse the repository at this point in the history
…t time patching strategy
  • Loading branch information
wolny committed Apr 15, 2024
1 parent 3dce0d7 commit f7cd369
Show file tree
Hide file tree
Showing 14 changed files with 333 additions and 159 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ predict3dunet --config <CONFIG>
In order to predict on your own data, just provide the path to your model as well as paths to HDF5 test files (see example [test_config_segmentation.yaml](resources/3DUnet_confocal_boundary/test_config.yml)).

### Prediction tips
1. In order to avoid patch boundary artifacts in the output prediction masks the patch predictions are averaged, so make sure that `patch/stride` params lead to overlapping blocks, e.g. `patch: [64, 128, 128] stride: [32, 96, 96]` will give you a 'halo' of 32 voxels in each direction.
1. If you're running prediction for a large dataset, consider using `LazyHDF5Dataset` and `LazyPredictor` in the config. This will save memory by loading data on the fly at the cost of slower prediction time. See [test_config_lazy](resources/3DUnet_confocal_boundary/test_config_lazy.yml) for an example config.
2. If your model predicts multiple classes (see e.g. [train_config_multiclass](resources/3DUnet_multiclass/train_config.yaml)), consider saving only the final segmentation instead of the probability maps which can be time and space consuming.
To do so, set `save_segmentation: true` in the `predictor` section of the config (see [test_config_multiclass](resources/3DUnet_multiclass/test_config.yaml)).

Expand Down
200 changes: 141 additions & 59 deletions pytorch3dunet/datasets/hdf5.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import glob
import os
from abc import abstractmethod
from itertools import chain

import h5py
Expand All @@ -11,6 +12,24 @@
logger = get_logger('HDF5Dataset')


def _create_padded_indexes(indexes, halo_shape):
return tuple(slice(index.start, index.stop + 2 * halo) for index, halo in zip(indexes, halo_shape))


def traverse_h5_paths(file_paths):
assert isinstance(file_paths, list)
results = []
for file_path in file_paths:
if os.path.isdir(file_path):
# if file path is a directory take all H5 files in that directory
iters = [glob.glob(os.path.join(file_path, ext)) for ext in ['*.h5', '*.hdf', '*.hdf5', '*.hd5']]
for fp in chain(*iters):
results.append(fp)
else:
results.append(file_path)
return results


class AbstractHDF5Dataset(ConfigDataset):
"""
Implementation of torch.utils.data.Dataset backed by the HDF5 files, which iterates over the raw and label datasets
Expand All @@ -33,57 +52,83 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, r

self.phase = phase
self.file_path = file_path
self.raw_internal_path = raw_internal_path
self.label_internal_path = label_internal_path
self.weight_internal_path = weight_internal_path

input_file = h5py.File(file_path, 'r')

self.raw = self._load_dataset(input_file, raw_internal_path)
self.halo_shape = slice_builder_config.get('halo_shape', [0, 0, 0])
self.raw_padded = mirror_pad(self.raw, self.halo_shape)

stats = calculate_stats(self.raw, global_normalization)
if global_normalization:
logger.info('Calculating mean and std of the raw data...')
with h5py.File(file_path, 'r') as f:
raw = f[raw_internal_path][:]
stats = calculate_stats(raw)
else:
stats = calculate_stats(None, True)

self.transformer = transforms.Transformer(transformer_config, stats)
self.raw_transform = self.transformer.raw_transform()

if phase != 'test':
# create label/weight transform only in train/val phase
self.label_transform = self.transformer.label_transform()
self.label = self._load_dataset(input_file, label_internal_path)

if weight_internal_path is not None:
# look for the weight map in the raw file
self.weight_map = self._load_dataset(input_file, weight_internal_path)
self.weight_transform = self.transformer.weight_transform()
else:
self.weight_map = None
self.weight_transform = None

self._check_volume_sizes(self.raw, self.label)
self._check_volume_sizes()
else:
# 'test' phase used only for predictions so ignore the label dataset
self.label = None
self.weight_map = None

# build slice indices for raw and label data sets
slice_builder = get_slice_builder(self.raw, self.label, self.weight_map, slice_builder_config)
self.raw_slices = slice_builder.raw_slices
self.label_slices = slice_builder.label_slices
self.weight_slices = slice_builder.weight_slices
# compare patch and stride configuration
patch_shape = slice_builder_config.get('patch_shape')
stride_shape = slice_builder_config.get('stride_shape')
if patch_shape != stride_shape:
logger.warning(f'Patch shape and stride shape should be equal for optimal prediction performance,'
f'but found patch_shape: {patch_shape} and stride_shape: {stride_shape} in the config!'
f'Overriding stride_shape to match patch_shape!')
slice_builder_config['stride_shape'] = patch_shape

with h5py.File(file_path, 'r') as f:
raw = f[raw_internal_path]
label = f[label_internal_path] if phase != 'test' else None
weight_map = f[weight_internal_path] if weight_internal_path is not None else None
# build slice indices for raw and label data sets
slice_builder = get_slice_builder(raw, label, weight_map, slice_builder_config)
self.raw_slices = slice_builder.raw_slices
self.label_slices = slice_builder.label_slices
self.weight_slices = slice_builder.weight_slices

self.patch_count = len(self.raw_slices)
logger.info(f'Number of patches: {self.patch_count}')

def load_dataset(self, input_file, internal_path):
@abstractmethod
def get_raw_patch(self, idx):
raise NotImplementedError

@abstractmethod
def get_label_patch(self, idx):
raise NotImplementedError

@abstractmethod
def get_weight_patch(self, idx):
raise NotImplementedError

def _load_dataset(self, input_file, internal_path):
assert internal_path in input_file, f"Internal path: {internal_path} not found in the H5 file"
ds = self.load_dataset(input_file, internal_path)
assert ds.ndim in [3, 4], \
f"Invalid dataset dimension: {ds.ndim}. Supported dataset formats: (C, Z, Y, X) or (Z, Y, X)"
return ds
@abstractmethod
def get_raw_padded_patch(self, idx):
raise NotImplementedError

def _create_padded_indexes(self, indexes, halo_shape):
return tuple(slice(index.start, index.stop + 2 * halo) for index, halo in zip(indexes, halo_shape))
def volume_shape(self):
with h5py.File(self.file_path, 'r') as f:
raw = f[self.raw_internal_path]
if raw.ndim == 3:
return raw.shape
else:
return raw.shape[1:]

def __getitem__(self, idx):
if idx >= len(self):
Expand All @@ -95,39 +140,44 @@ def __getitem__(self, idx):
if len(raw_idx) == 4:
# discard the channel dimension in the slices: predictor requires only the spatial dimensions of the volume
raw_idx = raw_idx[1:] # Remove the first element if raw_idx has 4 elements
raw_idx_padded = (slice(None),) + self._create_padded_indexes(raw_idx, self.halo_shape)
raw_idx_padded = (slice(None),) + _create_padded_indexes(raw_idx, self.halo_shape)
else:
raw_idx_padded = self._create_padded_indexes(raw_idx, self.halo_shape)
raw_idx_padded = _create_padded_indexes(raw_idx, self.halo_shape)

raw_patch_transformed = self.raw_transform(self.raw_padded[raw_idx_padded])
raw_patch_transformed = self.raw_transform(self.get_raw_padded_patch(raw_idx_padded))
return raw_patch_transformed, raw_idx
else:
raw_patch_transformed = self.raw_transform(self.raw[raw_idx])
raw_patch_transformed = self.raw_transform(self.get_raw_patch(raw_idx))

# get the slice for a given index 'idx'
label_idx = self.label_slices[idx]
label_patch_transformed = self.label_transform(self.label[label_idx])
if self.weight_map is not None:
label_patch_transformed = self.label_transform(self.get_label_patch(label_idx))
if self.weight_internal_path is not None:
weight_idx = self.weight_slices[idx]
weight_patch_transformed = self.weight_transform(self.weight_map[weight_idx])
weight_patch_transformed = self.weight_transform(self.get_weight_patch(weight_idx))
return raw_patch_transformed, label_patch_transformed, weight_patch_transformed
# return the transformed raw and label patches
return raw_patch_transformed, label_patch_transformed

def __len__(self):
return self.patch_count

@staticmethod
def _check_volume_sizes(raw, label):
def _check_volume_sizes(self):
def _volume_shape(volume):
if volume.ndim == 3:
return volume.shape
return volume.shape[1:]

assert raw.ndim in [3, 4], 'Raw dataset must be 3D (DxHxW) or 4D (CxDxHxW)'
assert label.ndim in [3, 4], 'Label dataset must be 3D (DxHxW) or 4D (CxDxHxW)'

assert _volume_shape(raw) == _volume_shape(label), 'Raw and labels have to be of the same size'
with h5py.File(self.file_path, 'r') as f:
raw = f[self.raw_internal_path]
label = f[self.label_internal_path]
assert raw.ndim in [3, 4], 'Raw dataset must be 3D (DxHxW) or 4D (CxDxHxW)'
assert label.ndim in [3, 4], 'Label dataset must be 3D (DxHxW) or 4D (CxDxHxW)'
assert _volume_shape(raw) == _volume_shape(label), 'Raw and labels have to be of the same size'
if self.weight_internal_path is not None:
weight_map = f[self.weight_internal_path]
assert weight_map.ndim in [3, 4], 'Weight map dataset must be 3D (DxHxW) or 4D (CxDxHxW)'
assert _volume_shape(raw) == _volume_shape(weight_map), 'Raw and weight map have to be of the same size'

@classmethod
def create_datasets(cls, dataset_config, phase):
Expand All @@ -141,7 +191,7 @@ def create_datasets(cls, dataset_config, phase):
file_paths = phase_config['file_paths']
# file_paths may contain both files and directories; if the file_path is a directory all H5 files inside
# are going to be included in the final file_paths
file_paths = cls.traverse_h5_paths(file_paths)
file_paths = traverse_h5_paths(file_paths)

datasets = []
for file_path in file_paths:
Expand All @@ -160,20 +210,6 @@ def create_datasets(cls, dataset_config, phase):
logger.error(f'Skipping {phase} set: {file_path}', exc_info=True)
return datasets

@staticmethod
def traverse_h5_paths(file_paths):
assert isinstance(file_paths, list)
results = []
for file_path in file_paths:
if os.path.isdir(file_path):
# if file path is a directory take all H5 files in that directory
iters = [glob.glob(os.path.join(file_path, ext)) for ext in ['*.h5', '*.hdf', '*.hdf5', '*.hd5']]
for fp in chain(*iters):
results.append(fp)
else:
results.append(file_path)
return results


class StandardHDF5Dataset(AbstractHDF5Dataset):
"""
Expand All @@ -188,10 +224,38 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config,
transformer_config=transformer_config, raw_internal_path=raw_internal_path,
label_internal_path=label_internal_path, weight_internal_path=weight_internal_path,
global_normalization=global_normalization)

def load_dataset(self, input_file, internal_path):
# load the dataset from the H5 file into memory
return input_file[internal_path][:]
self._raw = None
self._raw_padded = None
self._label = None
self._weight_map = None

def get_raw_patch(self, idx):
if self._raw is None:
with h5py.File(self.file_path, 'r') as f:
assert self.raw_internal_path in f, f'Dataset {self.raw_internal_path} not found in {self.file_path}'
self._raw = f[self.raw_internal_path][:]
return self._raw[idx]

def get_label_patch(self, idx):
if self._label is None:
with h5py.File(self.file_path, 'r') as f:
assert self.label_internal_path in f, f'Dataset {self.label_internal_path} not found in {self.file_path}'
self._label = f[self.label_internal_path][:]
return self._label[idx]

def get_weight_patch(self, idx):
if self._weight_map is None:
with h5py.File(self.file_path, 'r') as f:
assert self.weight_internal_path in f, f'Dataset {self.weight_internal_path} not found in {self.file_path}'
self._weight_map = f[self.weight_internal_path][:]
return self._weight_map[idx]

def get_raw_padded_patch(self, idx):
if self._raw_padded is None:
with h5py.File(self.file_path, 'r') as f:
assert self.raw_internal_path in f, f'Dataset {self.raw_internal_path} not found in {self.file_path}'
self._raw_padded = mirror_pad(f[self.raw_internal_path][:], self.halo_shape)
return self._raw_padded[idx]


class LazyHDF5Dataset(AbstractHDF5Dataset):
Expand All @@ -207,6 +271,24 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config,

logger.info("Using LazyHDF5Dataset")

def load_dataset(self, input_file, internal_path):
# load the dataset from the H5 file lazily
return input_file[internal_path]
def get_raw_patch(self, idx):
with h5py.File(self.file_path, 'r') as f:
return f[self.raw_internal_path][idx]

def get_label_patch(self, idx):
with h5py.File(self.file_path, 'r') as f:
return f[self.label_internal_path][idx]

def get_weight_patch(self, idx):
with h5py.File(self.file_path, 'r') as f:
return f[self.weight_internal_path][idx]

def get_raw_padded_patch(self, idx):
with h5py.File(self.file_path, 'r+') as f:
if 'raw_padded' in f:
return f['raw_padded'][idx]

raw = f[self.raw_internal_path][:]
raw_padded = mirror_pad(raw, self.halo_shape)
f.create_dataset('raw_padded', data=raw_padded, compression='gzip')
return raw_padded[idx]
25 changes: 17 additions & 8 deletions pytorch3dunet/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections
from typing import Any

import numpy as np
import torch
Expand Down Expand Up @@ -286,16 +287,19 @@ def default_prediction_collate(batch):
raise TypeError((error_msg.format(type(batch[0]))))


def calculate_stats(images, global_normalization=True):
def calculate_stats(img: np.array, skip: bool = False) -> dict[str, Any]:
"""
Calculates min, max, mean, std given a list of nd-arrays
Calculates the minimum percentile, maximum percentile, mean, and standard deviation of the image.
Args:
img: The input image array.
skip: if True, skip the calculation and return None for all values.
Returns:
tuple[float, float, float, float]: The minimum percentile, maximum percentile, mean, and std dev
"""
if global_normalization:
# flatten first since the images might not be the same size
flat = np.concatenate(
[img.ravel() for img in images]
)
pmin, pmax, mean, std = np.percentile(flat, 1), np.percentile(flat, 99.6), np.mean(flat), np.std(flat)
if not skip:
pmin, pmax, mean, std = np.percentile(img, 1), np.percentile(img, 99.6), np.mean(img), np.std(img)
else:
pmin, pmax, mean, std = None, None, None, None

Expand Down Expand Up @@ -323,13 +327,18 @@ def mirror_pad(image, padding_shape):
Raises:
ValueError: If any element of padding_shape is negative.
"""
assert len(padding_shape) == 3, "Padding shape must be specified for each dimension: ZYX"

if any(p < 0 for p in padding_shape):
raise ValueError("padding_shape must be non-negative")

if all(p == 0 for p in padding_shape):
return image

pad_width = [(p, p) for p in padding_shape]

if image.ndim == 4:
pad_width = [(0, 0)] + pad_width
return np.pad(image, pad_width, mode='reflect')


Expand Down
Loading

0 comments on commit f7cd369

Please sign in to comment.