From 6fbb5a3567786f577199920f8f8502acc432e537 Mon Sep 17 00:00:00 2001 From: Adrian Wolny Date: Sun, 18 Feb 2024 23:43:17 +0100 Subject: [PATCH 1/2] fix lazy hdf5 loader --- pytorch3dunet/datasets/hdf5.py | 63 +++++++++------------------------ pytorch3dunet/datasets/utils.py | 10 +++--- 2 files changed, 22 insertions(+), 51 deletions(-) diff --git a/pytorch3dunet/datasets/hdf5.py b/pytorch3dunet/datasets/hdf5.py index d4a02e1c..60ff3c09 100644 --- a/pytorch3dunet/datasets/hdf5.py +++ b/pytorch3dunet/datasets/hdf5.py @@ -34,9 +34,9 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, r self.phase = phase self.file_path = file_path - input_file = self.create_h5_file(file_path) + input_file = h5py.File(file_path, 'r') - self.raw = self.load_dataset(input_file, raw_internal_path) + self.raw = self._load_dataset(input_file, raw_internal_path) stats = calculate_stats(self.raw, global_normalization) @@ -46,11 +46,11 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, r if phase != 'test': # create label/weight transform only in train/val phase self.label_transform = self.transformer.label_transform() - self.label = self.load_dataset(input_file, label_internal_path) + self.label = self._load_dataset(input_file, label_internal_path) if weight_internal_path is not None: # look for the weight map in the raw file - self.weight_map = self.load_dataset(input_file, weight_internal_path) + self.weight_map = self._load_dataset(input_file, weight_internal_path) self.weight_transform = self.transformer.weight_transform() else: self.weight_map = None @@ -70,10 +70,12 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, r self.patch_count = len(self.raw_slices) logger.info(f'Number of patches: {self.patch_count}') - @staticmethod - def load_dataset(input_file, internal_path): + def load_dataset(self, input_file, internal_path): + raise NotImplementedError + + def _load_dataset(self, input_file, internal_path): assert internal_path in input_file, f"Internal path: {internal_path} not found in the H5 file" - ds = input_file[internal_path][:] + ds = self.load_dataset(input_file, internal_path) assert ds.ndim in [3, 4], \ f"Invalid dataset dimension: {ds.ndim}. Supported dataset formats: (C, Z, Y, X) or (Z, Y, X)" return ds @@ -106,10 +108,6 @@ def __getitem__(self, idx): def __len__(self): return self.patch_count - @staticmethod - def create_h5_file(file_path): - raise NotImplementedError - @staticmethod def _check_volume_sizes(raw, label): def _volume_shape(volume): @@ -182,9 +180,9 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, label_internal_path=label_internal_path, weight_internal_path=weight_internal_path, global_normalization=global_normalization) - @staticmethod - def create_h5_file(file_path): - return h5py.File(file_path, 'r') + def load_dataset(self, input_file, internal_path): + # load the dataset from the H5 file into memory + return input_file[internal_path][:] class LazyHDF5Dataset(AbstractHDF5Dataset): @@ -198,37 +196,8 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, label_internal_path=label_internal_path, weight_internal_path=weight_internal_path, global_normalization=global_normalization) - logger.info("Using modified HDF5Dataset!") - - @staticmethod - def create_h5_file(file_path): - return LazyHDF5File(file_path) - - -class LazyHDF5File: - """Implementation of the LazyHDF5File class for the LazyHDF5Dataset.""" - - def __init__(self, path, internal_path=None): - self.path = path - self.internal_path = internal_path - if self.internal_path: - with h5py.File(self.path, "r") as f: - self.ndim = f[self.internal_path].ndim - self.shape = f[self.internal_path].shape - - def ravel(self): - with h5py.File(self.path, "r") as f: - data = f[self.internal_path][:].ravel() - return data - - def __getitem__(self, arg): - if isinstance(arg, str) and not self.internal_path: - return LazyHDF5File(self.path, arg) - - if arg == Ellipsis: - return LazyHDF5File(self.path, self.internal_path) - - with h5py.File(self.path, "r") as f: - data = f[self.internal_path][arg] + logger.info("Using LazyHDF5Dataset") - return data + def load_dataset(self, input_file, internal_path): + # load the dataset from the H5 file lazily + return input_file[internal_path] diff --git a/pytorch3dunet/datasets/utils.py b/pytorch3dunet/datasets/utils.py index 6ce311d6..99e4a255 100644 --- a/pytorch3dunet/datasets/utils.py +++ b/pytorch3dunet/datasets/utils.py @@ -134,7 +134,7 @@ class FilterSliceBuilder(SliceBuilder): Filter patches containing more than `1 - threshold` of ignore_index label """ - def __init__(self, raw_dataset, label_dataset, weight_dataset, patch_shape, stride_shape, ignore_index=(0,), + def __init__(self, raw_dataset, label_dataset, weight_dataset, patch_shape, stride_shape, ignore_index=None, threshold=0.6, slack_acceptance=0.01, **kwargs): super().__init__(raw_dataset, label_dataset, weight_dataset, patch_shape, stride_shape, **kwargs) if label_dataset is None: @@ -144,15 +144,17 @@ def __init__(self, raw_dataset, label_dataset, weight_dataset, patch_shape, stri def ignore_predicate(raw_label_idx): label_idx = raw_label_idx[1] - patch = np.copy(label_dataset[label_idx]) - for ii in ignore_index: - patch[patch == ii] = 0 + patch = label_dataset[label_idx] + if ignore_index is not None: + patch = np.copy(patch) + patch[patch == ignore_index] = 0 non_ignore_counts = np.count_nonzero(patch != 0) non_ignore_counts = non_ignore_counts / patch.size return non_ignore_counts > threshold or rand_state.rand() < slack_acceptance zipped_slices = zip(self.raw_slices, self.label_slices) # ignore slices containing too much ignore_index + logger.info(f'Filtering slices...') filtered_slices = list(filter(ignore_predicate, zipped_slices)) # unzip and save slices raw_slices, label_slices = zip(*filtered_slices) From df8cf5c03994bd2fc1b37ff59604becbad973a36 Mon Sep 17 00:00:00 2001 From: Adrian Wolny Date: Sun, 18 Feb 2024 23:45:53 +0100 Subject: [PATCH 2/2] fallback to conda build, due to mambabuild problems --- .github/workflows/conda-build.yml | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/workflows/conda-build.yml b/.github/workflows/conda-build.yml index 3aa39e07..a0c55c96 100644 --- a/.github/workflows/conda-build.yml +++ b/.github/workflows/conda-build.yml @@ -20,15 +20,14 @@ jobs: channel-priority: false - shell: bash -l {0} run: conda info --envs - - name: Build pytorch-3dunet using Boa + - name: Build pytorch-3dunet shell: bash -l {0} run: | - conda install --yes -c conda-forge mamba - mamba install -q boa - conda mambabuild -c pytorch -c nvidia -c conda-forge conda-recipe + conda install -q conda-build + conda build -c pytorch -c nvidia -c conda-forge conda-recipe - name: Create pytorch3dunet env run: | - mamba create -n pytorch3dunet -c pytorch -c nvidia -c conda-forge pytorch-3dunet pytest + conda create -n pytorch3dunet -c pytorch -c nvidia -c conda-forge pytorch-3dunet pytest - name: Run pytest shell: bash -l {0} run: |