Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix lazy HDF5 loader #110

Merged
merged 2 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions .github/workflows/conda-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@ jobs:
channel-priority: false
- shell: bash -l {0}
run: conda info --envs
- name: Build pytorch-3dunet using Boa
- name: Build pytorch-3dunet
shell: bash -l {0}
run: |
conda install --yes -c conda-forge mamba
mamba install -q boa
conda mambabuild -c pytorch -c nvidia -c conda-forge conda-recipe
conda install -q conda-build
conda build -c pytorch -c nvidia -c conda-forge conda-recipe
- name: Create pytorch3dunet env
run: |
mamba create -n pytorch3dunet -c pytorch -c nvidia -c conda-forge pytorch-3dunet pytest
conda create -n pytorch3dunet -c pytorch -c nvidia -c conda-forge pytorch-3dunet pytest
- name: Run pytest
shell: bash -l {0}
run: |
Expand Down
63 changes: 16 additions & 47 deletions pytorch3dunet/datasets/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, r
self.phase = phase
self.file_path = file_path

input_file = self.create_h5_file(file_path)
input_file = h5py.File(file_path, 'r')

self.raw = self.load_dataset(input_file, raw_internal_path)
self.raw = self._load_dataset(input_file, raw_internal_path)

stats = calculate_stats(self.raw, global_normalization)

Expand All @@ -46,11 +46,11 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, r
if phase != 'test':
# create label/weight transform only in train/val phase
self.label_transform = self.transformer.label_transform()
self.label = self.load_dataset(input_file, label_internal_path)
self.label = self._load_dataset(input_file, label_internal_path)

if weight_internal_path is not None:
# look for the weight map in the raw file
self.weight_map = self.load_dataset(input_file, weight_internal_path)
self.weight_map = self._load_dataset(input_file, weight_internal_path)
self.weight_transform = self.transformer.weight_transform()
else:
self.weight_map = None
Expand All @@ -70,10 +70,12 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, r
self.patch_count = len(self.raw_slices)
logger.info(f'Number of patches: {self.patch_count}')

@staticmethod
def load_dataset(input_file, internal_path):
def load_dataset(self, input_file, internal_path):
raise NotImplementedError

def _load_dataset(self, input_file, internal_path):
assert internal_path in input_file, f"Internal path: {internal_path} not found in the H5 file"
ds = input_file[internal_path][:]
ds = self.load_dataset(input_file, internal_path)
assert ds.ndim in [3, 4], \
f"Invalid dataset dimension: {ds.ndim}. Supported dataset formats: (C, Z, Y, X) or (Z, Y, X)"
return ds
Expand Down Expand Up @@ -106,10 +108,6 @@ def __getitem__(self, idx):
def __len__(self):
return self.patch_count

@staticmethod
def create_h5_file(file_path):
raise NotImplementedError

@staticmethod
def _check_volume_sizes(raw, label):
def _volume_shape(volume):
Expand Down Expand Up @@ -182,9 +180,9 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config,
label_internal_path=label_internal_path, weight_internal_path=weight_internal_path,
global_normalization=global_normalization)

@staticmethod
def create_h5_file(file_path):
return h5py.File(file_path, 'r')
def load_dataset(self, input_file, internal_path):
# load the dataset from the H5 file into memory
return input_file[internal_path][:]


class LazyHDF5Dataset(AbstractHDF5Dataset):
Expand All @@ -198,37 +196,8 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config,
label_internal_path=label_internal_path, weight_internal_path=weight_internal_path,
global_normalization=global_normalization)

logger.info("Using modified HDF5Dataset!")

@staticmethod
def create_h5_file(file_path):
return LazyHDF5File(file_path)


class LazyHDF5File:
"""Implementation of the LazyHDF5File class for the LazyHDF5Dataset."""

def __init__(self, path, internal_path=None):
self.path = path
self.internal_path = internal_path
if self.internal_path:
with h5py.File(self.path, "r") as f:
self.ndim = f[self.internal_path].ndim
self.shape = f[self.internal_path].shape

def ravel(self):
with h5py.File(self.path, "r") as f:
data = f[self.internal_path][:].ravel()
return data

def __getitem__(self, arg):
if isinstance(arg, str) and not self.internal_path:
return LazyHDF5File(self.path, arg)

if arg == Ellipsis:
return LazyHDF5File(self.path, self.internal_path)

with h5py.File(self.path, "r") as f:
data = f[self.internal_path][arg]
logger.info("Using LazyHDF5Dataset")

return data
def load_dataset(self, input_file, internal_path):
# load the dataset from the H5 file lazily
return input_file[internal_path]
10 changes: 6 additions & 4 deletions pytorch3dunet/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class FilterSliceBuilder(SliceBuilder):
Filter patches containing more than `1 - threshold` of ignore_index label
"""

def __init__(self, raw_dataset, label_dataset, weight_dataset, patch_shape, stride_shape, ignore_index=(0,),
def __init__(self, raw_dataset, label_dataset, weight_dataset, patch_shape, stride_shape, ignore_index=None,
threshold=0.6, slack_acceptance=0.01, **kwargs):
super().__init__(raw_dataset, label_dataset, weight_dataset, patch_shape, stride_shape, **kwargs)
if label_dataset is None:
Expand All @@ -144,15 +144,17 @@ def __init__(self, raw_dataset, label_dataset, weight_dataset, patch_shape, stri

def ignore_predicate(raw_label_idx):
label_idx = raw_label_idx[1]
patch = np.copy(label_dataset[label_idx])
for ii in ignore_index:
patch[patch == ii] = 0
patch = label_dataset[label_idx]
if ignore_index is not None:
patch = np.copy(patch)
patch[patch == ignore_index] = 0
non_ignore_counts = np.count_nonzero(patch != 0)
non_ignore_counts = non_ignore_counts / patch.size
return non_ignore_counts > threshold or rand_state.rand() < slack_acceptance

zipped_slices = zip(self.raw_slices, self.label_slices)
# ignore slices containing too much ignore_index
logger.info(f'Filtering slices...')
filtered_slices = list(filter(ignore_predicate, zipped_slices))
# unzip and save slices
raw_slices, label_slices = zip(*filtered_slices)
Expand Down
Loading