From ec3134c8d87862e1044867ae279e8cf82fcddc1c Mon Sep 17 00:00:00 2001 From: shota_mizusaki Date: Wed, 22 May 2024 00:48:19 +0900 Subject: [PATCH 1/4] first commit --- build/lib/pytorch3dunet/__init__.py | 1 + build/lib/pytorch3dunet/__version__.py | 1 + build/lib/pytorch3dunet/augment/__init__.py | 0 build/lib/pytorch3dunet/augment/transforms.py | 761 ++++++++++++++++++ build/lib/pytorch3dunet/datasets/__init__.py | 0 build/lib/pytorch3dunet/datasets/dsb.py | 108 +++ build/lib/pytorch3dunet/datasets/hdf5.py | 293 +++++++ build/lib/pytorch3dunet/datasets/utils.py | 361 +++++++++ build/lib/pytorch3dunet/predict.py | 59 ++ build/lib/pytorch3dunet/train.py | 35 + build/lib/pytorch3dunet/unet3d/__init__.py | 0 .../pytorch3dunet/unet3d/buildingblocks.py | 545 +++++++++++++ build/lib/pytorch3dunet/unet3d/config.py | 79 ++ build/lib/pytorch3dunet/unet3d/losses.py | 345 ++++++++ build/lib/pytorch3dunet/unet3d/metrics.py | 445 ++++++++++ build/lib/pytorch3dunet/unet3d/model.py | 249 ++++++ build/lib/pytorch3dunet/unet3d/predictor.py | 281 +++++++ build/lib/pytorch3dunet/unet3d/se.py | 113 +++ build/lib/pytorch3dunet/unet3d/seg_metrics.py | 123 +++ build/lib/pytorch3dunet/unet3d/trainer.py | 404 ++++++++++ build/lib/pytorch3dunet/unet3d/utils.py | 366 +++++++++ 21 files changed, 4569 insertions(+) create mode 100644 build/lib/pytorch3dunet/__init__.py create mode 100644 build/lib/pytorch3dunet/__version__.py create mode 100644 build/lib/pytorch3dunet/augment/__init__.py create mode 100644 build/lib/pytorch3dunet/augment/transforms.py create mode 100644 build/lib/pytorch3dunet/datasets/__init__.py create mode 100644 build/lib/pytorch3dunet/datasets/dsb.py create mode 100644 build/lib/pytorch3dunet/datasets/hdf5.py create mode 100644 build/lib/pytorch3dunet/datasets/utils.py create mode 100644 build/lib/pytorch3dunet/predict.py create mode 100644 build/lib/pytorch3dunet/train.py create mode 100644 build/lib/pytorch3dunet/unet3d/__init__.py create mode 100644 build/lib/pytorch3dunet/unet3d/buildingblocks.py create mode 100644 build/lib/pytorch3dunet/unet3d/config.py create mode 100644 build/lib/pytorch3dunet/unet3d/losses.py create mode 100644 build/lib/pytorch3dunet/unet3d/metrics.py create mode 100644 build/lib/pytorch3dunet/unet3d/model.py create mode 100644 build/lib/pytorch3dunet/unet3d/predictor.py create mode 100644 build/lib/pytorch3dunet/unet3d/se.py create mode 100644 build/lib/pytorch3dunet/unet3d/seg_metrics.py create mode 100644 build/lib/pytorch3dunet/unet3d/trainer.py create mode 100644 build/lib/pytorch3dunet/unet3d/utils.py diff --git a/build/lib/pytorch3dunet/__init__.py b/build/lib/pytorch3dunet/__init__.py new file mode 100644 index 00000000..9226fe7e --- /dev/null +++ b/build/lib/pytorch3dunet/__init__.py @@ -0,0 +1 @@ +from .__version__ import __version__ diff --git a/build/lib/pytorch3dunet/__version__.py b/build/lib/pytorch3dunet/__version__.py new file mode 100644 index 00000000..655be529 --- /dev/null +++ b/build/lib/pytorch3dunet/__version__.py @@ -0,0 +1 @@ +__version__ = '1.8.7' diff --git a/build/lib/pytorch3dunet/augment/__init__.py b/build/lib/pytorch3dunet/augment/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/pytorch3dunet/augment/transforms.py b/build/lib/pytorch3dunet/augment/transforms.py new file mode 100644 index 00000000..527d596b --- /dev/null +++ b/build/lib/pytorch3dunet/augment/transforms.py @@ -0,0 +1,761 @@ +import importlib +import random + +import numpy as np +import torch +from scipy.ndimage import rotate, map_coordinates, gaussian_filter, convolve +from skimage import measure +from skimage.filters import gaussian +from skimage.segmentation import find_boundaries + +# WARN: use fixed random state for reproducibility; if you want to randomize on each run seed with `time.time()` e.g. +GLOBAL_RANDOM_STATE = np.random.RandomState(47) + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, m): + for t in self.transforms: + m = t(m) + return m + + +class RandomFlip: + """ + Randomly flips the image across the given axes. Image can be either 3D (DxHxW) or 4D (CxDxHxW). + + When creating make sure that the provided RandomStates are consistent between raw and labeled datasets, + otherwise the models won't converge. + """ + + def __init__(self, random_state, axis_prob=0.5, **kwargs): + assert random_state is not None, 'RandomState cannot be None' + self.random_state = random_state + self.axes = (0, 1, 2) + self.axis_prob = axis_prob + + def __call__(self, m): + assert m.ndim in [3, 4], 'Supports only 3D (DxHxW) or 4D (CxDxHxW) images' + + for axis in self.axes: + if self.random_state.uniform() > self.axis_prob: + if m.ndim == 3: + m = np.flip(m, axis) + else: + channels = [np.flip(m[c], axis) for c in range(m.shape[0])] + m = np.stack(channels, axis=0) + + return m + + +class RandomRotate90: + """ + Rotate an array by 90 degrees around a randomly chosen plane. Image can be either 3D (DxHxW) or 4D (CxDxHxW). + + When creating make sure that the provided RandomStates are consistent between raw and labeled datasets, + otherwise the models won't converge. + + IMPORTANT: assumes DHW axis order (that's why rotation is performed across (1,2) axis) + """ + + def __init__(self, random_state, **kwargs): + self.random_state = random_state + # always rotate around z-axis + self.axis = (1, 2) + + def __call__(self, m): + assert m.ndim in [3, 4], 'Supports only 3D (DxHxW) or 4D (CxDxHxW) images' + + # pick number of rotations at random + k = self.random_state.randint(0, 4) + # rotate k times around a given plane + if m.ndim == 3: + m = np.rot90(m, k, self.axis) + else: + channels = [np.rot90(m[c], k, self.axis) for c in range(m.shape[0])] + m = np.stack(channels, axis=0) + + return m + + +class RandomRotate: + """ + Rotate an array by a random degrees from taken from (-angle_spectrum, angle_spectrum) interval. + Rotation axis is picked at random from the list of provided axes. + """ + + def __init__(self, random_state, angle_spectrum=30, axes=None, mode='reflect', order=0, **kwargs): + if axes is None: + axes = [(1, 0), (2, 1), (2, 0)] + else: + assert isinstance(axes, list) and len(axes) > 0 + + self.random_state = random_state + self.angle_spectrum = angle_spectrum + self.axes = axes + self.mode = mode + self.order = order + + def __call__(self, m): + axis = self.axes[self.random_state.randint(len(self.axes))] + angle = self.random_state.randint(-self.angle_spectrum, self.angle_spectrum) + + if m.ndim == 3: + m = rotate(m, angle, axes=axis, reshape=False, order=self.order, mode=self.mode, cval=-1) + else: + channels = [rotate(m[c], angle, axes=axis, reshape=False, order=self.order, mode=self.mode, cval=-1) for c + in range(m.shape[0])] + m = np.stack(channels, axis=0) + + return m + + +class RandomContrast: + """ + Adjust contrast by scaling each voxel to `mean + alpha * (v - mean)`. + """ + + def __init__(self, random_state, alpha=(0.5, 1.5), mean=0.0, execution_probability=0.1, **kwargs): + self.random_state = random_state + assert len(alpha) == 2 + self.alpha = alpha + self.mean = mean + self.execution_probability = execution_probability + + def __call__(self, m): + if self.random_state.uniform() < self.execution_probability: + alpha = self.random_state.uniform(self.alpha[0], self.alpha[1]) + result = self.mean + alpha * (m - self.mean) + return np.clip(result, -1, 1) + + return m + + +# it's relatively slow, i.e. ~1s per patch of size 64x200x200, so use multiple workers in the DataLoader +# remember to use spline_order=0 when transforming the labels +class ElasticDeformation: + """ + Apply elasitc deformations of 3D patches on a per-voxel mesh. Assumes ZYX axis order (or CZYX if the data is 4D). + Based on: https://github.com/fcalvet/image_tools/blob/master/image_augmentation.py#L62 + """ + + def __init__(self, random_state, spline_order, alpha=2000, sigma=50, execution_probability=0.1, apply_3d=True, + **kwargs): + """ + :param spline_order: the order of spline interpolation (use 0 for labeled images) + :param alpha: scaling factor for deformations + :param sigma: smoothing factor for Gaussian filter + :param execution_probability: probability of executing this transform + :param apply_3d: if True apply deformations in each axis + """ + self.random_state = random_state + self.spline_order = spline_order + self.alpha = alpha + self.sigma = sigma + self.execution_probability = execution_probability + self.apply_3d = apply_3d + + def __call__(self, m): + if self.random_state.uniform() < self.execution_probability: + assert m.ndim in [3, 4] + + if m.ndim == 3: + volume_shape = m.shape + else: + volume_shape = m[0].shape + + if self.apply_3d: + dz = gaussian_filter(self.random_state.randn(*volume_shape), self.sigma, mode="reflect") * self.alpha + else: + dz = np.zeros_like(m) + + dy, dx = [ + gaussian_filter( + self.random_state.randn(*volume_shape), + self.sigma, mode="reflect" + ) * self.alpha for _ in range(2) + ] + + z_dim, y_dim, x_dim = volume_shape + z, y, x = np.meshgrid(np.arange(z_dim), np.arange(y_dim), np.arange(x_dim), indexing='ij') + indices = z + dz, y + dy, x + dx + + if m.ndim == 3: + return map_coordinates(m, indices, order=self.spline_order, mode='reflect') + else: + channels = [map_coordinates(c, indices, order=self.spline_order, mode='reflect') for c in m] + return np.stack(channels, axis=0) + + return m + + +class CropToFixed: + def __init__(self, random_state, size=(256, 256), centered=False, **kwargs): + self.random_state = random_state + self.crop_y, self.crop_x = size + self.centered = centered + + def __call__(self, m): + def _padding(pad_total): + half_total = pad_total // 2 + return (half_total, pad_total - half_total) + + def _rand_range_and_pad(crop_size, max_size): + """ + Returns a tuple: + max_value (int) for the corner dimension. The corner dimension is chosen as `self.random_state(max_value)` + pad (int): padding in both directions; if crop_size is lt max_size the pad is 0 + """ + if crop_size < max_size: + return max_size - crop_size, (0, 0) + else: + return 1, _padding(crop_size - max_size) + + def _start_and_pad(crop_size, max_size): + if crop_size < max_size: + return (max_size - crop_size) // 2, (0, 0) + else: + return 0, _padding(crop_size - max_size) + + assert m.ndim in (3, 4) + if m.ndim == 3: + _, y, x = m.shape + else: + _, _, y, x = m.shape + + if not self.centered: + y_range, y_pad = _rand_range_and_pad(self.crop_y, y) + x_range, x_pad = _rand_range_and_pad(self.crop_x, x) + + y_start = self.random_state.randint(y_range) + x_start = self.random_state.randint(x_range) + + else: + y_start, y_pad = _start_and_pad(self.crop_y, y) + x_start, x_pad = _start_and_pad(self.crop_x, x) + + if m.ndim == 3: + result = m[:, y_start:y_start + self.crop_y, x_start:x_start + self.crop_x] + return np.pad(result, pad_width=((0, 0), y_pad, x_pad), mode='reflect') + else: + channels = [] + for c in range(m.shape[0]): + result = m[c][:, y_start:y_start + self.crop_y, x_start:x_start + self.crop_x] + channels.append(np.pad(result, pad_width=((0, 0), y_pad, x_pad), mode='reflect')) + return np.stack(channels, axis=0) + + +class AbstractLabelToBoundary: + AXES_TRANSPOSE = [ + (0, 1, 2), # X + (0, 2, 1), # Y + (2, 0, 1) # Z + ] + + def __init__(self, ignore_index=None, aggregate_affinities=False, append_label=False, **kwargs): + """ + :param ignore_index: label to be ignored in the output, i.e. after computing the boundary the label ignore_index + will be restored where is was in the patch originally + :param aggregate_affinities: aggregate affinities with the same offset across Z,Y,X axes + :param append_label: if True append the orignal ground truth labels to the last channel + :param blur: Gaussian blur the boundaries + :param sigma: standard deviation for Gaussian kernel + """ + self.ignore_index = ignore_index + self.aggregate_affinities = aggregate_affinities + self.append_label = append_label + + def __call__(self, m): + """ + Extract boundaries from a given 3D label tensor. + :param m: input 3D tensor + :return: binary mask, with 1-label corresponding to the boundary and 0-label corresponding to the background + """ + assert m.ndim == 3 + + kernels = self.get_kernels() + boundary_arr = [np.where(np.abs(convolve(m, kernel)) > 0, 1, 0) for kernel in kernels] + channels = np.stack(boundary_arr) + results = [] + if self.aggregate_affinities: + assert len(kernels) % 3 == 0, "Number of kernels must be divided by 3 (one kernel per offset per Z,Y,X axes" + # aggregate affinities with the same offset + for i in range(0, len(kernels), 3): + # merge across X,Y,Z axes (logical OR) + xyz_aggregated_affinities = np.logical_or.reduce(channels[i:i + 3, ...]).astype(np.int32) + # recover ignore index + xyz_aggregated_affinities = _recover_ignore_index(xyz_aggregated_affinities, m, self.ignore_index) + results.append(xyz_aggregated_affinities) + else: + results = [_recover_ignore_index(channels[i], m, self.ignore_index) for i in range(channels.shape[0])] + + if self.append_label: + # append original input data + results.append(m) + + # stack across channel dim + return np.stack(results, axis=0) + + @staticmethod + def create_kernel(axis, offset): + # create conv kernel + k_size = offset + 1 + k = np.zeros((1, 1, k_size), dtype=np.int32) + k[0, 0, 0] = 1 + k[0, 0, offset] = -1 + return np.transpose(k, axis) + + def get_kernels(self): + raise NotImplementedError + + +class StandardLabelToBoundary: + def __init__(self, ignore_index=None, append_label=False, mode='thick', foreground=False, + **kwargs): + self.ignore_index = ignore_index + self.append_label = append_label + self.mode = mode + self.foreground = foreground + + def __call__(self, m): + assert m.ndim == 3 + + boundaries = find_boundaries(m, connectivity=2, mode=self.mode) + boundaries = boundaries.astype('int32') + + results = [] + if self.foreground: + foreground = (m > 0).astype('uint8') + results.append(_recover_ignore_index(foreground, m, self.ignore_index)) + + results.append(_recover_ignore_index(boundaries, m, self.ignore_index)) + + if self.append_label: + # append original input data + results.append(m) + + return np.stack(results, axis=0) + + +class BlobsToMask: + """ + Returns binary mask from labeled image, i.e. every label greater than 0 is treated as foreground. + + """ + + def __init__(self, append_label=False, boundary=False, cross_entropy=False, **kwargs): + self.cross_entropy = cross_entropy + self.boundary = boundary + self.append_label = append_label + + def __call__(self, m): + assert m.ndim == 3 + + # get the segmentation mask + mask = (m > 0).astype('uint8') + results = [mask] + + if self.boundary: + outer = find_boundaries(m, connectivity=2, mode='outer') + if self.cross_entropy: + # boundary is class 2 + mask[outer > 0] = 2 + results = [mask] + else: + results.append(outer) + + if self.append_label: + results.append(m) + + return np.stack(results, axis=0) + + +class RandomLabelToAffinities(AbstractLabelToBoundary): + """ + Converts a given volumetric label array to binary mask corresponding to borders between labels. + One specify the max_offset (thickness) of the border. Then the offset is picked at random every time you call + the transformer (offset is picked form the range 1:max_offset) for each axis and the boundary computed. + One may use this scheme in order to make the network more robust against various thickness of borders in the ground + truth (think of it as a boundary denoising scheme). + """ + + def __init__(self, random_state, max_offset=10, ignore_index=None, append_label=False, z_offset_scale=2, **kwargs): + super().__init__(ignore_index=ignore_index, append_label=append_label, aggregate_affinities=False) + self.random_state = random_state + self.offsets = tuple(range(1, max_offset + 1)) + self.z_offset_scale = z_offset_scale + + def get_kernels(self): + rand_offset = self.random_state.choice(self.offsets) + axis_ind = self.random_state.randint(3) + # scale down z-affinities due to anisotropy + if axis_ind == 2: + rand_offset = max(1, rand_offset // self.z_offset_scale) + + rand_axis = self.AXES_TRANSPOSE[axis_ind] + # return a single kernel + return [self.create_kernel(rand_axis, rand_offset)] + + +class LabelToAffinities(AbstractLabelToBoundary): + """ + Converts a given volumetric label array to binary mask corresponding to borders between labels (which can be seen + as an affinity graph: https://arxiv.org/pdf/1706.00120.pdf) + One specify the offsets (thickness) of the border. The boundary will be computed via the convolution operator. + """ + + def __init__(self, offsets, ignore_index=None, append_label=False, aggregate_affinities=False, z_offsets=None, + **kwargs): + super().__init__(ignore_index=ignore_index, append_label=append_label, + aggregate_affinities=aggregate_affinities) + + assert isinstance(offsets, list) or isinstance(offsets, tuple), 'offsets must be a list or a tuple' + assert all(a > 0 for a in offsets), "'offsets must be positive" + assert len(set(offsets)) == len(offsets), "'offsets' must be unique" + if z_offsets is not None: + assert len(offsets) == len(z_offsets), 'z_offsets length must be the same as the length of offsets' + else: + # if z_offsets is None just use the offsets for z-affinities + z_offsets = list(offsets) + self.z_offsets = z_offsets + + self.kernels = [] + # create kernel for every axis-offset pair + for xy_offset, z_offset in zip(offsets, z_offsets): + for axis_ind, axis in enumerate(self.AXES_TRANSPOSE): + final_offset = xy_offset + if axis_ind == 2: + final_offset = z_offset + # create kernels for a given offset in every direction + self.kernels.append(self.create_kernel(axis, final_offset)) + + def get_kernels(self): + return self.kernels + + +class LabelToZAffinities(AbstractLabelToBoundary): + """ + Converts a given volumetric label array to binary mask corresponding to borders between labels (which can be seen + as an affinity graph: https://arxiv.org/pdf/1706.00120.pdf) + One specify the offsets (thickness) of the border. The boundary will be computed via the convolution operator. + """ + + def __init__(self, offsets, ignore_index=None, append_label=False, **kwargs): + super().__init__(ignore_index=ignore_index, append_label=append_label) + + assert isinstance(offsets, list) or isinstance(offsets, tuple), 'offsets must be a list or a tuple' + assert all(a > 0 for a in offsets), "'offsets must be positive" + assert len(set(offsets)) == len(offsets), "'offsets' must be unique" + + self.kernels = [] + z_axis = self.AXES_TRANSPOSE[2] + # create kernels + for z_offset in offsets: + self.kernels.append(self.create_kernel(z_axis, z_offset)) + + def get_kernels(self): + return self.kernels + + +class LabelToBoundaryAndAffinities: + """ + Combines the StandardLabelToBoundary and LabelToAffinities in the hope + that that training the network to predict both would improve the main task: boundary prediction. + """ + + def __init__(self, xy_offsets, z_offsets, append_label=False, blur=False, sigma=1, ignore_index=None, mode='thick', + foreground=False, **kwargs): + # blur only StandardLabelToBoundary results; we don't want to blur the affinities + self.l2b = StandardLabelToBoundary(blur=blur, sigma=sigma, ignore_index=ignore_index, mode=mode, + foreground=foreground) + self.l2a = LabelToAffinities(offsets=xy_offsets, z_offsets=z_offsets, append_label=append_label, + ignore_index=ignore_index) + + def __call__(self, m): + boundary = self.l2b(m) + affinities = self.l2a(m) + return np.concatenate((boundary, affinities), axis=0) + + +class LabelToMaskAndAffinities: + def __init__(self, xy_offsets, z_offsets, append_label=False, background=0, ignore_index=None, **kwargs): + self.background = background + self.l2a = LabelToAffinities(offsets=xy_offsets, z_offsets=z_offsets, append_label=append_label, + ignore_index=ignore_index) + + def __call__(self, m): + mask = m > self.background + mask = np.expand_dims(mask.astype(np.uint8), axis=0) + affinities = self.l2a(m) + return np.concatenate((mask, affinities), axis=0) + + +class Standardize: + """ + Apply Z-score normalization to a given input tensor, i.e. re-scaling the values to be 0-mean and 1-std. + """ + + def __init__(self, eps=1e-10, mean=None, std=None, channelwise=False, **kwargs): + if mean is not None or std is not None: + assert mean is not None and std is not None + self.mean = mean + self.std = std + self.eps = eps + self.channelwise = channelwise + + def __call__(self, m): + if self.mean is not None: + mean, std = self.mean, self.std + else: + if self.channelwise: + # normalize per-channel + axes = list(range(m.ndim)) + # average across channels + axes = tuple(axes[1:]) + mean = np.mean(m, axis=axes, keepdims=True) + std = np.std(m, axis=axes, keepdims=True) + else: + mean = np.mean(m) + std = np.std(m) + + return (m - mean) / np.clip(std, a_min=self.eps, a_max=None) + + +class PercentileNormalizer: + def __init__(self, pmin=1, pmax=99.6, channelwise=False, eps=1e-10, **kwargs): + self.eps = eps + self.pmin = pmin + self.pmax = pmax + self.channelwise = channelwise + + def __call__(self, m): + if self.channelwise: + axes = list(range(m.ndim)) + # average across channels + axes = tuple(axes[1:]) + pmin = np.percentile(m, self.pmin, axis=axes, keepdims=True) + pmax = np.percentile(m, self.pmax, axis=axes, keepdims=True) + else: + pmin = np.percentile(m, self.pmin) + pmax = np.percentile(m, self.pmax) + + return (m - pmin) / (pmax - pmin + self.eps) + + +class Normalize: + """ + Apply simple min-max scaling to a given input tensor, i.e. shrinks the range of the data + in a fixed range of [-1, 1] or in case of norm01==True to [0, 1]. In addition, data can be + clipped by specifying min_value/max_value either globally using single values or via a + list/tuple channelwise if enabled. + """ + + def __init__(self, min_value=None, max_value=None, norm01=False, channelwise=False, + eps=1e-10, **kwargs): + if min_value is not None and max_value is not None: + assert max_value > min_value + self.min_value = min_value + self.max_value = max_value + self.norm01 = norm01 + self.channelwise = channelwise + self.eps = eps + + def __call__(self, m): + if self.channelwise: + # get min/max channelwise + axes = list(range(m.ndim)) + axes = tuple(axes[1:]) + if self.min_value is None or 'None' in self.min_value: + min_value = np.min(m, axis=axes, keepdims=True) + + if self.max_value is None or 'None' in self.max_value: + max_value = np.max(m, axis=axes, keepdims=True) + + # check if non None in self.min_value/self.max_value + # if present and if so copy value to min_value + if self.min_value is not None: + for i,v in enumerate(self.min_value): + if v != 'None': + min_value[i] = v + + if self.max_value is not None: + for i,v in enumerate(self.max_value): + if v != 'None': + max_value[i] = v + else: + if self.min_value is None: + min_value = np.min(m) + else: + min_value = self.min_value + + if self.max_value is None: + max_value = np.max(m) + else: + max_value = self.max_value + + # calculate norm_0_1 with min_value / max_value with the same dimension + # in case of channelwise application + norm_0_1 = (m - min_value) / (max_value - min_value + self.eps) + + if self.norm01 is True: + return np.clip(norm_0_1, 0, 1) + else: + return np.clip(2 * norm_0_1 - 1, -1, 1) + + +class AdditiveGaussianNoise: + def __init__(self, random_state, scale=(0.0, 1.0), execution_probability=0.1, **kwargs): + self.execution_probability = execution_probability + self.random_state = random_state + self.scale = scale + + def __call__(self, m): + if self.random_state.uniform() < self.execution_probability: + std = self.random_state.uniform(self.scale[0], self.scale[1]) + gaussian_noise = self.random_state.normal(0, std, size=m.shape) + return m + gaussian_noise + return m + + +class AdditivePoissonNoise: + def __init__(self, random_state, lam=(0.0, 1.0), execution_probability=0.1, **kwargs): + self.execution_probability = execution_probability + self.random_state = random_state + self.lam = lam + + def __call__(self, m): + if self.random_state.uniform() < self.execution_probability: + lam = self.random_state.uniform(self.lam[0], self.lam[1]) + poisson_noise = self.random_state.poisson(lam, size=m.shape) + return m + poisson_noise + return m + + +class ToTensor: + """ + Converts a given input numpy.ndarray into torch.Tensor. + + Args: + expand_dims (bool): if True, adds a channel dimension to the input data + dtype (np.dtype): the desired output data type + """ + + def __init__(self, expand_dims, dtype=np.float32, **kwargs): + self.expand_dims = expand_dims + self.dtype = dtype + + def __call__(self, m): + assert m.ndim in [3, 4], 'Supports only 3D (DxHxW) or 4D (CxDxHxW) images' + # add channel dimension + if self.expand_dims and m.ndim == 3: + m = np.expand_dims(m, axis=0) + + return torch.from_numpy(m.astype(dtype=self.dtype)) + + +class Relabel: + """ + Relabel a numpy array of labels into a consecutive numbers, e.g. + [10, 10, 0, 6, 6] -> [2, 2, 0, 1, 1]. Useful when one has an instance segmentation volume + at hand and would like to create a one-hot-encoding for it. Without a consecutive labeling the task would be harder. + """ + + def __init__(self, append_original=False, run_cc=True, ignore_label=None, **kwargs): + self.append_original = append_original + self.ignore_label = ignore_label + self.run_cc = run_cc + + if ignore_label is not None: + assert append_original, "ignore_label present, so append_original must be true, so that one can localize the ignore region" + + def __call__(self, m): + orig = m + if self.run_cc: + # assign 0 to the ignore region + m = measure.label(m, background=self.ignore_label) + + _, unique_labels = np.unique(m, return_inverse=True) + result = unique_labels.reshape(m.shape) + if self.append_original: + result = np.stack([result, orig]) + return result + + +class Identity: + def __init__(self, **kwargs): + pass + + def __call__(self, m): + return m + + +class RgbToLabel: + def __call__(self, img): + img = np.array(img) + assert img.ndim == 3 and img.shape[2] == 3 + result = img[..., 0] * 65536 + img[..., 1] * 256 + img[..., 2] + return result + + +class LabelToTensor: + def __call__(self, m): + m = np.array(m) + return torch.from_numpy(m.astype(dtype='int64')) + + +class GaussianBlur3D: + def __init__(self, sigma=[.1, 2.], execution_probability=0.5, **kwargs): + self.sigma = sigma + self.execution_probability = execution_probability + + def __call__(self, x): + if random.random() < self.execution_probability: + sigma = random.uniform(self.sigma[0], self.sigma[1]) + x = gaussian(x, sigma=sigma) + return x + return x + + +class Transformer: + def __init__(self, phase_config, base_config): + self.phase_config = phase_config + self.config_base = base_config + self.seed = GLOBAL_RANDOM_STATE.randint(10000000) + + def raw_transform(self): + return self._create_transform('raw') + + def label_transform(self): + return self._create_transform('label') + + def weight_transform(self): + return self._create_transform('weight') + + @staticmethod + def _transformer_class(class_name): + m = importlib.import_module('pytorch3dunet.augment.transforms') + clazz = getattr(m, class_name) + return clazz + + def _create_transform(self, name): + assert name in self.phase_config, f'Could not find {name} transform' + return Compose([ + self._create_augmentation(c) for c in self.phase_config[name] + ]) + + def _create_augmentation(self, c): + config = dict(self.config_base) + config.update(c) + config['random_state'] = np.random.RandomState(self.seed) + aug_class = self._transformer_class(config['name']) + return aug_class(**config) + + +def _recover_ignore_index(input, orig, ignore_index): + if ignore_index is not None: + mask = orig == ignore_index + input[mask] = ignore_index + + return input diff --git a/build/lib/pytorch3dunet/datasets/__init__.py b/build/lib/pytorch3dunet/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/pytorch3dunet/datasets/dsb.py b/build/lib/pytorch3dunet/datasets/dsb.py new file mode 100644 index 00000000..5d0cde86 --- /dev/null +++ b/build/lib/pytorch3dunet/datasets/dsb.py @@ -0,0 +1,108 @@ +import collections +import os + +import imageio +import numpy as np +import torch + +from pytorch3dunet.augment import transforms +from pytorch3dunet.datasets.utils import ConfigDataset, calculate_stats +from pytorch3dunet.unet3d.utils import get_logger + +logger = get_logger('DSB2018Dataset') + + +def dsb_prediction_collate(batch): + """ + Forms a mini-batch of (images, paths) during test time for the DSB-like datasets. + """ + error_msg = "batch must contain tensors or str; found {}" + if isinstance(batch[0], torch.Tensor): + return torch.stack(batch, 0) + elif isinstance(batch[0], str): + return list(batch) + elif isinstance(batch[0], collections.Sequence): + # transpose tuples, i.e. [[1, 2], ['a', 'b']] to be [[1, 'a'], [2, 'b']] + transposed = zip(*batch) + return [dsb_prediction_collate(samples) for samples in transposed] + + raise TypeError((error_msg.format(type(batch[0])))) + + +class DSB2018Dataset(ConfigDataset): + def __init__(self, root_dir, phase, transformer_config, expand_dims=True): + assert os.path.isdir(root_dir), f'{root_dir} is not a directory' + assert phase in ['train', 'val', 'test'] + + self.phase = phase + + # load raw images + images_dir = os.path.join(root_dir, 'images') + assert os.path.isdir(images_dir) + self.images, self.paths = self._load_files(images_dir, expand_dims) + self.file_path = images_dir + + stats = calculate_stats(self.images, True) + + transformer = transforms.Transformer(transformer_config, stats) + + # load raw images transformer + self.raw_transform = transformer.raw_transform() + + if phase != 'test': + # load labeled images + masks_dir = os.path.join(root_dir, 'masks') + assert os.path.isdir(masks_dir) + self.masks, _ = self._load_files(masks_dir, expand_dims) + assert len(self.images) == len(self.masks) + # load label images transformer + self.masks_transform = transformer.label_transform() + else: + self.masks = None + self.masks_transform = None + + def __getitem__(self, idx): + if idx >= len(self): + raise StopIteration + + img = self.images[idx] + if self.phase != 'test': + mask = self.masks[idx] + return self.raw_transform(img), self.masks_transform(mask) + else: + return self.raw_transform(img), self.paths[idx] + + def __len__(self): + return len(self.images) + + @classmethod + def prediction_collate(cls, batch): + return dsb_prediction_collate(batch) + + @classmethod + def create_datasets(cls, dataset_config, phase): + phase_config = dataset_config[phase] + # load data augmentation configuration + transformer_config = phase_config['transformer'] + # load files to process + file_paths = phase_config['file_paths'] + expand_dims = dataset_config.get('expand_dims', True) + return [cls(file_paths[0], phase, transformer_config, expand_dims)] + + @staticmethod + def _load_files(dir, expand_dims): + files_data = [] + paths = [] + for file in os.listdir(dir): + path = os.path.join(dir, file) + img = np.asarray(imageio.imread(path)) + if expand_dims: + dims = img.ndim + img = np.expand_dims(img, axis=0) + if dims == 3: + img = np.transpose(img, (3, 0, 1, 2)) + + files_data.append(img) + paths.append(path) + + return files_data, paths diff --git a/build/lib/pytorch3dunet/datasets/hdf5.py b/build/lib/pytorch3dunet/datasets/hdf5.py new file mode 100644 index 00000000..040adb85 --- /dev/null +++ b/build/lib/pytorch3dunet/datasets/hdf5.py @@ -0,0 +1,293 @@ +import glob +import os +from abc import abstractmethod +from itertools import chain + +import h5py + +import pytorch3dunet.augment.transforms as transforms +from pytorch3dunet.datasets.utils import get_slice_builder, ConfigDataset, calculate_stats, mirror_pad +from pytorch3dunet.unet3d.utils import get_logger + +logger = get_logger('HDF5Dataset') + + +def _create_padded_indexes(indexes, halo_shape): + return tuple(slice(index.start, index.stop + 2 * halo) for index, halo in zip(indexes, halo_shape)) + + +def traverse_h5_paths(file_paths): + assert isinstance(file_paths, list) + results = [] + for file_path in file_paths: + if os.path.isdir(file_path): + # if file path is a directory take all H5 files in that directory + iters = [glob.glob(os.path.join(file_path, ext)) for ext in ['*.h5', '*.hdf', '*.hdf5', '*.hd5']] + for fp in chain(*iters): + results.append(fp) + else: + results.append(file_path) + return results + + +class AbstractHDF5Dataset(ConfigDataset): + """ + Implementation of torch.utils.data.Dataset backed by the HDF5 files, which iterates over the raw and label datasets + patch by patch with a given stride. + + Args: + file_path (str): path to H5 file containing raw data as well as labels and per pixel weights (optional) + phase (str): 'train' for training, 'val' for validation, 'test' for testing + slice_builder_config (dict): configuration of the SliceBuilder + transformer_config (dict): data augmentation configuration + raw_internal_path (str or list): H5 internal path to the raw dataset + label_internal_path (str or list): H5 internal path to the label dataset + weight_internal_path (str or list): H5 internal path to the per pixel weights (optional) + global_normalization (bool): if True, the mean and std of the raw data will be calculated over the whole dataset + """ + + def __init__(self, file_path, phase, slice_builder_config, transformer_config, raw_internal_path='raw', + label_internal_path='label', weight_internal_path=None, global_normalization=True): + assert phase in ['train', 'val', 'test'] + + self.phase = phase + self.file_path = file_path + self.raw_internal_path = raw_internal_path + self.label_internal_path = label_internal_path + self.weight_internal_path = weight_internal_path + + self.halo_shape = slice_builder_config.get('halo_shape', [0, 0, 0]) + + if global_normalization: + logger.info('Calculating mean and std of the raw data...') + with h5py.File(file_path, 'r') as f: + raw = f[raw_internal_path][:] + stats = calculate_stats(raw) + else: + stats = calculate_stats(None, True) + + self.transformer = transforms.Transformer(transformer_config, stats) + self.raw_transform = self.transformer.raw_transform() + + if phase != 'test': + # create label/weight transform only in train/val phase + self.label_transform = self.transformer.label_transform() + + if weight_internal_path is not None: + self.weight_transform = self.transformer.weight_transform() + else: + self.weight_transform = None + + self._check_volume_sizes() + else: + # 'test' phase used only for predictions so ignore the label dataset + self.label = None + self.weight_map = None + + # compare patch and stride configuration + patch_shape = slice_builder_config.get('patch_shape') + stride_shape = slice_builder_config.get('stride_shape') + if sum(self.halo_shape) != 0 and patch_shape != stride_shape: + logger.warning(f'Found non-zero halo shape {self.halo_shape}. ' + f'In this case: patch shape and stride shape should be equal for optimal prediction ' + f'performance, but found patch_shape: {patch_shape} and stride_shape: {stride_shape}!') + + with h5py.File(file_path, 'r') as f: + raw = f[raw_internal_path] + label = f[label_internal_path] if phase != 'test' else None + weight_map = f[weight_internal_path] if weight_internal_path is not None else None + # build slice indices for raw and label data sets + slice_builder = get_slice_builder(raw, label, weight_map, slice_builder_config) + self.raw_slices = slice_builder.raw_slices + self.label_slices = slice_builder.label_slices + self.weight_slices = slice_builder.weight_slices + + self.patch_count = len(self.raw_slices) + logger.info(f'Number of patches: {self.patch_count}') + + @abstractmethod + def get_raw_patch(self, idx): + raise NotImplementedError + + @abstractmethod + def get_label_patch(self, idx): + raise NotImplementedError + + @abstractmethod + def get_weight_patch(self, idx): + raise NotImplementedError + + @abstractmethod + def get_raw_padded_patch(self, idx): + raise NotImplementedError + + def volume_shape(self): + with h5py.File(self.file_path, 'r') as f: + raw = f[self.raw_internal_path] + if raw.ndim == 3: + return raw.shape + else: + return raw.shape[1:] + + def __getitem__(self, idx): + if idx >= len(self): + raise StopIteration + + raw_idx = self.raw_slices[idx] + + if self.phase == 'test': + if len(raw_idx) == 4: + # discard the channel dimension in the slices: predictor requires only the spatial dimensions of the volume + raw_idx = raw_idx[1:] # Remove the first element if raw_idx has 4 elements + raw_idx_padded = (slice(None),) + _create_padded_indexes(raw_idx, self.halo_shape) + else: + raw_idx_padded = _create_padded_indexes(raw_idx, self.halo_shape) + + raw_patch_transformed = self.raw_transform(self.get_raw_padded_patch(raw_idx_padded)) + return raw_patch_transformed, raw_idx + else: + raw_patch_transformed = self.raw_transform(self.get_raw_patch(raw_idx)) + + # get the slice for a given index 'idx' + label_idx = self.label_slices[idx] + label_patch_transformed = self.label_transform(self.get_label_patch(label_idx)) + if self.weight_internal_path is not None: + weight_idx = self.weight_slices[idx] + weight_patch_transformed = self.weight_transform(self.get_weight_patch(weight_idx)) + return raw_patch_transformed, label_patch_transformed, weight_patch_transformed + # return the transformed raw and label patches + return raw_patch_transformed, label_patch_transformed + + def __len__(self): + return self.patch_count + + def _check_volume_sizes(self): + def _volume_shape(volume): + if volume.ndim == 3: + return volume.shape + return volume.shape[1:] + + with h5py.File(self.file_path, 'r') as f: + raw = f[self.raw_internal_path] + label = f[self.label_internal_path] + assert raw.ndim in [3, 4], 'Raw dataset must be 3D (DxHxW) or 4D (CxDxHxW)' + assert label.ndim in [3, 4], 'Label dataset must be 3D (DxHxW) or 4D (CxDxHxW)' + assert _volume_shape(raw) == _volume_shape(label), 'Raw and labels have to be of the same size' + if self.weight_internal_path is not None: + weight_map = f[self.weight_internal_path] + assert weight_map.ndim in [3, 4], 'Weight map dataset must be 3D (DxHxW) or 4D (CxDxHxW)' + assert _volume_shape(raw) == _volume_shape(weight_map), 'Raw and weight map have to be of the same size' + + @classmethod + def create_datasets(cls, dataset_config, phase): + phase_config = dataset_config[phase] + + # load data augmentation configuration + transformer_config = phase_config['transformer'] + # load slice builder config + slice_builder_config = phase_config['slice_builder'] + # load files to process + file_paths = phase_config['file_paths'] + # file_paths may contain both files and directories; if the file_path is a directory all H5 files inside + # are going to be included in the final file_paths + file_paths = traverse_h5_paths(file_paths) + + datasets = [] + for file_path in file_paths: + try: + logger.info(f'Loading {phase} set from: {file_path}...') + dataset = cls(file_path=file_path, + phase=phase, + slice_builder_config=slice_builder_config, + transformer_config=transformer_config, + raw_internal_path=dataset_config.get('raw_internal_path', 'raw'), + label_internal_path=dataset_config.get('label_internal_path', 'label'), + weight_internal_path=dataset_config.get('weight_internal_path', None), + global_normalization=dataset_config.get('global_normalization', None)) + datasets.append(dataset) + except Exception: + logger.error(f'Skipping {phase} set: {file_path}', exc_info=True) + return datasets + + +class StandardHDF5Dataset(AbstractHDF5Dataset): + """ + Implementation of the HDF5 dataset which loads the data from the H5 files into the memory. + Fast but might consume a lot of memory. + """ + + def __init__(self, file_path, phase, slice_builder_config, transformer_config, + raw_internal_path='raw', label_internal_path='label', weight_internal_path=None, + global_normalization=True): + super().__init__(file_path=file_path, phase=phase, slice_builder_config=slice_builder_config, + transformer_config=transformer_config, raw_internal_path=raw_internal_path, + label_internal_path=label_internal_path, weight_internal_path=weight_internal_path, + global_normalization=global_normalization) + self._raw = None + self._raw_padded = None + self._label = None + self._weight_map = None + + def get_raw_patch(self, idx): + if self._raw is None: + with h5py.File(self.file_path, 'r') as f: + assert self.raw_internal_path in f, f'Dataset {self.raw_internal_path} not found in {self.file_path}' + self._raw = f[self.raw_internal_path][:] + return self._raw[idx] + + def get_label_patch(self, idx): + if self._label is None: + with h5py.File(self.file_path, 'r') as f: + assert self.label_internal_path in f, f'Dataset {self.label_internal_path} not found in {self.file_path}' + self._label = f[self.label_internal_path][:] + return self._label[idx] + + def get_weight_patch(self, idx): + if self._weight_map is None: + with h5py.File(self.file_path, 'r') as f: + assert self.weight_internal_path in f, f'Dataset {self.weight_internal_path} not found in {self.file_path}' + self._weight_map = f[self.weight_internal_path][:] + return self._weight_map[idx] + + def get_raw_padded_patch(self, idx): + if self._raw_padded is None: + with h5py.File(self.file_path, 'r') as f: + assert self.raw_internal_path in f, f'Dataset {self.raw_internal_path} not found in {self.file_path}' + self._raw_padded = mirror_pad(f[self.raw_internal_path][:], self.halo_shape) + return self._raw_padded[idx] + + +class LazyHDF5Dataset(AbstractHDF5Dataset): + """Implementation of the HDF5 dataset which loads the data lazily. It's slower, but has a low memory footprint.""" + + def __init__(self, file_path, phase, slice_builder_config, transformer_config, + raw_internal_path='raw', label_internal_path='label', weight_internal_path=None, + global_normalization=False): + super().__init__(file_path=file_path, phase=phase, slice_builder_config=slice_builder_config, + transformer_config=transformer_config, raw_internal_path=raw_internal_path, + label_internal_path=label_internal_path, weight_internal_path=weight_internal_path, + global_normalization=global_normalization) + + logger.info("Using LazyHDF5Dataset") + + def get_raw_patch(self, idx): + with h5py.File(self.file_path, 'r') as f: + return f[self.raw_internal_path][idx] + + def get_label_patch(self, idx): + with h5py.File(self.file_path, 'r') as f: + return f[self.label_internal_path][idx] + + def get_weight_patch(self, idx): + with h5py.File(self.file_path, 'r') as f: + return f[self.weight_internal_path][idx] + + def get_raw_padded_patch(self, idx): + with h5py.File(self.file_path, 'r+') as f: + if 'raw_padded' in f: + return f['raw_padded'][idx] + + raw = f[self.raw_internal_path][:] + raw_padded = mirror_pad(raw, self.halo_shape) + f.create_dataset('raw_padded', data=raw_padded, compression='gzip') + return raw_padded[idx] diff --git a/build/lib/pytorch3dunet/datasets/utils.py b/build/lib/pytorch3dunet/datasets/utils.py new file mode 100644 index 00000000..1ffeefe4 --- /dev/null +++ b/build/lib/pytorch3dunet/datasets/utils.py @@ -0,0 +1,361 @@ +import collections +from typing import Any + +import numpy as np +import torch +from torch.utils.data import DataLoader, ConcatDataset, Dataset + +from pytorch3dunet.unet3d.utils import get_logger, get_class + +logger = get_logger('Dataset') + + +class ConfigDataset(Dataset): + def __getitem__(self, index): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + @classmethod + def create_datasets(cls, dataset_config, phase): + """ + Factory method for creating a list of datasets based on the provided config. + + Args: + dataset_config (dict): dataset configuration + phase (str): one of ['train', 'val', 'test'] + + Returns: + list of `Dataset` instances + """ + raise NotImplementedError + + @classmethod + def prediction_collate(cls, batch): + """Default collate_fn. Override in child class for non-standard datasets.""" + return default_prediction_collate(batch) + + +class SliceBuilder: + """ + Builds the position of the patches in a given raw/label/weight ndarray based on the patch and stride shape. + + Args: + raw_dataset (ndarray): raw data + label_dataset (ndarray): ground truth labels + weight_dataset (ndarray): weights for the labels + patch_shape (tuple): the shape of the patch DxHxW + stride_shape (tuple): the shape of the stride DxHxW + kwargs: additional metadata + """ + + def __init__(self, raw_dataset, label_dataset, weight_dataset, patch_shape, stride_shape, **kwargs): + patch_shape = tuple(patch_shape) + stride_shape = tuple(stride_shape) + skip_shape_check = kwargs.get('skip_shape_check', False) + if not skip_shape_check: + self._check_patch_shape(patch_shape) + + self._raw_slices = self._build_slices(raw_dataset, patch_shape, stride_shape) + if label_dataset is None: + self._label_slices = None + else: + # take the first element in the label_dataset to build slices + self._label_slices = self._build_slices(label_dataset, patch_shape, stride_shape) + assert len(self._raw_slices) == len(self._label_slices) + if weight_dataset is None: + self._weight_slices = None + else: + self._weight_slices = self._build_slices(weight_dataset, patch_shape, stride_shape) + assert len(self.raw_slices) == len(self._weight_slices) + + @property + def raw_slices(self): + return self._raw_slices + + @property + def label_slices(self): + return self._label_slices + + @property + def weight_slices(self): + return self._weight_slices + + @staticmethod + def _build_slices(dataset, patch_shape, stride_shape): + """Iterates over a given n-dim dataset patch-by-patch with a given stride + and builds an array of slice positions. + + Returns: + list of slices, i.e. + [(slice, slice, slice, slice), ...] if len(shape) == 4 + [(slice, slice, slice), ...] if len(shape) == 3 + """ + slices = [] + if dataset.ndim == 4: + in_channels, i_z, i_y, i_x = dataset.shape + else: + i_z, i_y, i_x = dataset.shape + + k_z, k_y, k_x = patch_shape + s_z, s_y, s_x = stride_shape + z_steps = SliceBuilder._gen_indices(i_z, k_z, s_z) + for z in z_steps: + y_steps = SliceBuilder._gen_indices(i_y, k_y, s_y) + for y in y_steps: + x_steps = SliceBuilder._gen_indices(i_x, k_x, s_x) + for x in x_steps: + slice_idx = ( + slice(z, z + k_z), + slice(y, y + k_y), + slice(x, x + k_x), + ) + if dataset.ndim == 4: + slice_idx = (slice(0, in_channels),) + slice_idx + slices.append(slice_idx) + return slices + + @staticmethod + def _gen_indices(i, k, s): + assert i >= k, 'Sample size has to be bigger than the patch size' + for j in range(0, i - k + 1, s): + yield j + if j + k < i: + yield i - k + + @staticmethod + def _check_patch_shape(patch_shape): + assert len(patch_shape) == 3, 'patch_shape must be a 3D tuple' + assert patch_shape[1] >= 64 and patch_shape[2] >= 64, 'Height and Width must be greater or equal 64' + + +class FilterSliceBuilder(SliceBuilder): + """ + Filter patches containing more than `1 - threshold` of ignore_index label + """ + + def __init__(self, raw_dataset, label_dataset, weight_dataset, patch_shape, stride_shape, ignore_index=None, + threshold=0.6, slack_acceptance=0.01, **kwargs): + super().__init__(raw_dataset, label_dataset, weight_dataset, patch_shape, stride_shape, **kwargs) + if label_dataset is None: + return + + rand_state = np.random.RandomState(47) + + def ignore_predicate(raw_label_idx): + label_idx = raw_label_idx[1] + patch = label_dataset[label_idx] + if ignore_index is not None: + patch = np.copy(patch) + patch[patch == ignore_index] = 0 + non_ignore_counts = np.count_nonzero(patch != 0) + non_ignore_counts = non_ignore_counts / patch.size + return non_ignore_counts > threshold or rand_state.rand() < slack_acceptance + + zipped_slices = zip(self.raw_slices, self.label_slices) + # ignore slices containing too much ignore_index + logger.info(f'Filtering slices...') + filtered_slices = list(filter(ignore_predicate, zipped_slices)) + # unzip and save slices + raw_slices, label_slices = zip(*filtered_slices) + self._raw_slices = list(raw_slices) + self._label_slices = list(label_slices) + + +def _loader_classes(class_name): + modules = [ + 'pytorch3dunet.datasets.hdf5', + 'pytorch3dunet.datasets.dsb', + 'pytorch3dunet.datasets.utils' + ] + return get_class(class_name, modules) + + +def get_slice_builder(raws, labels, weight_maps, config): + assert 'name' in config + logger.info(f"Slice builder config: {config}") + slice_builder_cls = _loader_classes(config['name']) + return slice_builder_cls(raws, labels, weight_maps, **config) + + +def get_train_loaders(config): + """ + Returns dictionary containing the training and validation loaders (torch.utils.data.DataLoader). + + :param config: a top level configuration object containing the 'loaders' key + :return: dict { + 'train': + 'val': + } + """ + assert 'loaders' in config, 'Could not find data loaders configuration' + loaders_config = config['loaders'] + + logger.info('Creating training and validation set loaders...') + + # get dataset class + dataset_cls_str = loaders_config.get('dataset', None) + if dataset_cls_str is None: + dataset_cls_str = 'StandardHDF5Dataset' + logger.warning(f"Cannot find dataset class in the config. Using default '{dataset_cls_str}'.") + dataset_class = _loader_classes(dataset_cls_str) + + assert set(loaders_config['train']['file_paths']).isdisjoint(loaders_config['val']['file_paths']), \ + "Train and validation 'file_paths' overlap. One cannot use validation data for training!" + + train_datasets = dataset_class.create_datasets(loaders_config, phase='train') + + val_datasets = dataset_class.create_datasets(loaders_config, phase='val') + + num_workers = loaders_config.get('num_workers', 1) + logger.info(f'Number of workers for train/val dataloader: {num_workers}') + batch_size = loaders_config.get('batch_size', 1) + if torch.cuda.device_count() > 1 and not config['device'] == 'cpu': + logger.info( + f'{torch.cuda.device_count()} GPUs available. Using batch_size = {torch.cuda.device_count()} * {batch_size}') + batch_size = batch_size * torch.cuda.device_count() + + logger.info(f'Batch size for train/val loader: {batch_size}') + # when training with volumetric data use batch_size of 1 due to GPU memory constraints + return { + 'train': DataLoader(ConcatDataset(train_datasets), batch_size=batch_size, shuffle=True, pin_memory=True, + num_workers=num_workers), + # don't shuffle during validation: useful when showing how predictions for a given batch get better over time + 'val': DataLoader(ConcatDataset(val_datasets), batch_size=batch_size, shuffle=False, pin_memory=True, + num_workers=num_workers) + } + + +def get_test_loaders(config): + """ + Returns test DataLoader. + + :return: generator of DataLoader objects + """ + + assert 'loaders' in config, 'Could not find data loaders configuration' + loaders_config = config['loaders'] + + logger.info('Creating test set loaders...') + + # get dataset class + dataset_cls_str = loaders_config.get('dataset', None) + if dataset_cls_str is None: + dataset_cls_str = 'StandardHDF5Dataset' + logger.warning(f"Cannot find dataset class in the config. Using default '{dataset_cls_str}'.") + dataset_class = _loader_classes(dataset_cls_str) + + test_datasets = dataset_class.create_datasets(loaders_config, phase='test') + + num_workers = loaders_config.get('num_workers', 1) + logger.info(f'Number of workers for the dataloader: {num_workers}') + + batch_size = loaders_config.get('batch_size', 1) + if torch.cuda.device_count() > 1 and not config['device'] == 'cpu': + logger.info( + f'{torch.cuda.device_count()} GPUs available. Using batch_size = {torch.cuda.device_count()} * {batch_size}') + batch_size = batch_size * torch.cuda.device_count() + + logger.info(f'Batch size for dataloader: {batch_size}') + + # use generator in order to create data loaders lazily one by one + for test_dataset in test_datasets: + logger.info(f'Loading test set from: {test_dataset.file_path}...') + if hasattr(test_dataset, 'prediction_collate'): + collate_fn = test_dataset.prediction_collate + else: + collate_fn = default_prediction_collate + + yield DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True, + collate_fn=collate_fn) + + +def default_prediction_collate(batch): + """ + Default collate_fn to form a mini-batch of Tensor(s) for HDF5 based datasets + """ + error_msg = "batch must contain tensors or slice; found {}" + if isinstance(batch[0], torch.Tensor): + return torch.stack(batch, 0) + elif isinstance(batch[0], tuple) and isinstance(batch[0][0], slice): + return batch + elif isinstance(batch[0], collections.abc.Sequence): + transposed = zip(*batch) + return [default_prediction_collate(samples) for samples in transposed] + + raise TypeError((error_msg.format(type(batch[0])))) + + +def calculate_stats(img: np.array, skip: bool = False) -> dict[str, Any]: + """ + Calculates the minimum percentile, maximum percentile, mean, and standard deviation of the image. + + Args: + img: The input image array. + skip: if True, skip the calculation and return None for all values. + + Returns: + tuple[float, float, float, float]: The minimum percentile, maximum percentile, mean, and std dev + """ + if not skip: + pmin, pmax, mean, std = np.percentile(img, 1), np.percentile(img, 99.6), np.mean(img), np.std(img) + else: + pmin, pmax, mean, std = None, None, None, None + + return { + 'pmin': pmin, + 'pmax': pmax, + 'mean': mean, + 'std': std + } + + +def mirror_pad(image, padding_shape): + """ + Pad the image with a mirror reflection of itself. + + This function is used on data in its original shape before it is split into patches. + + Args: + image (np.ndarray): The input image array to be padded. + padding_shape (tuple of int): Specifies the amount of padding for each dimension, should be YX or ZYX. + + Returns: + np.ndarray: The mirror-padded image. + + Raises: + ValueError: If any element of padding_shape is negative. + """ + assert len(padding_shape) == 3, "Padding shape must be specified for each dimension: ZYX" + + if any(p < 0 for p in padding_shape): + raise ValueError("padding_shape must be non-negative") + + if all(p == 0 for p in padding_shape): + return image + + pad_width = [(p, p) for p in padding_shape] + + if image.ndim == 4: + pad_width = [(0, 0)] + pad_width + return np.pad(image, pad_width, mode='reflect') + + +def remove_padding(m, padding_shape): + """ + Removes padding from the margins of a multi-dimensional array. + + Args: + m (np.ndarray): The input array to be unpadded. + padding_shape (tuple of int, optional): The amount of padding to remove from each dimension. + Assumes the tuple length matches the array dimensions. + + Returns: + np.ndarray: The unpadded array. + """ + if padding_shape is None: + return m + + # Correctly construct slice objects for each dimension in padding_shape and apply them to m. + return m[(..., *(slice(p, -p or None) for p in padding_shape))] diff --git a/build/lib/pytorch3dunet/predict.py b/build/lib/pytorch3dunet/predict.py new file mode 100644 index 00000000..cc54fcf7 --- /dev/null +++ b/build/lib/pytorch3dunet/predict.py @@ -0,0 +1,59 @@ +import importlib +import os + +import torch +import torch.nn as nn + +from pytorch3dunet.datasets.utils import get_test_loaders +from pytorch3dunet.unet3d import utils +from pytorch3dunet.unet3d.config import load_config +from pytorch3dunet.unet3d.model import get_model + +logger = utils.get_logger('UNet3DPredict') + + +def get_predictor(model, config): + output_dir = config['loaders'].get('output_dir', None) + # override output_dir if provided in the 'predictor' section of the config + output_dir = config.get('predictor', {}).get('output_dir', output_dir) + if output_dir is not None: + os.makedirs(output_dir, exist_ok=True) + + predictor_config = config.get('predictor', {}) + class_name = predictor_config.get('name', 'StandardPredictor') + + m = importlib.import_module('pytorch3dunet.unet3d.predictor') + predictor_class = getattr(m, class_name) + out_channels = config['model'].get('out_channels') + return predictor_class(model, output_dir, out_channels, **predictor_config) + + +def main(): + # Load configuration + config, _ = load_config() + + # Create the model + model = get_model(config['model']) + + # Load model state + model_path = config['model_path'] + logger.info(f'Loading model from {model_path}...') + utils.load_checkpoint(model_path, model) + # use DataParallel if more than 1 GPU available + + if torch.cuda.device_count() > 1 and not config['device'] == 'cpu': + model = nn.DataParallel(model) + logger.info(f'Using {torch.cuda.device_count()} GPUs for prediction') + if torch.cuda.is_available() and not config['device'] == 'cpu': + model = model.cuda() + + # create predictor instance + predictor = get_predictor(model, config) + + for test_loader in get_test_loaders(config): + # run the model prediction on the test_loader and save the results in the output_dir + predictor(test_loader) + + +if __name__ == '__main__': + main() diff --git a/build/lib/pytorch3dunet/train.py b/build/lib/pytorch3dunet/train.py new file mode 100644 index 00000000..eceaf719 --- /dev/null +++ b/build/lib/pytorch3dunet/train.py @@ -0,0 +1,35 @@ +import random + +import torch + +from pytorch3dunet.unet3d.config import load_config, copy_config +from pytorch3dunet.unet3d.trainer import create_trainer +from pytorch3dunet.unet3d.utils import get_logger + +logger = get_logger('TrainingSetup') + + +def main(): + # Load and log experiment configuration + config, config_path = load_config() + logger.info(config) + + manual_seed = config.get('manual_seed', None) + if manual_seed is not None: + logger.info(f'Seed the RNG for all devices with {manual_seed}') + logger.warning('Using CuDNN deterministic setting. This may slow down the training!') + random.seed(manual_seed) + torch.manual_seed(manual_seed) + # see https://pytorch.org/docs/stable/notes/randomness.html + torch.backends.cudnn.deterministic = True + + # Create trainer + trainer = create_trainer(config) + # Copy config file + copy_config(config, config_path) + # Start training + trainer.fit() + + +if __name__ == '__main__': + main() diff --git a/build/lib/pytorch3dunet/unet3d/__init__.py b/build/lib/pytorch3dunet/unet3d/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/pytorch3dunet/unet3d/buildingblocks.py b/build/lib/pytorch3dunet/unet3d/buildingblocks.py new file mode 100644 index 00000000..25679c24 --- /dev/null +++ b/build/lib/pytorch3dunet/unet3d/buildingblocks.py @@ -0,0 +1,545 @@ +from functools import partial + +import torch +from torch import nn as nn +from torch.nn import functional as F + +from pytorch3dunet.unet3d.se import ChannelSELayer3D, ChannelSpatialSELayer3D, SpatialSELayer3D + + +def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding, + dropout_prob, is3d): + """ + Create a list of modules with together constitute a single conv layer with non-linearity + and optional batchnorm/groupnorm. + + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + kernel_size(int or tuple): size of the convolving kernel + order (string): order of things, e.g. + 'cr' -> conv + ReLU + 'gcr' -> groupnorm + conv + ReLU + 'cl' -> conv + LeakyReLU + 'ce' -> conv + ELU + 'bcr' -> batchnorm + conv + ReLU + 'cbrd' -> conv + batchnorm + ReLU + dropout + 'cbrD' -> conv + batchnorm + ReLU + dropout2d + num_groups (int): number of groups for the GroupNorm + padding (int or tuple): add zero-padding added to all three sides of the input + dropout_prob (float): dropout probability + is3d (bool): is3d (bool): if True use Conv3d, otherwise use Conv2d + Return: + list of tuple (name, module) + """ + assert 'c' in order, "Conv layer MUST be present" + assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer' + + modules = [] + for i, char in enumerate(order): + if char == 'r': + modules.append(('ReLU', nn.ReLU(inplace=True))) + elif char == 'l': + modules.append(('LeakyReLU', nn.LeakyReLU(inplace=True))) + elif char == 'e': + modules.append(('ELU', nn.ELU(inplace=True))) + elif char == 'c': + # add learnable bias only in the absence of batchnorm/groupnorm + bias = not ('g' in order or 'b' in order) + if is3d: + conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) + else: + conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) + + modules.append(('conv', conv)) + elif char == 'g': + is_before_conv = i < order.index('c') + if is_before_conv: + num_channels = in_channels + else: + num_channels = out_channels + + # use only one group if the given number of groups is greater than the number of channels + if num_channels < num_groups: + num_groups = 1 + + assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}' + modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels))) + elif char == 'b': + is_before_conv = i < order.index('c') + if is3d: + bn = nn.BatchNorm3d + else: + bn = nn.BatchNorm2d + + if is_before_conv: + modules.append(('batchnorm', bn(in_channels))) + else: + modules.append(('batchnorm', bn(out_channels))) + elif char == 'd': + modules.append(('dropout', nn.Dropout(p=dropout_prob))) + elif char == 'D': + modules.append(('dropout2d', nn.Dropout2d(p=dropout_prob))) + else: + raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c', 'd', 'D']") + + return modules + + +class SingleConv(nn.Sequential): + """ + Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order + of operations can be specified via the `order` parameter + + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + kernel_size (int or tuple): size of the convolving kernel + order (string): determines the order of layers, e.g. + 'cr' -> conv + ReLU + 'crg' -> conv + ReLU + groupnorm + 'cl' -> conv + LeakyReLU + 'ce' -> conv + ELU + num_groups (int): number of groups for the GroupNorm + padding (int or tuple): add zero-padding + dropout_prob (float): dropout probability, default 0.1 + is3d (bool): if True use Conv3d, otherwise use Conv2d + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, order='gcr', num_groups=8, + padding=1, dropout_prob=0.1, is3d=True): + super(SingleConv, self).__init__() + + for name, module in create_conv(in_channels, out_channels, kernel_size, order, + num_groups, padding, dropout_prob, is3d): + self.add_module(name, module) + + +class DoubleConv(nn.Sequential): + """ + A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d). + We use (Conv3d+ReLU+GroupNorm3d) by default. + This can be changed however by providing the 'order' argument, e.g. in order + to change to Conv3d+BatchNorm3d+ELU use order='cbe'. + Use padded convolutions to make sure that the output (H_out, W_out) is the same + as (H_in, W_in), so that you don't have to crop in the decoder path. + + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + encoder (bool): if True we're in the encoder path, otherwise we're in the decoder + kernel_size (int or tuple): size of the convolving kernel + order (string): determines the order of layers, e.g. + 'cr' -> conv + ReLU + 'crg' -> conv + ReLU + groupnorm + 'cl' -> conv + LeakyReLU + 'ce' -> conv + ELU + num_groups (int): number of groups for the GroupNorm + padding (int or tuple): add zero-padding added to all three sides of the input + upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2 + dropout_prob (float or tuple): dropout probability for each convolution, default 0.1 + is3d (bool): if True use Conv3d instead of Conv2d layers + """ + + def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr', + num_groups=8, padding=1, upscale=2, dropout_prob=0.1, is3d=True): + super(DoubleConv, self).__init__() + if encoder: + # we're in the encoder path + conv1_in_channels = in_channels + if upscale == 1: + conv1_out_channels = out_channels + else: + conv1_out_channels = out_channels // 2 + if conv1_out_channels < in_channels: + conv1_out_channels = in_channels + conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels + else: + # we're in the decoder path, decrease the number of channels in the 1st convolution + conv1_in_channels, conv1_out_channels = in_channels, out_channels + conv2_in_channels, conv2_out_channels = out_channels, out_channels + + # check if dropout_prob is a tuple and if so + # split it for different dropout probabilities for each convolution. + if isinstance(dropout_prob, list) or isinstance(dropout_prob, tuple): + dropout_prob1 = dropout_prob[0] + dropout_prob2 = dropout_prob[1] + else: + dropout_prob1 = dropout_prob2 = dropout_prob + + # conv1 + self.add_module('SingleConv1', + SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups, + padding=padding, dropout_prob=dropout_prob1, is3d=is3d)) + # conv2 + self.add_module('SingleConv2', + SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups, + padding=padding, dropout_prob=dropout_prob2, is3d=is3d)) + + +class ResNetBlock(nn.Module): + """ + Residual block that can be used instead of standard DoubleConv in the Encoder module. + Motivated by: https://arxiv.org/pdf/1706.00120.pdf + + Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, is3d=True, **kwargs): + super(ResNetBlock, self).__init__() + + if in_channels != out_channels: + # conv1x1 for increasing the number of channels + if is3d: + self.conv1 = nn.Conv3d(in_channels, out_channels, 1) + else: + self.conv1 = nn.Conv2d(in_channels, out_channels, 1) + else: + self.conv1 = nn.Identity() + + # residual block + self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups, + is3d=is3d) + # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual + n_order = order + for c in 'rel': + n_order = n_order.replace(c, '') + self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order, + num_groups=num_groups, is3d=is3d) + + # create non-linearity separately + if 'l' in order: + self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True) + elif 'e' in order: + self.non_linearity = nn.ELU(inplace=True) + else: + self.non_linearity = nn.ReLU(inplace=True) + + def forward(self, x): + # apply first convolution to bring the number of channels to out_channels + residual = self.conv1(x) + + # residual block + out = self.conv2(residual) + out = self.conv3(out) + + out += residual + out = self.non_linearity(out) + + return out + + +class ResNetBlockSE(ResNetBlock): + def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, se_module='scse', **kwargs): + super(ResNetBlockSE, self).__init__( + in_channels, out_channels, kernel_size=kernel_size, order=order, + num_groups=num_groups, **kwargs) + assert se_module in ['scse', 'cse', 'sse'] + if se_module == 'scse': + self.se_module = ChannelSpatialSELayer3D(num_channels=out_channels, reduction_ratio=1) + elif se_module == 'cse': + self.se_module = ChannelSELayer3D(num_channels=out_channels, reduction_ratio=1) + elif se_module == 'sse': + self.se_module = SpatialSELayer3D(num_channels=out_channels) + + def forward(self, x): + out = super().forward(x) + out = self.se_module(out) + return out + + +class Encoder(nn.Module): + """ + A single module from the encoder path consisting of the optional max + pooling layer (one may specify the MaxPool kernel_size to be different + from the standard (2,2,2), e.g. if the volumetric data is anisotropic + (make sure to use complementary scale_factor in the decoder path) followed by + a basic module (DoubleConv or ResNetBlock). + + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + conv_kernel_size (int or tuple): size of the convolving kernel + apply_pooling (bool): if True use MaxPool3d before DoubleConv + pool_kernel_size (int or tuple): the size of the window + pool_type (str): pooling layer: 'max' or 'avg' + basic_module(nn.Module): either ResNetBlock or DoubleConv + conv_layer_order (string): determines the order of layers + in `DoubleConv` module. See `DoubleConv` for more info. + num_groups (int): number of groups for the GroupNorm + padding (int or tuple): add zero-padding added to all three sides of the input + upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2 + dropout_prob (float or tuple): dropout probability, default 0.1 + is3d (bool): use 3d or 2d convolutions/pooling operation + """ + + def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True, + pool_kernel_size=2, pool_type='max', basic_module=DoubleConv, conv_layer_order='gcr', + num_groups=8, padding=1, upscale=2, dropout_prob=0.1, is3d=True): + super(Encoder, self).__init__() + assert pool_type in ['max', 'avg'] + if apply_pooling: + if pool_type == 'max': + if is3d: + self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size) + else: + self.pooling = nn.MaxPool2d(kernel_size=pool_kernel_size) + else: + if is3d: + self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size) + else: + self.pooling = nn.AvgPool2d(kernel_size=pool_kernel_size) + else: + self.pooling = None + + self.basic_module = basic_module(in_channels, out_channels, + encoder=True, + kernel_size=conv_kernel_size, + order=conv_layer_order, + num_groups=num_groups, + padding=padding, + upscale=upscale, + dropout_prob=dropout_prob, + is3d=is3d) + + def forward(self, x): + if self.pooling is not None: + x = self.pooling(x) + x = self.basic_module(x) + return x + + +class Decoder(nn.Module): + """ + A single module for decoder path consisting of the upsampling layer + (either learned ConvTranspose3d or nearest neighbor interpolation) + followed by a basic module (DoubleConv or ResNetBlock). + + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + conv_kernel_size (int or tuple): size of the convolving kernel + scale_factor (int or tuple): used as the multiplier for the image H/W/D in + case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation + from the corresponding encoder + basic_module(nn.Module): either ResNetBlock or DoubleConv + conv_layer_order (string): determines the order of layers + in `DoubleConv` module. See `DoubleConv` for more info. + num_groups (int): number of groups for the GroupNorm + padding (int or tuple): add zero-padding added to all three sides of the input + upsample (str): algorithm used for upsampling: + InterpolateUpsampling: 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area' + TransposeConvUpsampling: 'deconv' + No upsampling: None + Default: 'default' (chooses automatically) + dropout_prob (float or tuple): dropout probability, default 0.1 + """ + + def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=2, basic_module=DoubleConv, + conv_layer_order='gcr', num_groups=8, padding=1, upsample='default', + dropout_prob=0.1, is3d=True): + super(Decoder, self).__init__() + + # perform concat joining per default + concat = True + + # don't adapt channels after join operation + adapt_channels = False + + if upsample is not None and upsample != 'none': + if upsample == 'default': + if basic_module == DoubleConv: + upsample = 'nearest' # use nearest neighbor interpolation for upsampling + concat = True # use concat joining + adapt_channels = False # don't adapt channels + elif basic_module == ResNetBlock or basic_module == ResNetBlockSE: + upsample = 'deconv' # use deconvolution upsampling + concat = False # use summation joining + adapt_channels = True # adapt channels after joining + + # perform deconvolution upsampling if mode is deconv + if upsample == 'deconv': + self.upsampling = TransposeConvUpsampling(in_channels=in_channels, out_channels=out_channels, + kernel_size=conv_kernel_size, scale_factor=scale_factor, + is3d=is3d) + else: + self.upsampling = InterpolateUpsampling(mode=upsample) + else: + # no upsampling + self.upsampling = NoUpsampling() + # concat joining + self.joining = partial(self._joining, concat=True) + + # perform joining operation + self.joining = partial(self._joining, concat=concat) + + # adapt the number of in_channels for the ResNetBlock + if adapt_channels is True: + in_channels = out_channels + + self.basic_module = basic_module(in_channels, out_channels, + encoder=False, + kernel_size=conv_kernel_size, + order=conv_layer_order, + num_groups=num_groups, + padding=padding, + dropout_prob=dropout_prob, + is3d=is3d) + + def forward(self, encoder_features, x): + x = self.upsampling(encoder_features=encoder_features, x=x) + x = self.joining(encoder_features, x) + x = self.basic_module(x) + return x + + @staticmethod + def _joining(encoder_features, x, concat): + if concat: + return torch.cat((encoder_features, x), dim=1) + else: + return encoder_features + x + + +def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding, + conv_upscale, dropout_prob, + layer_order, num_groups, pool_kernel_size, is3d): + # create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)` + encoders = [] + for i, out_feature_num in enumerate(f_maps): + if i == 0: + # apply conv_coord only in the first encoder if any + encoder = Encoder(in_channels, out_feature_num, + apply_pooling=False, # skip pooling in the firs encoder + basic_module=basic_module, + conv_layer_order=layer_order, + conv_kernel_size=conv_kernel_size, + num_groups=num_groups, + padding=conv_padding, + upscale=conv_upscale, + dropout_prob=dropout_prob, + is3d=is3d) + else: + encoder = Encoder(f_maps[i - 1], out_feature_num, + basic_module=basic_module, + conv_layer_order=layer_order, + conv_kernel_size=conv_kernel_size, + num_groups=num_groups, + pool_kernel_size=pool_kernel_size, + padding=conv_padding, + upscale=conv_upscale, + dropout_prob=dropout_prob, + is3d=is3d) + + encoders.append(encoder) + + return nn.ModuleList(encoders) + + +def create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, + num_groups, upsample, dropout_prob, is3d): + # create decoder path consisting of the Decoder modules. The length of the decoder list is equal to `len(f_maps) - 1` + decoders = [] + reversed_f_maps = list(reversed(f_maps)) + for i in range(len(reversed_f_maps) - 1): + if basic_module == DoubleConv and upsample != 'deconv': + in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1] + else: + in_feature_num = reversed_f_maps[i] + + out_feature_num = reversed_f_maps[i + 1] + + decoder = Decoder(in_feature_num, out_feature_num, + basic_module=basic_module, + conv_layer_order=layer_order, + conv_kernel_size=conv_kernel_size, + num_groups=num_groups, + padding=conv_padding, + upsample=upsample, + dropout_prob=dropout_prob, + is3d=is3d) + decoders.append(decoder) + return nn.ModuleList(decoders) + + +class AbstractUpsampling(nn.Module): + """ + Abstract class for upsampling. A given implementation should upsample a given 5D input tensor using either + interpolation or learned transposed convolution. + """ + + def __init__(self, upsample): + super(AbstractUpsampling, self).__init__() + self.upsample = upsample + + def forward(self, encoder_features, x): + # get the spatial dimensions of the output given the encoder_features + output_size = encoder_features.size()[2:] + # upsample the input and return + return self.upsample(x, output_size) + + +class InterpolateUpsampling(AbstractUpsampling): + """ + Args: + mode (str): algorithm used for upsampling: + 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest' + used only if transposed_conv is False + """ + + def __init__(self, mode='nearest'): + upsample = partial(self._interpolate, mode=mode) + super().__init__(upsample) + + @staticmethod + def _interpolate(x, size, mode): + return F.interpolate(x, size=size, mode=mode) + + +class TransposeConvUpsampling(AbstractUpsampling): + """ + Args: + in_channels (int): number of input channels for transposed conv + used only if transposed_conv is True + out_channels (int): number of output channels for transpose conv + used only if transposed_conv is True + kernel_size (int or tuple): size of the convolving kernel + used only if transposed_conv is True + scale_factor (int or tuple): stride of the convolution + used only if transposed_conv is True + is3d (bool): if True use ConvTranspose3d, otherwise use ConvTranspose2d + """ + + class Upsample(nn.Module): + """ + Workaround the 'ValueError: requested an output size...' in the `_output_padding` method in + transposed convolution. It performs transposed conv followed by the interpolation to the correct size if necessary. + """ + + def __init__(self, conv_transposed, is3d): + super().__init__() + self.conv_transposed = conv_transposed + self.is3d = is3d + + def forward(self, x, size): + x = self.conv_transposed(x) + return F.interpolate(x, size=size) + + def __init__(self, in_channels, out_channels, kernel_size=3, scale_factor=2, is3d=True): + # make sure that the output size reverses the MaxPool3d from the corresponding encoder + if is3d is True: + conv_transposed = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, + stride=scale_factor, padding=1, bias=False) + else: + conv_transposed = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, + stride=scale_factor, padding=1, bias=False) + upsample = self.Upsample(conv_transposed, is3d) + super().__init__(upsample) + + +class NoUpsampling(AbstractUpsampling): + def __init__(self): + super().__init__(self._no_upsampling) + + @staticmethod + def _no_upsampling(x, size): + return x diff --git a/build/lib/pytorch3dunet/unet3d/config.py b/build/lib/pytorch3dunet/unet3d/config.py new file mode 100644 index 00000000..bb011632 --- /dev/null +++ b/build/lib/pytorch3dunet/unet3d/config.py @@ -0,0 +1,79 @@ +import argparse +import os +import shutil + +import torch +import yaml + +from pytorch3dunet.unet3d import utils + +logger = utils.get_logger('ConfigLoader') + + +def _override_config(args, config): + """Overrides config params with the ones given in command line.""" + + args_dict = vars(args) + # remove the first argument which is the config file path + args_dict.pop('config') + + for key, value in args_dict.items(): + if value is None: + continue + c = config + for k in key.split('.'): + if k not in c: + raise ValueError(f'Invalid config key: {key}') + if isinstance(c[k], dict): + c = c[k] + else: + c[k] = value + + +def load_config(): + parser = argparse.ArgumentParser(description='UNet3D') + parser.add_argument('--config', type=str, help='Path to the YAML config file', required=True) + # add additional command line arguments for the prediction that override the ones in the config file + parser.add_argument('--model_path', type=str, required=False) + parser.add_argument('--loaders.output_dir', type=str, required=False) + parser.add_argument('--loaders.test.file_paths', type=str, nargs="+", required=False) + parser.add_argument('--loaders.test.slice_builder.patch_shape', type=int, nargs="+", required=False) + parser.add_argument('--loaders.test.slice_builder.stride_shape', type=int, nargs="+", required=False) + + args = parser.parse_args() + config_path = args.config + config = yaml.safe_load(open(config_path, 'r')) + _override_config(args, config) + + device = config.get('device', None) + if device == 'cpu': + logger.warning('CPU mode forced in config, this will likely result in slow training/prediction') + config['device'] = 'cpu' + return config + + if torch.cuda.is_available(): + config['device'] = 'cuda' + else: + logger.warning('CUDA not available, using CPU') + config['device'] = 'cpu' + return config, config_path + + +def copy_config(config, config_path): + """Copies the config file to the checkpoint folder.""" + + def _get_last_subfolder_path(path): + subfolders = [f.path for f in os.scandir(path) if f.is_dir()] + return max(subfolders, default=None) + + checkpoint_dir = os.path.join( + config['trainer'].pop('checkpoint_dir'), 'logs') + last_run_dir = _get_last_subfolder_path(checkpoint_dir) + config_file_name = os.path.basename(config_path) + + if last_run_dir: + shutil.copy2(config_path, os.path.join(last_run_dir, config_file_name)) + + +def _load_config_yaml(config_file): + return yaml.safe_load(open(config_file, 'r')) diff --git a/build/lib/pytorch3dunet/unet3d/losses.py b/build/lib/pytorch3dunet/unet3d/losses.py new file mode 100644 index 00000000..6a53966f --- /dev/null +++ b/build/lib/pytorch3dunet/unet3d/losses.py @@ -0,0 +1,345 @@ +import torch +import torch.nn.functional as F +from torch import nn as nn +from torch.nn import MSELoss, SmoothL1Loss, L1Loss + + +def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None): + """ + Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given a multi channel input and target. + Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function. + + Args: + input (torch.Tensor): NxCxSpatial input tensor + target (torch.Tensor): NxCxSpatial target tensor + epsilon (float): prevents division by zero + weight (torch.Tensor): Cx1 tensor of weight per channel/class + """ + + # input and target shapes must match + assert input.size() == target.size(), "'input' and 'target' must have the same shape" + + input = flatten(input) + target = flatten(target) + target = target.float() + + # compute per channel Dice Coefficient + intersect = (input * target).sum(-1) + if weight is not None: + intersect = weight * intersect + + # here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1) + denominator = (input * input).sum(-1) + (target * target).sum(-1) + return 2 * (intersect / denominator.clamp(min=epsilon)) + + +class _MaskingLossWrapper(nn.Module): + """ + Loss wrapper which prevents the gradient of the loss to be computed where target is equal to `ignore_index`. + """ + + def __init__(self, loss, ignore_index): + super(_MaskingLossWrapper, self).__init__() + assert ignore_index is not None, 'ignore_index cannot be None' + self.loss = loss + self.ignore_index = ignore_index + + def forward(self, input, target): + mask = target.clone().ne_(self.ignore_index) + mask.requires_grad = False + + # mask out input/target so that the gradient is zero where on the mask + input = input * mask + target = target * mask + + # forward masked input and target to the loss + return self.loss(input, target) + + +class SkipLastTargetChannelWrapper(nn.Module): + """ + Loss wrapper which removes additional target channel + """ + + def __init__(self, loss, squeeze_channel=False): + super(SkipLastTargetChannelWrapper, self).__init__() + self.loss = loss + self.squeeze_channel = squeeze_channel + + def forward(self, input, target, weight=None): + assert target.size(1) > 1, 'Target tensor has a singleton channel dimension, cannot remove channel' + + # skips last target channel if needed + target = target[:, :-1, ...] + + if self.squeeze_channel: + # squeeze channel dimension + target = torch.squeeze(target, dim=1) + if weight is not None: + return self.loss(input, target, weight) + return self.loss(input, target) + + +class _AbstractDiceLoss(nn.Module): + """ + Base class for different implementations of Dice loss. + """ + + def __init__(self, weight=None, normalization='sigmoid'): + super(_AbstractDiceLoss, self).__init__() + self.register_buffer('weight', weight) + # The output from the network during training is assumed to be un-normalized probabilities and we would + # like to normalize the logits. Since Dice (or soft Dice in this case) is usually used for binary data, + # normalizing the channels with Sigmoid is the default choice even for multi-class segmentation problems. + # However if one would like to apply Softmax in order to get the proper probability distribution from the + # output, just specify `normalization=Softmax` + assert normalization in ['sigmoid', 'softmax', 'none'] + if normalization == 'sigmoid': + self.normalization = nn.Sigmoid() + elif normalization == 'softmax': + self.normalization = nn.Softmax(dim=1) + else: + self.normalization = lambda x: x + + def dice(self, input, target, weight): + # actual Dice score computation; to be implemented by the subclass + raise NotImplementedError + + def forward(self, input, target): + # get probabilities from logits + input = self.normalization(input) + + # compute per channel Dice coefficient + per_channel_dice = self.dice(input, target, weight=self.weight) + + # average Dice score across all channels/classes + return 1. - torch.mean(per_channel_dice) + + +class DiceLoss(_AbstractDiceLoss): + """Computes Dice Loss according to https://arxiv.org/abs/1606.04797. + For multi-class segmentation `weight` parameter can be used to assign different weights per class. + The input to the loss function is assumed to be a logit and will be normalized by the Sigmoid function. + """ + + def __init__(self, weight=None, normalization='sigmoid'): + super().__init__(weight, normalization) + + def dice(self, input, target, weight): + return compute_per_channel_dice(input, target, weight=self.weight) + + +class GeneralizedDiceLoss(_AbstractDiceLoss): + """Computes Generalized Dice Loss (GDL) as described in https://arxiv.org/pdf/1707.03237.pdf. + """ + + def __init__(self, normalization='sigmoid', epsilon=1e-6): + super().__init__(weight=None, normalization=normalization) + self.epsilon = epsilon + + def dice(self, input, target, weight): + assert input.size() == target.size(), "'input' and 'target' must have the same shape" + + input = flatten(input) + target = flatten(target) + target = target.float() + + if input.size(0) == 1: + # for GDL to make sense we need at least 2 channels (see https://arxiv.org/pdf/1707.03237.pdf) + # put foreground and background voxels in separate channels + input = torch.cat((input, 1 - input), dim=0) + target = torch.cat((target, 1 - target), dim=0) + + # GDL weighting: the contribution of each label is corrected by the inverse of its volume + w_l = target.sum(-1) + w_l = 1 / (w_l * w_l).clamp(min=self.epsilon) + w_l.requires_grad = False + + intersect = (input * target).sum(-1) + intersect = intersect * w_l + + denominator = (input + target).sum(-1) + denominator = (denominator * w_l).clamp(min=self.epsilon) + + return 2 * (intersect.sum() / denominator.sum()) + + +class BCEDiceLoss(nn.Module): + """Linear combination of BCE and Dice losses""" + + def __init__(self, alpha, beta): + super(BCEDiceLoss, self).__init__() + self.alpha = alpha + self.bce = nn.BCEWithLogitsLoss() + self.beta = beta + self.dice = DiceLoss() + + def forward(self, input, target): + return self.alpha * self.bce(input, target) + self.beta * self.dice(input, target) + + +class WeightedCrossEntropyLoss(nn.Module): + """WeightedCrossEntropyLoss (WCE) as described in https://arxiv.org/pdf/1707.03237.pdf + """ + + def __init__(self, ignore_index=-1): + super(WeightedCrossEntropyLoss, self).__init__() + self.ignore_index = ignore_index + + def forward(self, input, target): + weight = self._class_weights(input) + return F.cross_entropy(input, target, weight=weight, ignore_index=self.ignore_index) + + @staticmethod + def _class_weights(input): + # normalize the input first + input = F.softmax(input, dim=1) + flattened = flatten(input) + nominator = (1. - flattened).sum(-1) + denominator = flattened.sum(-1) + class_weights = nominator / denominator + return class_weights.detach() + + +class PixelWiseCrossEntropyLoss(nn.Module): + def __init__(self, ignore_index=None): + super(PixelWiseCrossEntropyLoss, self).__init__() + self.ignore_index = ignore_index + self.log_softmax = nn.LogSoftmax(dim=1) + + def forward(self, input, target, weights): + assert target.size() == weights.size() + # normalize the input + log_probabilities = self.log_softmax(input) + # standard CrossEntropyLoss requires the target to be (NxDxHxW), so we need to expand it to (NxCxDxHxW) + if self.ignore_index is not None: + mask = target == self.ignore_index + target[mask] = 0 + else: + mask = torch.zeros_like(target) + # add channel dimension and invert the mask + mask = 1 - mask.unsqueeze(1) + # convert target to one-hot encoding + target = F.one_hot(target.long()) + if target.ndim == 5: + # permute target to (NxCxDxHxW) + target = target.permute(0, 4, 1, 2, 3).contiguous() + else: + target = target.permute(0, 3, 1, 2).contiguous() + # apply the mask on the target + target = target * mask + # add channel dimension to the weights + weights = weights.unsqueeze(1) + # compute the losses + result = -weights * target * log_probabilities + return result.mean() + + +class WeightedSmoothL1Loss(nn.SmoothL1Loss): + def __init__(self, threshold, initial_weight, apply_below_threshold=True): + super().__init__(reduction="none") + self.threshold = threshold + self.apply_below_threshold = apply_below_threshold + self.weight = initial_weight + + def forward(self, input, target): + l1 = super().forward(input, target) + + if self.apply_below_threshold: + mask = target < self.threshold + else: + mask = target >= self.threshold + + l1[mask] = l1[mask] * self.weight + + return l1.mean() + + +def flatten(tensor): + """Flattens a given tensor such that the channel axis is first. + The shapes are transformed as follows: + (N, C, D, H, W) -> (C, N * D * H * W) + """ + # number of channels + C = tensor.size(1) + # new axis order + axis_order = (1, 0) + tuple(range(2, tensor.dim())) + # Transpose: (N, C, D, H, W) -> (C, N, D, H, W) + transposed = tensor.permute(axis_order) + # Flatten: (C, N, D, H, W) -> (C, N * D * H * W) + return transposed.contiguous().view(C, -1) + + +def get_loss_criterion(config): + """ + Returns the loss function based on provided configuration + :param config: (dict) a top level configuration object containing the 'loss' key + :return: an instance of the loss function + """ + assert 'loss' in config, 'Could not find loss function configuration' + loss_config = config['loss'] + name = loss_config.pop('name') + + ignore_index = loss_config.pop('ignore_index', None) + skip_last_target = loss_config.pop('skip_last_target', False) + weight = loss_config.pop('weight', None) + + if weight is not None: + weight = torch.tensor(weight) + + pos_weight = loss_config.pop('pos_weight', None) + if pos_weight is not None: + pos_weight = torch.tensor(pos_weight) + + loss = _create_loss(name, loss_config, weight, ignore_index, pos_weight) + + if not (ignore_index is None or name in ['CrossEntropyLoss', 'WeightedCrossEntropyLoss']): + # use MaskingLossWrapper only for non-cross-entropy losses, since CE losses allow specifying 'ignore_index' directly + loss = _MaskingLossWrapper(loss, ignore_index) + + if skip_last_target: + loss = SkipLastTargetChannelWrapper(loss, loss_config.get('squeeze_channel', False)) + + if torch.cuda.is_available(): + loss = loss.cuda() + + return loss + + +####################################################################################################################### + +def _create_loss(name, loss_config, weight, ignore_index, pos_weight): + if name == 'BCEWithLogitsLoss': + return nn.BCEWithLogitsLoss(pos_weight=pos_weight) + elif name == 'BCEDiceLoss': + alpha = loss_config.get('alpha', 1.) + beta = loss_config.get('beta', 1.) + return BCEDiceLoss(alpha, beta) + elif name == 'CrossEntropyLoss': + if ignore_index is None: + ignore_index = -100 # use the default 'ignore_index' as defined in the CrossEntropyLoss + return nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index) + elif name == 'WeightedCrossEntropyLoss': + if ignore_index is None: + ignore_index = -100 # use the default 'ignore_index' as defined in the CrossEntropyLoss + return WeightedCrossEntropyLoss(ignore_index=ignore_index) + elif name == 'PixelWiseCrossEntropyLoss': + return PixelWiseCrossEntropyLoss(ignore_index=ignore_index) + elif name == 'GeneralizedDiceLoss': + normalization = loss_config.get('normalization', 'sigmoid') + return GeneralizedDiceLoss(normalization=normalization) + elif name == 'DiceLoss': + normalization = loss_config.get('normalization', 'sigmoid') + return DiceLoss(weight=weight, normalization=normalization) + elif name == 'MSELoss': + return MSELoss() + elif name == 'SmoothL1Loss': + return SmoothL1Loss() + elif name == 'L1Loss': + return L1Loss() + elif name == 'WeightedSmoothL1Loss': + return WeightedSmoothL1Loss(threshold=loss_config['threshold'], + initial_weight=loss_config['initial_weight'], + apply_below_threshold=loss_config.get('apply_below_threshold', True)) + else: + raise RuntimeError(f"Unsupported loss function: '{name}'") diff --git a/build/lib/pytorch3dunet/unet3d/metrics.py b/build/lib/pytorch3dunet/unet3d/metrics.py new file mode 100644 index 00000000..2b60b4b7 --- /dev/null +++ b/build/lib/pytorch3dunet/unet3d/metrics.py @@ -0,0 +1,445 @@ +import importlib + +import numpy as np +import torch +from skimage import measure +from skimage.metrics import adapted_rand_error, peak_signal_noise_ratio, mean_squared_error + +from pytorch3dunet.unet3d.losses import compute_per_channel_dice +from pytorch3dunet.unet3d.seg_metrics import AveragePrecision, Accuracy +from pytorch3dunet.unet3d.utils import get_logger, expand_as_one_hot, convert_to_numpy + +logger = get_logger('EvalMetric') + + +class DiceCoefficient: + """Computes Dice Coefficient. + Generalized to multiple channels by computing per-channel Dice Score + (as described in https://arxiv.org/pdf/1707.03237.pdf) and then simply taking the average. + Input is expected to be probabilities instead of logits. + This metric is mostly useful when channels contain the same semantic class (e.g. affinities computed with different offsets). + DO NOT USE this metric when training with DiceLoss, otherwise the results will be biased towards the loss. + """ + + def __init__(self, epsilon=1e-6, **kwargs): + self.epsilon = epsilon + + def __call__(self, input, target): + # Average across channels in order to get the final score + return torch.mean(compute_per_channel_dice(input, target, epsilon=self.epsilon)) + + +class MeanIoU: + """ + Computes IoU for each class separately and then averages over all classes. + """ + + def __init__(self, skip_channels=(), ignore_index=None, **kwargs): + """ + :param skip_channels: list/tuple of channels to be ignored from the IoU computation + :param ignore_index: id of the label to be ignored from IoU computation + """ + self.ignore_index = ignore_index + self.skip_channels = skip_channels + + def __call__(self, input, target): + """ + :param input: 5D probability maps torch float tensor (NxCxDxHxW) + :param target: 4D or 5D ground truth torch tensor. 4D (NxDxHxW) tensor will be expanded to 5D as one-hot + :return: intersection over union averaged over all channels + """ + assert input.dim() == 5 + + n_classes = input.size()[1] + + if target.dim() == 4: + target = expand_as_one_hot(target, C=n_classes, ignore_index=self.ignore_index) + + assert input.size() == target.size() + + per_batch_iou = [] + for _input, _target in zip(input, target): + binary_prediction = self._binarize_predictions(_input, n_classes) + + if self.ignore_index is not None: + # zero out ignore_index + mask = _target == self.ignore_index + binary_prediction[mask] = 0 + _target[mask] = 0 + + # convert to uint8 just in case + binary_prediction = binary_prediction.byte() + _target = _target.byte() + + per_channel_iou = [] + for c in range(n_classes): + if c in self.skip_channels: + continue + + per_channel_iou.append(self._jaccard_index(binary_prediction[c], _target[c])) + + assert per_channel_iou, "All channels were ignored from the computation" + mean_iou = torch.mean(torch.tensor(per_channel_iou)) + per_batch_iou.append(mean_iou) + + return torch.mean(torch.tensor(per_batch_iou)) + + def _binarize_predictions(self, input, n_classes): + """ + Puts 1 for the class/channel with the highest probability and 0 in other channels. Returns byte tensor of the + same size as the input tensor. + """ + if n_classes == 1: + # for single channel input just threshold the probability map + result = input > 0.5 + return result.long() + + _, max_index = torch.max(input, dim=0, keepdim=True) + return torch.zeros_like(input, dtype=torch.uint8).scatter_(0, max_index, 1) + + def _jaccard_index(self, prediction, target): + """ + Computes IoU for a given target and prediction tensors + """ + return torch.sum(prediction & target).float() / torch.clamp(torch.sum(prediction | target).float(), min=1e-8) + + +class AdaptedRandError: + """ + A functor which computes an Adapted Rand error as defined by the SNEMI3D contest + (http://brainiac2.mit.edu/SNEMI3D/evaluation). + + This is a generic implementation which takes the input, converts it to the segmentation image (see `input_to_segm()`) + and then computes the ARand between the segmentation and the ground truth target. Depending on one's use case + it's enough to extend this class and implement the `input_to_segm` method. + + Args: + use_last_target (bool): if true, use the last channel from the target to compute the ARand, otherwise the first. + """ + + def __init__(self, use_last_target=False, ignore_index=None, **kwargs): + self.use_last_target = use_last_target + self.ignore_index = ignore_index + + def __call__(self, input, target): + """ + Compute ARand Error for each input, target pair in the batch and return the mean value. + + Args: + input (torch.tensor): 5D (NCDHW) output from the network + target (torch.tensor): 5D (NCDHW) ground truth segmentation + + Returns: + average ARand Error across the batch + """ + + # converts input and target to numpy arrays + input, target = convert_to_numpy(input, target) + if self.use_last_target: + target = target[:, -1, ...] # 4D + else: + # use 1st target channel + target = target[:, 0, ...] # 4D + + # ensure target is of integer type + target = target.astype(np.int32) + + if self.ignore_index is not None: + target[target == self.ignore_index] = 0 + + per_batch_arand = [] + for _input, _target in zip(input, target): + if np.all(_target == _target.flat[0]): # skip ARand eval if there is only one label in the patch due to zero-division + logger.info('Skipping ARandError computation: only 1 label present in the ground truth') + per_batch_arand.append(0.) + continue + + # convert _input to segmentation CDHW + segm = self.input_to_segm(_input) + assert segm.ndim == 4 + + # compute per channel arand and return the minimum value + per_channel_arand = [adapted_rand_error(_target, channel_segm)[0] for channel_segm in segm] + per_batch_arand.append(np.min(per_channel_arand)) + + # return mean arand error + mean_arand = torch.mean(torch.tensor(per_batch_arand)) + logger.info(f'ARand: {mean_arand.item()}') + return mean_arand + + def input_to_segm(self, input): + """ + Converts input tensor (output from the network) to the segmentation image. E.g. if the input is the boundary + pmaps then one option would be to threshold it and run connected components in order to return the segmentation. + + :param input: 4D tensor (CDHW) + :return: segmentation volume either 4D (segmentation per channel) + """ + # by deafult assume that input is a segmentation volume itself + return input + + +class BoundaryAdaptedRandError(AdaptedRandError): + """ + Compute ARand between the input boundary map and target segmentation. + Boundary map is thresholded, and connected components is run to get the predicted segmentation + """ + + def __init__(self, thresholds=None, use_last_target=True, ignore_index=None, input_channel=None, invert_pmaps=True, + save_plots=False, plots_dir='.', **kwargs): + super().__init__(use_last_target=use_last_target, ignore_index=ignore_index, save_plots=save_plots, + plots_dir=plots_dir, **kwargs) + + if thresholds is None: + thresholds = [0.3, 0.4, 0.5, 0.6] + assert isinstance(thresholds, list) + self.thresholds = thresholds + self.input_channel = input_channel + self.invert_pmaps = invert_pmaps + + def input_to_segm(self, input): + if self.input_channel is not None: + input = np.expand_dims(input[self.input_channel], axis=0) + + segs = [] + for predictions in input: + for th in self.thresholds: + # threshold probability maps + predictions = predictions > th + + if self.invert_pmaps: + # for connected component analysis we need to treat boundary signal as background + # assign 0-label to boundary mask + predictions = np.logical_not(predictions) + + predictions = predictions.astype(np.uint8) + # run connected components on the predicted mask; consider only 1-connectivity + seg = measure.label(predictions, background=0, connectivity=1) + segs.append(seg) + + return np.stack(segs) + + +class GenericAdaptedRandError(AdaptedRandError): + def __init__(self, input_channels, thresholds=None, use_last_target=True, ignore_index=None, invert_channels=None, + **kwargs): + + super().__init__(use_last_target=use_last_target, ignore_index=ignore_index, **kwargs) + assert isinstance(input_channels, list) or isinstance(input_channels, tuple) + self.input_channels = input_channels + if thresholds is None: + thresholds = [0.3, 0.4, 0.5, 0.6] + assert isinstance(thresholds, list) + self.thresholds = thresholds + if invert_channels is None: + invert_channels = [] + self.invert_channels = invert_channels + + def input_to_segm(self, input): + # pick only the channels specified in the input_channels + results = [] + for i in self.input_channels: + c = input[i] + # invert channel if necessary + if i in self.invert_channels: + c = 1 - c + results.append(c) + + input = np.stack(results) + + segs = [] + for predictions in input: + for th in self.thresholds: + # run connected components on the predicted mask; consider only 1-connectivity + seg = measure.label((predictions > th).astype(np.uint8), background=0, connectivity=1) + segs.append(seg) + + return np.stack(segs) + + +class GenericAveragePrecision: + def __init__(self, min_instance_size=None, use_last_target=False, metric='ap', **kwargs): + self.min_instance_size = min_instance_size + self.use_last_target = use_last_target + assert metric in ['ap', 'acc'] + if metric == 'ap': + # use AveragePrecision + self.metric = AveragePrecision() + else: + # use Accuracy at 0.5 IoU + self.metric = Accuracy(iou_threshold=0.5) + + def __call__(self, input, target): + if target.dim() == 5: + if self.use_last_target: + target = target[:, -1, ...] # 4D + else: + # use 1st target channel + target = target[:, 0, ...] # 4D + + input1 = input2 = input + multi_head = isinstance(input, tuple) + if multi_head: + input1, input2 = input + + input1, input2, target = convert_to_numpy(input1, input2, target) + + batch_aps = [] + i_batch = 0 + # iterate over the batch + for inp1, inp2, tar in zip(input1, input2, target): + if multi_head: + inp = (inp1, inp2) + else: + inp = inp1 + + segs = self.input_to_seg(inp, tar) # expects 4D + assert segs.ndim == 4 + # convert target to seg + tar = self.target_to_seg(tar) + + # filter small instances if necessary + tar = self._filter_instances(tar) + + # compute average precision per channel + segs_aps = [self.metric(self._filter_instances(seg), tar) for seg in segs] + + logger.info(f'Batch: {i_batch}. Max Average Precision for channel: {np.argmax(segs_aps)}') + # save max AP + batch_aps.append(np.max(segs_aps)) + i_batch += 1 + + return torch.tensor(batch_aps).mean() + + def _filter_instances(self, input): + """ + Filters instances smaller than 'min_instance_size' by overriding them with 0-index + :param input: input instance segmentation + """ + if self.min_instance_size is not None: + labels, counts = np.unique(input, return_counts=True) + for label, count in zip(labels, counts): + if count < self.min_instance_size: + input[input == label] = 0 + return input + + def input_to_seg(self, input, target=None): + raise NotImplementedError + + def target_to_seg(self, target): + return target + + +class BlobsAveragePrecision(GenericAveragePrecision): + """ + Computes Average Precision given foreground prediction and ground truth instance segmentation. + """ + + def __init__(self, thresholds=None, metric='ap', min_instance_size=None, input_channel=0, **kwargs): + super().__init__(min_instance_size=min_instance_size, use_last_target=True, metric=metric) + if thresholds is None: + thresholds = [0.4, 0.5, 0.6, 0.7, 0.8] + assert isinstance(thresholds, list) + self.thresholds = thresholds + self.input_channel = input_channel + + def input_to_seg(self, input, target=None): + input = input[self.input_channel] + segs = [] + for th in self.thresholds: + # threshold and run connected components + mask = (input > th).astype(np.uint8) + seg = measure.label(mask, background=0, connectivity=1) + segs.append(seg) + return np.stack(segs) + + +class BlobsBoundaryAveragePrecision(GenericAveragePrecision): + """ + Computes Average Precision given foreground prediction, boundary prediction and ground truth instance segmentation. + Segmentation mask is computed as (P_mask - P_boundary) > th followed by a connected component + """ + + def __init__(self, thresholds=None, metric='ap', min_instance_size=None, **kwargs): + super().__init__(min_instance_size=min_instance_size, use_last_target=True, metric=metric) + if thresholds is None: + thresholds = [0.3, 0.4, 0.5, 0.6, 0.7] + assert isinstance(thresholds, list) + self.thresholds = thresholds + + def input_to_seg(self, input, target=None): + # input = P_mask - P_boundary + input = input[0] - input[1] + segs = [] + for th in self.thresholds: + # threshold and run connected components + mask = (input > th).astype(np.uint8) + seg = measure.label(mask, background=0, connectivity=1) + segs.append(seg) + return np.stack(segs) + + +class BoundaryAveragePrecision(GenericAveragePrecision): + """ + Computes Average Precision given boundary prediction and ground truth instance segmentation. + """ + + def __init__(self, thresholds=None, min_instance_size=None, input_channel=0, **kwargs): + super().__init__(min_instance_size=min_instance_size, use_last_target=True) + if thresholds is None: + thresholds = [0.3, 0.4, 0.5, 0.6] + assert isinstance(thresholds, list) + self.thresholds = thresholds + self.input_channel = input_channel + + def input_to_seg(self, input, target=None): + input = input[self.input_channel] + segs = [] + for th in self.thresholds: + seg = measure.label(np.logical_not(input > th).astype(np.uint8), background=0, connectivity=1) + segs.append(seg) + return np.stack(segs) + + +class PSNR: + """ + Computes Peak Signal to Noise Ratio. Use e.g. as an eval metric for denoising task + """ + + def __init__(self, **kwargs): + pass + + def __call__(self, input, target): + input, target = convert_to_numpy(input, target) + return peak_signal_noise_ratio(target, input) + + +class MSE: + """ + Computes MSE between input and target + """ + + def __init__(self, **kwargs): + pass + + def __call__(self, input, target): + input, target = convert_to_numpy(input, target) + return mean_squared_error(input, target) + + +def get_evaluation_metric(config): + """ + Returns the evaluation metric function based on provided configuration + :param config: (dict) a top level configuration object containing the 'eval_metric' key + :return: an instance of the evaluation metric + """ + + def _metric_class(class_name): + m = importlib.import_module('pytorch3dunet.unet3d.metrics') + clazz = getattr(m, class_name) + return clazz + + assert 'eval_metric' in config, 'Could not find evaluation metric configuration' + metric_config = config['eval_metric'] + metric_class = _metric_class(metric_config['name']) + return metric_class(**metric_config) diff --git a/build/lib/pytorch3dunet/unet3d/model.py b/build/lib/pytorch3dunet/unet3d/model.py new file mode 100644 index 00000000..e4de49a7 --- /dev/null +++ b/build/lib/pytorch3dunet/unet3d/model.py @@ -0,0 +1,249 @@ +import torch.nn as nn + +from pytorch3dunet.unet3d.buildingblocks import DoubleConv, ResNetBlock, ResNetBlockSE, \ + create_decoders, create_encoders +from pytorch3dunet.unet3d.utils import get_class, number_of_features_per_level + + +class AbstractUNet(nn.Module): + """ + Base class for standard and residual UNet. + + Args: + in_channels (int): number of input channels + out_channels (int): number of output segmentation masks; + Note that the of out_channels might correspond to either + different semantic classes or to different binary segmentation mask. + It's up to the user of the class to interpret the out_channels and + use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class) + or BCEWithLogitsLoss (two-class) respectively) + f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number + of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4 + final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the final 1x1 convolution, + otherwise apply nn.Softmax. In effect only if `self.training == False`, i.e. during validation/testing + basic_module: basic model for the encoder/decoder (DoubleConv, ResNetBlock, ....) + layer_order (string): determines the order of layers in `SingleConv` module. + E.g. 'crg' stands for GroupNorm3d+Conv3d+ReLU. See `SingleConv` for more info + num_groups (int): number of groups for the GroupNorm + num_levels (int): number of levels in the encoder/decoder path (applied only if f_maps is an int) + default: 4 + is_segmentation (bool): if True and the model is in eval mode, Sigmoid/Softmax normalization is applied + after the final convolution; if False (regression problem) the normalization layer is skipped + conv_kernel_size (int or tuple): size of the convolving kernel in the basic_module + pool_kernel_size (int or tuple): the size of the window + conv_padding (int or tuple): add zero-padding added to all three sides of the input + conv_upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2 + upsample (str): algorithm used for decoder upsampling: + InterpolateUpsampling: 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area' + TransposeConvUpsampling: 'deconv' + No upsampling: None + Default: 'default' (chooses automatically) + dropout_prob (float or tuple): dropout probability, default: 0.1 + is3d (bool): if True the model is 3D, otherwise 2D, default: True + """ + + def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr', + num_groups=8, num_levels=4, is_segmentation=True, conv_kernel_size=3, pool_kernel_size=2, + conv_padding=1, conv_upscale=2, upsample='default', dropout_prob=0.1, is3d=True): + super(AbstractUNet, self).__init__() + + if isinstance(f_maps, int): + f_maps = number_of_features_per_level(f_maps, num_levels=num_levels) + + assert isinstance(f_maps, list) or isinstance(f_maps, tuple) + assert len(f_maps) > 1, "Required at least 2 levels in the U-Net" + if 'g' in layer_order: + assert num_groups is not None, "num_groups must be specified if GroupNorm is used" + + # create encoder path + self.encoders = create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, + conv_padding, conv_upscale, dropout_prob, + layer_order, num_groups, pool_kernel_size, is3d) + + # create decoder path + self.decoders = create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, + layer_order, num_groups, upsample, dropout_prob, + is3d) + + # in the last layer a 1×1 convolution reduces the number of output channels to the number of labels + if is3d: + self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1) + else: + self.final_conv = nn.Conv2d(f_maps[0], out_channels, 1) + + if is_segmentation: + # semantic segmentation problem + if final_sigmoid: + self.final_activation = nn.Sigmoid() + else: + self.final_activation = nn.Softmax(dim=1) + else: + # regression problem + self.final_activation = None + + def forward(self, x): + # encoder part + encoders_features = [] + for encoder in self.encoders: + x = encoder(x) + # reverse the encoder outputs to be aligned with the decoder + encoders_features.insert(0, x) + + # remove the last encoder's output from the list + # !!remember: it's the 1st in the list + encoders_features = encoders_features[1:] + + # decoder part + for decoder, encoder_features in zip(self.decoders, encoders_features): + # pass the output from the corresponding encoder and the output + # of the previous decoder + x = decoder(encoder_features, x) + + x = self.final_conv(x) + + # apply final_activation (i.e. Sigmoid or Softmax) only during prediction. + # During training the network outputs logits + if not self.training and self.final_activation is not None: + x = self.final_activation(x) + + return x + + +class UNet3D(AbstractUNet): + """ + 3DUnet model from + `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation" + `. + + Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder + """ + + def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', + num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1, + conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs): + super(UNet3D, self).__init__(in_channels=in_channels, + out_channels=out_channels, + final_sigmoid=final_sigmoid, + basic_module=DoubleConv, + f_maps=f_maps, + layer_order=layer_order, + num_groups=num_groups, + num_levels=num_levels, + is_segmentation=is_segmentation, + conv_padding=conv_padding, + conv_upscale=conv_upscale, + upsample=upsample, + dropout_prob=dropout_prob, + is3d=True) + + +class ResidualUNet3D(AbstractUNet): + """ + Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf. + Uses ResNetBlock as a basic building block, summation joining instead + of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts). + Since the model effectively becomes a residual net, in theory it allows for deeper UNet. + """ + + def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', + num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1, + conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs): + super(ResidualUNet3D, self).__init__(in_channels=in_channels, + out_channels=out_channels, + final_sigmoid=final_sigmoid, + basic_module=ResNetBlock, + f_maps=f_maps, + layer_order=layer_order, + num_groups=num_groups, + num_levels=num_levels, + is_segmentation=is_segmentation, + conv_padding=conv_padding, + conv_upscale=conv_upscale, + upsample=upsample, + dropout_prob=dropout_prob, + is3d=True) + + +class ResidualUNetSE3D(AbstractUNet): + """_summary_ + Residual 3DUnet model implementation with squeeze and excitation based on + https://arxiv.org/pdf/1706.00120.pdf. + Uses ResNetBlockSE as a basic building block, summation joining instead + of concatenation joining and transposed convolutions for upsampling (watch + out for block artifacts). Since the model effectively becomes a residual + net, in theory it allows for deeper UNet. + """ + + def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', + num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1, + conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs): + super(ResidualUNetSE3D, self).__init__(in_channels=in_channels, + out_channels=out_channels, + final_sigmoid=final_sigmoid, + basic_module=ResNetBlockSE, + f_maps=f_maps, + layer_order=layer_order, + num_groups=num_groups, + num_levels=num_levels, + is_segmentation=is_segmentation, + conv_padding=conv_padding, + conv_upscale=conv_upscale, + upsample=upsample, + dropout_prob=dropout_prob, + is3d=True) + + +class UNet2D(AbstractUNet): + """ + 2DUnet model from + `"U-Net: Convolutional Networks for Biomedical Image Segmentation" ` + """ + + def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', + num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1, + conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs): + super(UNet2D, self).__init__(in_channels=in_channels, + out_channels=out_channels, + final_sigmoid=final_sigmoid, + basic_module=DoubleConv, + f_maps=f_maps, + layer_order=layer_order, + num_groups=num_groups, + num_levels=num_levels, + is_segmentation=is_segmentation, + conv_padding=conv_padding, + conv_upscale=conv_upscale, + upsample=upsample, + dropout_prob=dropout_prob, + is3d=False) + + +class ResidualUNet2D(AbstractUNet): + """ + Residual 2DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf. + """ + + def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', + num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1, + conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs): + super(ResidualUNet2D, self).__init__(in_channels=in_channels, + out_channels=out_channels, + final_sigmoid=final_sigmoid, + basic_module=ResNetBlock, + f_maps=f_maps, + layer_order=layer_order, + num_groups=num_groups, + num_levels=num_levels, + is_segmentation=is_segmentation, + conv_padding=conv_padding, + conv_upscale=conv_upscale, + upsample=upsample, + dropout_prob=dropout_prob, + is3d=False) + + +def get_model(model_config): + model_class = get_class(model_config['name'], modules=[ + 'pytorch3dunet.unet3d.model' + ]) + return model_class(**model_config) diff --git a/build/lib/pytorch3dunet/unet3d/predictor.py b/build/lib/pytorch3dunet/unet3d/predictor.py new file mode 100644 index 00000000..c9b4f6eb --- /dev/null +++ b/build/lib/pytorch3dunet/unet3d/predictor.py @@ -0,0 +1,281 @@ +import os +import time +from concurrent import futures +from pathlib import Path + +import h5py +import numpy as np +import torch +from skimage import measure +from torch import nn +from tqdm import tqdm + +from pytorch3dunet.datasets.hdf5 import AbstractHDF5Dataset +from pytorch3dunet.datasets.utils import SliceBuilder, remove_padding +from pytorch3dunet.unet3d.model import UNet2D +from pytorch3dunet.unet3d.utils import get_logger + +logger = get_logger('UNetPredictor') + + +def _get_output_file(dataset, suffix='_predictions', output_dir=None): + input_dir, file_name = os.path.split(dataset.file_path) + if output_dir is None: + output_dir = input_dir + output_filename = os.path.splitext(file_name)[0] + suffix + '.h5' + return Path(output_dir) / output_filename + + +def _is_2d_model(model): + if isinstance(model, nn.DataParallel): + model = model.module + return isinstance(model, UNet2D) + + +class _AbstractPredictor: + def __init__(self, + model: nn.Module, + output_dir: str, + out_channels: int, + output_dataset: str = 'predictions', + save_segmentation: bool = False, + prediction_channel: int = None, + **kwargs): + """ + Base class for predictors. + Args: + model: segmentation model + output_dir: directory where the predictions will be saved + out_channels: number of output channels of the model + output_dataset: name of the dataset in the H5 file where the predictions will be saved + save_segmentation: if true the segmentation will be saved instead of the probability maps + prediction_channel: save only the specified channel from the network output + """ + self.model = model + self.output_dir = output_dir + self.out_channels = out_channels + self.output_dataset = output_dataset + self.save_segmentation = save_segmentation + self.prediction_channel = prediction_channel + + def __call__(self, test_loader): + raise NotImplementedError + + +class StandardPredictor(_AbstractPredictor): + """ + Applies the model on the given dataset and saves the result as H5 file. + Predictions from the network are kept in memory. If the results from the network don't fit in into RAM + use `LazyPredictor` instead. + + The output dataset names inside the H5 is given by `output_dataset` config argument. + """ + + def __init__(self, + model: nn.Module, + output_dir: str, + out_channels: int, + output_dataset: str = 'predictions', + save_segmentation: bool = False, + prediction_channel: int = None, + **kwargs): + super().__init__(model, output_dir, out_channels, output_dataset, save_segmentation, prediction_channel, + **kwargs) + + def __call__(self, test_loader): + assert isinstance(test_loader.dataset, AbstractHDF5Dataset) + logger.info(f"Processing '{test_loader.dataset.file_path}'...") + start = time.perf_counter() + + logger.info(f'Running inference on {len(test_loader)} batches') + # dimensionality of the output predictions + volume_shape = test_loader.dataset.volume_shape() + if self.prediction_channel is not None: + # single channel prediction map + prediction_maps_shape = (1,) + volume_shape + else: + prediction_maps_shape = (self.out_channels,) + volume_shape + + # create destination H5 file + output_file = _get_output_file(dataset=test_loader.dataset, output_dir=self.output_dir) + with h5py.File(output_file, 'w') as h5_output_file: + # allocate prediction and normalization arrays + logger.info('Allocating prediction and normalization arrays...') + prediction_map, normalization_mask = self._allocate_prediction_maps(prediction_maps_shape, h5_output_file) + + # determine halo used for padding + patch_halo = test_loader.dataset.halo_shape + + # Sets the module in evaluation mode explicitly + # It is necessary for batchnorm/dropout layers if present as well as final Sigmoid/Softmax to be applied + self.model.eval() + # Run predictions on the entire input dataset + with torch.no_grad(): + for input, indices in tqdm(test_loader): + # send batch to gpu + if torch.cuda.is_available(): + input = input.pin_memory().cuda(non_blocking=True) + + if _is_2d_model(self.model): + # remove the singleton z-dimension from the input + input = torch.squeeze(input, dim=-3) + # forward pass + prediction = self.model(input) + # add the singleton z-dimension to the output + prediction = torch.unsqueeze(prediction, dim=-3) + else: + # forward pass + prediction = self.model(input) + + # unpad the predicted patch + prediction = remove_padding(prediction, patch_halo) + # convert to numpy array + prediction = prediction.cpu().numpy() + # for each batch sample + for pred, index in zip(prediction, indices): + # save patch index: (C,D,H,W) + if self.prediction_channel is None: + channel_slice = slice(0, self.out_channels) + else: + # use only the specified channel + channel_slice = slice(0, 1) + pred = np.expand_dims(pred[self.prediction_channel], axis=0) + + # add channel dimension to the index + index = (channel_slice,) + tuple(index) + # accumulate probabilities into the output prediction array + prediction_map[index] += pred + # count voxel visits for normalization + normalization_mask[index] += 1 + + logger.info(f'Finished inference in {time.perf_counter() - start:.2f} seconds') + # save results + output_type = 'segmentation' if self.save_segmentation else 'probability maps' + logger.info(f'Saving {output_type} to: {output_file}') + self._save_results(prediction_map, normalization_mask, h5_output_file, test_loader.dataset) + + def _allocate_prediction_maps(self, output_shape, output_file): + # initialize the output prediction arrays + prediction_map = np.zeros(output_shape, dtype='float32') + # initialize normalization mask in order to average out probabilities of overlapping patches + normalization_mask = np.zeros(output_shape, dtype='uint8') + return prediction_map, normalization_mask + + def _save_results(self, prediction_map, normalization_mask, output_file, dataset): + result = prediction_map / normalization_mask + if self.save_segmentation: + result = np.argmax(result, axis=0).astype('uint16') + output_file.create_dataset(self.output_dataset, data=result, compression="gzip") + + +class LazyPredictor(StandardPredictor): + """ + Applies the model on the given dataset and saves the result in the `output_file` in the H5 format. + Predicted patches are directly saved into the H5 and they won't be stored in memory. Since this predictor + is slower than the `StandardPredictor` it should only be used when the predicted volume does not fit into RAM. + """ + + def __init__(self, + model: nn.Module, + output_dir: str, + out_channels: int, + output_dataset: str = 'predictions', + save_segmentation: bool = False, + prediction_channel: int = None, + **kwargs): + super().__init__(model, output_dir, out_channels, output_dataset, save_segmentation, prediction_channel, + **kwargs) + + def _allocate_prediction_maps(self, output_shape, output_file): + # allocate datasets for probability maps + prediction_map = output_file.create_dataset(self.output_dataset, + shape=output_shape, + dtype='float32', + chunks=True, + compression='gzip') + # allocate datasets for normalization masks + normalization_mask = output_file.create_dataset('normalization', + shape=output_shape, + dtype='uint8', + chunks=True, + compression='gzip') + return prediction_map, normalization_mask + + def _save_results(self, prediction_map, normalization_mask, output_file, dataset): + z, y, x = prediction_map.shape[1:] + # take slices which are 1/27 of the original volume + patch_shape = (z // 3, y // 3, x // 3) + if self.save_segmentation: + output_file.create_dataset('segmentation', shape=(z, y, x), dtype='uint16', chunks=True, compression='gzip') + + for index in SliceBuilder._build_slices(prediction_map, patch_shape=patch_shape, stride_shape=patch_shape): + logger.info(f'Normalizing slice: {index}') + prediction_map[index] /= normalization_mask[index] + # make sure to reset the slice that has been visited already in order to avoid 'double' normalization + # when the patches overlap with each other + normalization_mask[index] = 1 + # save segmentation + if self.save_segmentation: + output_file['segmentation'][index[1:]] = np.argmax(prediction_map[index], axis=0).astype('uint16') + + del output_file['normalization'] + if self.save_segmentation: + del output_file[self.output_dataset] + + +class DSB2018Predictor(_AbstractPredictor): + def __init__(self, model, output_dir, config, save_segmentation=True, pmaps_thershold=0.5, **kwargs): + super().__init__(model, output_dir, config, **kwargs) + self.pmaps_threshold = pmaps_thershold + self.save_segmentation = save_segmentation + + def _slice_from_pad(self, pad): + if pad == 0: + return slice(None, None) + else: + return slice(pad, -pad) + + def __call__(self, test_loader): + # Sets the module in evaluation mode explicitly + self.model.eval() + # initial process pool for saving results to disk + executor = futures.ProcessPoolExecutor(max_workers=32) + # Run predictions on the entire input dataset + with torch.no_grad(): + for img, path in test_loader: + # send batch to gpu + if torch.cuda.is_available(): + img = img.cuda(non_blocking=True) + # forward pass + pred = self.model(img) + + executor.submit( + dsb_save_batch, + self.output_dir, + path + ) + + print('Waiting for all predictions to be saved to disk...') + executor.shutdown(wait=True) + + +def dsb_save_batch(output_dir, path, pred, save_segmentation=True, pmaps_thershold=0.5): + def _pmaps_to_seg(pred): + mask = (pred > pmaps_thershold) + return measure.label(mask).astype('uint16') + + # convert to numpy array + for single_pred, single_path in zip(pred, path): + logger.info(f'Processing {single_path}') + single_pred = single_pred.squeeze() + + # save to h5 file + out_file = os.path.splitext(single_path)[0] + '_predictions.h5' + if output_dir is not None: + out_file = os.path.join(output_dir, os.path.split(out_file)[1]) + + with h5py.File(out_file, 'w') as f: + # logger.info(f'Saving output to {out_file}') + f.create_dataset('predictions', data=single_pred, compression='gzip') + if save_segmentation: + f.create_dataset('segmentation', data=_pmaps_to_seg(single_pred), compression='gzip') diff --git a/build/lib/pytorch3dunet/unet3d/se.py b/build/lib/pytorch3dunet/unet3d/se.py new file mode 100644 index 00000000..23fac3d7 --- /dev/null +++ b/build/lib/pytorch3dunet/unet3d/se.py @@ -0,0 +1,113 @@ +""" +3D Squeeze and Excitation Modules +***************************** +3D Extensions of the following 2D squeeze and excitation blocks: + 1. `Channel Squeeze and Excitation `_ + 2. `Spatial Squeeze and Excitation `_ + 3. `Channel and Spatial Squeeze and Excitation `_ +New Project & Excite block, designed specifically for 3D inputs + 'quote' + Coded by -- Anne-Marie Rickmann (https://github.com/arickm) +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +class ChannelSELayer3D(nn.Module): + """ + 3D extension of Squeeze-and-Excitation (SE) block described in: + *Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507* + *Zhu et al., AnatomyNet, arXiv:arXiv:1808.05238* + """ + + def __init__(self, num_channels, reduction_ratio=2): + """ + Args: + num_channels (int): No of input channels + reduction_ratio (int): By how much should the num_channels should be reduced + """ + super(ChannelSELayer3D, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool3d(1) + num_channels_reduced = num_channels // reduction_ratio + self.reduction_ratio = reduction_ratio + self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True) + self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True) + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + batch_size, num_channels, D, H, W = x.size() + # Average along each channel + squeeze_tensor = self.avg_pool(x) + + # channel excitation + fc_out_1 = self.relu(self.fc1(squeeze_tensor.view(batch_size, num_channels))) + fc_out_2 = self.sigmoid(self.fc2(fc_out_1)) + + output_tensor = torch.mul(x, fc_out_2.view(batch_size, num_channels, 1, 1, 1)) + + return output_tensor + + +class SpatialSELayer3D(nn.Module): + """ + 3D extension of SE block -- squeezing spatially and exciting channel-wise described in: + *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018* + """ + + def __init__(self, num_channels): + """ + Args: + num_channels (int): No of input channels + """ + super(SpatialSELayer3D, self).__init__() + self.conv = nn.Conv3d(num_channels, 1, 1) + self.sigmoid = nn.Sigmoid() + + def forward(self, x, weights=None): + """ + Args: + weights (torch.Tensor): weights for few shot learning + x: X, shape = (batch_size, num_channels, D, H, W) + + Returns: + (torch.Tensor): output_tensor + """ + # channel squeeze + batch_size, channel, D, H, W = x.size() + + if weights: + weights = weights.view(1, channel, 1, 1) + out = F.conv2d(x, weights) + else: + out = self.conv(x) + + squeeze_tensor = self.sigmoid(out) + + # spatial excitation + output_tensor = torch.mul(x, squeeze_tensor.view(batch_size, 1, D, H, W)) + + return output_tensor + + +class ChannelSpatialSELayer3D(nn.Module): + """ + 3D extension of concurrent spatial and channel squeeze & excitation: + *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, arXiv:1803.02579* + """ + + def __init__(self, num_channels, reduction_ratio=2): + """ + Args: + num_channels (int): No of input channels + reduction_ratio (int): By how much should the num_channels should be reduced + """ + super(ChannelSpatialSELayer3D, self).__init__() + self.cSE = ChannelSELayer3D(num_channels, reduction_ratio) + self.sSE = SpatialSELayer3D(num_channels) + + def forward(self, input_tensor): + output_tensor = torch.max(self.cSE(input_tensor), self.sSE(input_tensor)) + return output_tensor diff --git a/build/lib/pytorch3dunet/unet3d/seg_metrics.py b/build/lib/pytorch3dunet/unet3d/seg_metrics.py new file mode 100644 index 00000000..e713ea23 --- /dev/null +++ b/build/lib/pytorch3dunet/unet3d/seg_metrics.py @@ -0,0 +1,123 @@ +import numpy as np +from skimage.metrics import contingency_table + + +def precision(tp, fp, fn): + return tp / (tp + fp) if tp > 0 else 0 + + +def recall(tp, fp, fn): + return tp / (tp + fn) if tp > 0 else 0 + + +def accuracy(tp, fp, fn): + return tp / (tp + fp + fn) if tp > 0 else 0 + + +def f1(tp, fp, fn): + return (2 * tp) / (2 * tp + fp + fn) if tp > 0 else 0 + + +def _relabel(input): + _, unique_labels = np.unique(input, return_inverse=True) + return unique_labels.reshape(input.shape) + + +def _iou_matrix(gt, seg): + # relabel gt and seg for smaller memory footprint of contingency table + gt = _relabel(gt) + seg = _relabel(seg) + + # get number of overlapping pixels between GT and SEG + n_inter = contingency_table(gt, seg).A + + # number of pixels for GT instances + n_gt = n_inter.sum(axis=1, keepdims=True) + # number of pixels for SEG instances + n_seg = n_inter.sum(axis=0, keepdims=True) + + # number of pixels in the union between GT and SEG instances + n_union = n_gt + n_seg - n_inter + + iou_matrix = n_inter / n_union + # make sure that the values are within [0,1] range + assert 0 <= np.min(iou_matrix) <= np.max(iou_matrix) <= 1 + + return iou_matrix + + +class SegmentationMetrics: + """ + Computes precision, recall, accuracy, f1 score for a given ground truth and predicted segmentation. + Contingency table for a given ground truth and predicted segmentation is computed eagerly upon construction + of the instance of `SegmentationMetrics`. + + Args: + gt (ndarray): ground truth segmentation + seg (ndarray): predicted segmentation + """ + + def __init__(self, gt, seg): + self.iou_matrix = _iou_matrix(gt, seg) + + def metrics(self, iou_threshold): + """ + Computes precision, recall, accuracy, f1 score at a given IoU threshold + """ + # ignore background + iou_matrix = self.iou_matrix[1:, 1:] + detection_matrix = (iou_matrix > iou_threshold).astype(np.uint8) + n_gt, n_seg = detection_matrix.shape + + # if the iou_matrix is empty or all values are 0 + trivial = min(n_gt, n_seg) == 0 or np.all(detection_matrix == 0) + if trivial: + tp = fp = fn = 0 + else: + # count non-zero rows to get the number of TP + tp = np.count_nonzero(detection_matrix.sum(axis=1)) + # count zero rows to get the number of FN + fn = n_gt - tp + # count zero columns to get the number of FP + fp = n_seg - np.count_nonzero(detection_matrix.sum(axis=0)) + + return { + 'precision': precision(tp, fp, fn), + 'recall': recall(tp, fp, fn), + 'accuracy': accuracy(tp, fp, fn), + 'f1': f1(tp, fp, fn) + } + + +class Accuracy: + """ + Computes accuracy between ground truth and predicted segmentation a a given threshold value. + Defined as: AC = TP / (TP + FP + FN). + Kaggle DSB2018 calls it Precision, see: + https://www.kaggle.com/stkbailey/step-by-step-explanation-of-scoring-metric. + """ + + def __init__(self, iou_threshold): + self.iou_threshold = iou_threshold + + def __call__(self, input_seg, gt_seg): + metrics = SegmentationMetrics(gt_seg, input_seg).metrics(self.iou_threshold) + return metrics['accuracy'] + + +class AveragePrecision: + """ + Average precision taken for the IoU range (0.5, 0.95) with a step of 0.05 as defined in: + https://www.kaggle.com/stkbailey/step-by-step-explanation-of-scoring-metric + """ + + def __init__(self): + self.iou_range = np.linspace(0.50, 0.95, 10) + + def __call__(self, input_seg, gt_seg): + # compute contingency_table + sm = SegmentationMetrics(gt_seg, input_seg) + # compute accuracy for each threshold + acc = [sm.metrics(iou)['accuracy'] for iou in self.iou_range] + # return the average + return np.mean(acc) diff --git a/build/lib/pytorch3dunet/unet3d/trainer.py b/build/lib/pytorch3dunet/unet3d/trainer.py new file mode 100644 index 00000000..4b59d568 --- /dev/null +++ b/build/lib/pytorch3dunet/unet3d/trainer.py @@ -0,0 +1,404 @@ +import os +import torch +import torch.nn as nn +from torch.optim.lr_scheduler import ReduceLROnPlateau +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime + +from pytorch3dunet.datasets.utils import get_train_loaders +from pytorch3dunet.unet3d.losses import get_loss_criterion +from pytorch3dunet.unet3d.metrics import get_evaluation_metric +from pytorch3dunet.unet3d.model import get_model, UNet2D +from pytorch3dunet.unet3d.utils import get_logger, get_tensorboard_formatter, create_optimizer, \ + create_lr_scheduler, get_number_of_learnable_parameters +from . import utils + +logger = get_logger('UNetTrainer') + + +def create_trainer(config): + # Create the model + model = get_model(config['model']) + + if torch.cuda.device_count() > 1 and not config['device'] == 'cpu': + model = nn.DataParallel(model) + logger.info(f'Using {torch.cuda.device_count()} GPUs for prediction') + if torch.cuda.is_available() and not config['device'] == 'cpu': + model = model.cuda() + + # Log the number of learnable parameters + logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}') + + # Create loss criterion + loss_criterion = get_loss_criterion(config) + # Create evaluation metric + eval_criterion = get_evaluation_metric(config) + + # Create data loaders + loaders = get_train_loaders(config) + + # Create the optimizer + optimizer = create_optimizer(config['optimizer'], model) + + # Create learning rate adjustment strategy + lr_scheduler = create_lr_scheduler(config.get('lr_scheduler', None), optimizer) + + trainer_config = config['trainer'] + # Create tensorboard formatter + tensorboard_formatter = get_tensorboard_formatter(trainer_config.pop('tensorboard_formatter', None)) + # Create trainer + resume = trainer_config.pop('resume', None) + pre_trained = trainer_config.pop('pre_trained', None) + + return UNetTrainer(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, loss_criterion=loss_criterion, + eval_criterion=eval_criterion, loaders=loaders, tensorboard_formatter=tensorboard_formatter, + resume=resume, pre_trained=pre_trained, **trainer_config) + + +class UNetTrainer: + """UNet trainer. + + Args: + model (Unet3D): UNet 3D model to be trained + optimizer (nn.optim.Optimizer): optimizer used for training + lr_scheduler (torch.optim.lr_scheduler._LRScheduler): learning rate scheduler + WARN: bear in mind that lr_scheduler.step() is invoked after every validation step + (i.e. validate_after_iters) not after every epoch. So e.g. if one uses StepLR with step_size=30 + the learning rate will be adjusted after every 30 * validate_after_iters iterations. + loss_criterion (callable): loss function + eval_criterion (callable): used to compute training/validation metric (such as Dice, IoU, AP or Rand score) + saving the best checkpoint is based on the result of this function on the validation set + loaders (dict): 'train' and 'val' loaders + checkpoint_dir (string): dir for saving checkpoints and tensorboard logs + max_num_epochs (int): maximum number of epochs + max_num_iterations (int): maximum number of iterations + validate_after_iters (int): validate after that many iterations + log_after_iters (int): number of iterations before logging to tensorboard + validate_iters (int): number of validation iterations, if None validate + on the whole validation set + eval_score_higher_is_better (bool): if True higher eval scores are considered better + best_eval_score (float): best validation score so far (higher better) + num_iterations (int): useful when loading the model from the checkpoint + num_epoch (int): useful when loading the model from the checkpoint + tensorboard_formatter (callable): converts a given batch of input/output/target image to a series of images + that can be displayed in tensorboard + skip_train_validation (bool): if True eval_criterion is not evaluated on the training set (used mostly when + evaluation is expensive) + """ + + def __init__(self, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders, checkpoint_dir, + max_num_epochs, max_num_iterations, validate_after_iters=200, log_after_iters=100, validate_iters=None, + num_iterations=1, num_epoch=0, eval_score_higher_is_better=True, tensorboard_formatter=None, + skip_train_validation=False, resume=None, pre_trained=None, **kwargs): + + self.model = model + self.optimizer = optimizer + self.scheduler = lr_scheduler + self.loss_criterion = loss_criterion + self.eval_criterion = eval_criterion + self.loaders = loaders + self.checkpoint_dir = checkpoint_dir + self.max_num_epochs = max_num_epochs + self.max_num_iterations = max_num_iterations + self.validate_after_iters = validate_after_iters + self.log_after_iters = log_after_iters + self.validate_iters = validate_iters + self.eval_score_higher_is_better = eval_score_higher_is_better + + logger.info(model) + logger.info(f'eval_score_higher_is_better: {eval_score_higher_is_better}') + + # initialize the best_eval_score + if eval_score_higher_is_better: + self.best_eval_score = float('-inf') + else: + self.best_eval_score = float('+inf') + + self.writer = SummaryWriter( + log_dir=os.path.join( + checkpoint_dir, 'logs', + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ) + ) + + assert tensorboard_formatter is not None, 'TensorboardFormatter must be provided' + self.tensorboard_formatter = tensorboard_formatter + + self.num_iterations = num_iterations + self.num_epochs = num_epoch + self.skip_train_validation = skip_train_validation + + if resume is not None: + logger.info(f"Loading checkpoint '{resume}'...") + state = utils.load_checkpoint(resume, self.model, self.optimizer) + logger.info( + f"Checkpoint loaded from '{resume}'. Epoch: {state['num_epochs']}. Iteration: {state['num_iterations']}. " + f"Best val score: {state['best_eval_score']}." + ) + self.best_eval_score = state['best_eval_score'] + self.num_iterations = state['num_iterations'] + self.num_epochs = state['num_epochs'] + self.checkpoint_dir = os.path.split(resume)[0] + elif pre_trained is not None: + logger.info(f"Logging pre-trained model from '{pre_trained}'...") + utils.load_checkpoint(pre_trained, self.model, None) + if 'checkpoint_dir' not in kwargs: + self.checkpoint_dir = os.path.split(pre_trained)[0] + + def fit(self): + for _ in range(self.num_epochs, self.max_num_epochs): + # train for one epoch + should_terminate = self.train() + + if should_terminate: + logger.info('Stopping criterion is satisfied. Finishing training') + return + + self.num_epochs += 1 + logger.info(f"Reached maximum number of epochs: {self.max_num_epochs}. Finishing training...") + + def train(self): + """Trains the model for 1 epoch. + + Returns: + True if the training should be terminated immediately, False otherwise + """ + train_losses = utils.RunningAverage() + train_eval_scores = utils.RunningAverage() + + # sets the model in training mode + self.model.train() + + for t in self.loaders['train']: + logger.info(f'Training iteration [{self.num_iterations}/{self.max_num_iterations}]. ' + f'Epoch [{self.num_epochs}/{self.max_num_epochs - 1}]') + + input, target, weight = self._split_training_batch(t) + + output, loss = self._forward_pass(input, target, weight) + + train_losses.update(loss.item(), self._batch_size(input)) + + # compute gradients and update parameters + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + if self.num_iterations % self.validate_after_iters == 0: + # set the model in eval mode + self.model.eval() + # evaluate on validation set + eval_score = self.validate() + # set the model back to training mode + self.model.train() + + # adjust learning rate if necessary + if isinstance(self.scheduler, ReduceLROnPlateau): + self.scheduler.step(eval_score) + elif self.scheduler is not None: + self.scheduler.step() + + # log current learning rate in tensorboard + self._log_lr() + # remember best validation metric + is_best = self._is_best_eval_score(eval_score) + + # save checkpoint + self._save_checkpoint(is_best) + + if self.num_iterations % self.log_after_iters == 0: + # compute eval criterion + if not self.skip_train_validation: + # apply final activation before calculating eval score + if isinstance(self.model, nn.DataParallel): + final_activation = self.model.module.final_activation + else: + final_activation = self.model.final_activation + + if final_activation is not None: + act_output = final_activation(output) + else: + act_output = output + eval_score = self.eval_criterion(act_output, target) + train_eval_scores.update(eval_score.item(), self._batch_size(input)) + + # log stats, params and images + logger.info( + f'Training stats. Loss: {train_losses.avg}. Evaluation score: {train_eval_scores.avg}') + self._log_stats('train', train_losses.avg, train_eval_scores.avg) + # self._log_params() + self._log_images(input, target, output, 'train_') + + if self.should_stop(): + return True + + self.num_iterations += 1 + + return False + + def should_stop(self): + """ + Training will terminate if maximum number of iterations is exceeded or the learning rate drops below + some predefined threshold (1e-6 in our case) + """ + if self.max_num_iterations < self.num_iterations: + logger.info(f'Maximum number of iterations {self.max_num_iterations} exceeded.') + return True + + min_lr = 1e-6 + lr = self.optimizer.param_groups[0]['lr'] + if lr < min_lr: + logger.info(f'Learning rate below the minimum {min_lr}.') + return True + + return False + + def validate(self): + logger.info('Validating...') + + val_losses = utils.RunningAverage() + val_scores = utils.RunningAverage() + + with torch.no_grad(): + for i, t in enumerate(self.loaders['val']): + logger.info(f'Validation iteration {i}') + + input, target, weight = self._split_training_batch(t) + + output, loss = self._forward_pass(input, target, weight) + val_losses.update(loss.item(), self._batch_size(input)) + + if i % 100 == 0: + self._log_images(input, target, output, 'val_') + + eval_score = self.eval_criterion(output, target) + val_scores.update(eval_score.item(), self._batch_size(input)) + + if self.validate_iters is not None and self.validate_iters <= i: + # stop validation + break + + self._log_stats('val', val_losses.avg, val_scores.avg) + logger.info(f'Validation finished. Loss: {val_losses.avg}. Evaluation score: {val_scores.avg}') + return val_scores.avg + + def _split_training_batch(self, t): + def _move_to_gpu(input): + if isinstance(input, tuple) or isinstance(input, list): + return tuple([_move_to_gpu(x) for x in input]) + else: + if torch.cuda.is_available(): + input = input.cuda(non_blocking=True) + return input + + t = _move_to_gpu(t) + weight = None + if len(t) == 2: + input, target = t + else: + input, target, weight = t + return input, target, weight + + def _forward_pass(self, input, target, weight=None): + if isinstance(self.model, UNet2D): + # remove the singleton z-dimension from the input + input = torch.squeeze(input, dim=-3) + # forward pass + output = self.model(input) + # add the singleton z-dimension to the output + output = torch.unsqueeze(output, dim=-3) + else: + # forward pass + output = self.model(input) + + # compute the loss + if weight is None: + loss = self.loss_criterion(output, target) + else: + loss = self.loss_criterion(output, target, weight) + + return output, loss + + def _is_best_eval_score(self, eval_score): + if self.eval_score_higher_is_better: + is_best = eval_score > self.best_eval_score + else: + is_best = eval_score < self.best_eval_score + + if is_best: + logger.info(f'Saving new best evaluation metric: {eval_score}') + self.best_eval_score = eval_score + + return is_best + + def _save_checkpoint(self, is_best): + # remove `module` prefix from layer names when using `nn.DataParallel` + # see: https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/20 + if isinstance(self.model, nn.DataParallel): + state_dict = self.model.module.state_dict() + else: + state_dict = self.model.state_dict() + + last_file_path = os.path.join(self.checkpoint_dir, 'last_checkpoint.pytorch') + logger.info(f"Saving checkpoint to '{last_file_path}'") + + utils.save_checkpoint({ + 'num_epochs': self.num_epochs + 1, + 'num_iterations': self.num_iterations, + 'model_state_dict': state_dict, + 'best_eval_score': self.best_eval_score, + 'optimizer_state_dict': self.optimizer.state_dict(), + }, is_best, checkpoint_dir=self.checkpoint_dir) + + def _log_lr(self): + lr = self.optimizer.param_groups[0]['lr'] + self.writer.add_scalar('learning_rate', lr, self.num_iterations) + + def _log_stats(self, phase, loss_avg, eval_score_avg): + tag_value = { + f'{phase}_loss_avg': loss_avg, + f'{phase}_eval_score_avg': eval_score_avg + } + + for tag, value in tag_value.items(): + self.writer.add_scalar(tag, value, self.num_iterations) + + def _log_params(self): + logger.info('Logging model parameters and gradients') + for name, value in self.model.named_parameters(): + self.writer.add_histogram(name, value.data.cpu().numpy(), self.num_iterations) + self.writer.add_histogram(name + '/grad', value.grad.data.cpu().numpy(), self.num_iterations) + + def _log_images(self, input, target, prediction, prefix=''): + + if isinstance(self.model, nn.DataParallel): + net = self.model.module + else: + net = self.model + + if net.final_activation is not None: + prediction = net.final_activation(prediction) + + inputs_map = { + 'inputs': input, + 'targets': target, + 'predictions': prediction + } + img_sources = {} + for name, batch in inputs_map.items(): + if isinstance(batch, list) or isinstance(batch, tuple): + for i, b in enumerate(batch): + img_sources[f'{name}{i}'] = b.data.cpu().numpy() + else: + img_sources[name] = batch.data.cpu().numpy() + + for name, batch in img_sources.items(): + for tag, image in self.tensorboard_formatter(name, batch): + self.writer.add_image(prefix + tag, image, self.num_iterations) + + @staticmethod + def _batch_size(input): + if isinstance(input, list) or isinstance(input, tuple): + return input[0].size(0) + else: + return input.size(0) diff --git a/build/lib/pytorch3dunet/unet3d/utils.py b/build/lib/pytorch3dunet/unet3d/utils.py new file mode 100644 index 00000000..01d5559c --- /dev/null +++ b/build/lib/pytorch3dunet/unet3d/utils.py @@ -0,0 +1,366 @@ +import importlib +import logging +import os +import shutil +import sys + +import h5py +import numpy as np +import torch +from torch import optim + + +def save_checkpoint(state, is_best, checkpoint_dir): + """Saves model and training parameters at '{checkpoint_dir}/last_checkpoint.pytorch'. + If is_best==True saves '{checkpoint_dir}/best_checkpoint.pytorch' as well. + + Args: + state (dict): contains model's state_dict, optimizer's state_dict, epoch + and best evaluation metric value so far + is_best (bool): if True state contains the best model seen so far + checkpoint_dir (string): directory where the checkpoint are to be saved + """ + + if not os.path.exists(checkpoint_dir): + os.mkdir(checkpoint_dir) + + last_file_path = os.path.join(checkpoint_dir, 'last_checkpoint.pytorch') + torch.save(state, last_file_path) + if is_best: + best_file_path = os.path.join(checkpoint_dir, 'best_checkpoint.pytorch') + shutil.copyfile(last_file_path, best_file_path) + + +def load_checkpoint(checkpoint_path, model, optimizer=None, + model_key='model_state_dict', optimizer_key='optimizer_state_dict'): + """Loads model and training parameters from a given checkpoint_path + If optimizer is provided, loads optimizer's state_dict of as well. + + Args: + checkpoint_path (string): path to the checkpoint to be loaded + model (torch.nn.Module): model into which the parameters are to be copied + optimizer (torch.optim.Optimizer) optional: optimizer instance into + which the parameters are to be copied + + Returns: + state + """ + if not os.path.exists(checkpoint_path): + raise IOError(f"Checkpoint '{checkpoint_path}' does not exist") + + state = torch.load(checkpoint_path, map_location='cpu') + model.load_state_dict(state[model_key]) + + if optimizer is not None: + optimizer.load_state_dict(state[optimizer_key]) + + return state + + +def save_network_output(output_path, output, logger=None): + if logger is not None: + logger.info(f'Saving network output to: {output_path}...') + output = output.detach().cpu()[0] + with h5py.File(output_path, 'w') as f: + f.create_dataset('predictions', data=output, compression='gzip') + + +loggers = {} + + +def get_logger(name, level=logging.INFO): + global loggers + if loggers.get(name) is not None: + return loggers[name] + else: + logger = logging.getLogger(name) + logger.setLevel(level) + # Logging to console + stream_handler = logging.StreamHandler(sys.stdout) + formatter = logging.Formatter( + '%(asctime)s [%(threadName)s] %(levelname)s %(name)s - %(message)s') + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) + + loggers[name] = logger + + return logger + + +def get_number_of_learnable_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +class RunningAverage: + """Computes and stores the average + """ + + def __init__(self): + self.count = 0 + self.sum = 0 + self.avg = 0 + + def update(self, value, n=1): + self.count += n + self.sum += value * n + self.avg = self.sum / self.count + + +def number_of_features_per_level(init_channel_number, num_levels): + return [init_channel_number * 2 ** k for k in range(num_levels)] + + +class _TensorboardFormatter: + """ + Tensorboard formatters converts a given batch of images (be it input/output to the network or the target segmentation + image) to a series of images that can be displayed in tensorboard. This is the parent class for all tensorboard + formatters which ensures that returned images are in the 'CHW' format. + """ + + def __init__(self, **kwargs): + pass + + def __call__(self, name, batch): + """ + Transform a batch to a series of tuples of the form (tag, img), where `tag` corresponds to the image tag + and `img` is the image itself. + + Args: + name (str): one of 'inputs'/'targets'/'predictions' + batch (torch.tensor): 4D or 5D torch tensor + """ + + def _check_img(tag_img): + tag, img = tag_img + + assert img.ndim == 2 or img.ndim == 3, 'Only 2D (HW) and 3D (CHW) images are accepted for display' + + if img.ndim == 2: + img = np.expand_dims(img, axis=0) + else: + C = img.shape[0] + assert C == 1 or C == 3, 'Only (1, H, W) or (3, H, W) images are supported' + + return tag, img + + tagged_images = self.process_batch(name, batch) + + return list(map(_check_img, tagged_images)) + + def process_batch(self, name, batch): + raise NotImplementedError + + +class DefaultTensorboardFormatter(_TensorboardFormatter): + def __init__(self, skip_last_target=False, **kwargs): + super().__init__(**kwargs) + self.skip_last_target = skip_last_target + + def process_batch(self, name, batch): + if name == 'targets' and self.skip_last_target: + batch = batch[:, :-1, ...] + + tag_template = '{}/batch_{}/channel_{}/slice_{}' + + tagged_images = [] + + if batch.ndim == 5: + # NCDHW + slice_idx = batch.shape[2] // 2 # get the middle slice + for batch_idx in range(batch.shape[0]): + for channel_idx in range(batch.shape[1]): + tag = tag_template.format(name, batch_idx, channel_idx, slice_idx) + img = batch[batch_idx, channel_idx, slice_idx, ...] + tagged_images.append((tag, self._normalize_img(img))) + else: + # batch has no channel dim: NDHW + slice_idx = batch.shape[1] // 2 # get the middle slice + for batch_idx in range(batch.shape[0]): + tag = tag_template.format(name, batch_idx, 0, slice_idx) + img = batch[batch_idx, slice_idx, ...] + tagged_images.append((tag, self._normalize_img(img))) + + return tagged_images + + @staticmethod + def _normalize_img(img): + return np.nan_to_num((img - np.min(img)) / np.ptp(img)) + + +def _find_masks(batch, min_size=10): + """Center the z-slice in the 'middle' of a given instance, given a batch of instances + + Args: + batch (ndarray): 5d numpy tensor (NCDHW) + """ + result = [] + for b in batch: + assert b.shape[0] == 1 + patch = b[0] + z_sum = patch.sum(axis=(1, 2)) + coords = np.where(z_sum > min_size)[0] + if len(coords) > 0: + ind = coords[len(coords) // 2] + result.append(b[:, ind:ind + 1, ...]) + else: + ind = b.shape[1] // 2 + result.append(b[:, ind:ind + 1, ...]) + + return np.stack(result, axis=0) + + +def get_tensorboard_formatter(formatter_config): + if formatter_config is None: + return DefaultTensorboardFormatter() + + class_name = formatter_config['name'] + m = importlib.import_module('pytorch3dunet.unet3d.utils') + clazz = getattr(m, class_name) + return clazz(**formatter_config) + + +def expand_as_one_hot(input, C, ignore_index=None): + """ + Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector. + It is assumed that the batch dimension is present. + Args: + input (torch.Tensor): 3D/4D input image + C (int): number of channels/labels + ignore_index (int): ignore index to be kept during the expansion + Returns: + 4D/5D output torch.Tensor (NxCxSPATIAL) + """ + assert input.dim() == 4 + + # expand the input tensor to Nx1xSPATIAL before scattering + input = input.unsqueeze(1) + # create output tensor shape (NxCxSPATIAL) + shape = list(input.size()) + shape[1] = C + + if ignore_index is not None: + # create ignore_index mask for the result + mask = input.expand(shape) == ignore_index + # clone the src tensor and zero out ignore_index in the input + input = input.clone() + input[input == ignore_index] = 0 + # scatter to get the one-hot tensor + result = torch.zeros(shape).to(input.device).scatter_(1, input, 1) + # bring back the ignore_index in the result + result[mask] = ignore_index + return result + else: + # scatter to get the one-hot tensor + return torch.zeros(shape).to(input.device).scatter_(1, input, 1) + + +def convert_to_numpy(*inputs): + """ + Coverts input tensors to numpy ndarrays + + Args: + inputs (iteable of torch.Tensor): torch tensor + + Returns: + tuple of ndarrays + """ + + def _to_numpy(i): + assert isinstance(i, torch.Tensor), "Expected input to be torch.Tensor" + return i.detach().cpu().numpy() + + return (_to_numpy(i) for i in inputs) + + +def create_optimizer(optimizer_config, model): + optim_name = optimizer_config.get('name', 'Adam') + # common optimizer settings + learning_rate = optimizer_config.get('learning_rate', 1e-3) + weight_decay = optimizer_config.get('weight_decay', 0) + + # grab optimizer specific settings and init + # optimizer + if optim_name == 'Adadelta': + rho = optimizer_config.get('rho', 0.9) + optimizer = optim.Adadelta(model.parameters(), lr=learning_rate, rho=rho, + weight_decay=weight_decay) + elif optim_name == 'Adagrad': + lr_decay = optimizer_config.get('lr_decay', 0) + optimizer = optim.Adagrad(model.parameters(), lr=learning_rate, lr_decay=lr_decay, + weight_decay=weight_decay) + elif optim_name == 'AdamW': + betas = tuple(optimizer_config.get('betas', (0.9, 0.999))) + optimizer = optim.AdamW(model.parameters(), lr=learning_rate, betas=betas, + weight_decay=weight_decay) + elif optim_name == 'SparseAdam': + betas = tuple(optimizer_config.get('betas', (0.9, 0.999))) + optimizer = optim.SparseAdam(model.parameters(), lr=learning_rate, betas=betas) + elif optim_name == 'Adamax': + betas = tuple(optimizer_config.get('betas', (0.9, 0.999))) + optimizer = optim.Adamax(model.parameters(), lr=learning_rate, betas=betas, + weight_decay=weight_decay) + elif optim_name == 'ASGD': + lambd = optimizer_config.get('lambd', 0.0001) + alpha = optimizer_config.get('alpha', 0.75) + t0 = optimizer_config.get('t0', 1e6) + optimizer = optim.Adamax(model.parameters(), lr=learning_rate, lambd=lambd, + alpha=alpha, t0=t0, weight_decay=weight_decay) + elif optim_name == 'LBFGS': + max_iter = optimizer_config.get('max_iter', 20) + max_eval = optimizer_config.get('max_eval', None) + tolerance_grad = optimizer_config.get('tolerance_grad', 1e-7) + tolerance_change = optimizer_config.get('tolerance_change', 1e-9) + history_size = optimizer_config.get('history_size', 100) + optimizer = optim.LBFGS(model.parameters(), lr=learning_rate, max_iter=max_iter, + max_eval=max_eval, tolerance_grad=tolerance_grad, + tolerance_change=tolerance_change, history_size=history_size) + elif optim_name == 'NAdam': + betas = tuple(optimizer_config.get('betas', (0.9, 0.999))) + momentum_decay = optimizer_config.get('momentum_decay', 4e-3) + optimizer = optim.NAdam(model.parameters(), lr=learning_rate, betas=betas, + momentum_decay=momentum_decay, + weight_decay=weight_decay) + elif optim_name == 'RAdam': + betas = tuple(optimizer_config.get('betas', (0.9, 0.999))) + optimizer = optim.RAdam(model.parameters(), lr=learning_rate, betas=betas, + weight_decay=weight_decay) + elif optim_name == 'RMSprop': + alpha = optimizer_config.get('alpha', 0.99) + optimizer = optim.RMSprop(model.parameters(), lr=learning_rate, alpha=alpha, + weight_decay=weight_decay) + elif optim_name == 'Rprop': + momentum = optimizer_config.get('momentum', 0) + optimizer = optim.RMSprop(model.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=momentum) + elif optim_name == 'SGD': + momentum = optimizer_config.get('momentum', 0) + dampening = optimizer_config.get('dampening', 0) + nesterov = optimizer_config.get('nesterov', False) + optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, + dampening=dampening, nesterov=nesterov, + weight_decay=weight_decay) + else: # Adam is default + betas = tuple(optimizer_config.get('betas', (0.9, 0.999))) + optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=betas, + weight_decay=weight_decay) + + return optimizer + + +def create_lr_scheduler(lr_config, optimizer): + if lr_config is None: + return None + class_name = lr_config.pop('name') + m = importlib.import_module('torch.optim.lr_scheduler') + clazz = getattr(m, class_name) + # add optimizer to the config + lr_config['optimizer'] = optimizer + return clazz(**lr_config) + + +def get_class(class_name, modules): + for module in modules: + m = importlib.import_module(module) + clazz = getattr(m, class_name, None) + if clazz is not None: + return clazz + raise RuntimeError(f'Unsupported dataset class: {class_name}') From 0987f24d8d7c1e1ac45865a5402f2bab09f98de1 Mon Sep 17 00:00:00 2001 From: Shota Mizusaki Date: Fri, 12 Jul 2024 14:35:29 +0900 Subject: [PATCH 2/4] Fix: Add missing return statement --- pytorch3dunet/unet3d/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch3dunet/unet3d/config.py b/pytorch3dunet/unet3d/config.py index bb011632..0dbecffd 100644 --- a/pytorch3dunet/unet3d/config.py +++ b/pytorch3dunet/unet3d/config.py @@ -49,7 +49,7 @@ def load_config(): if device == 'cpu': logger.warning('CPU mode forced in config, this will likely result in slow training/prediction') config['device'] = 'cpu' - return config + return config, config_path if torch.cuda.is_available(): config['device'] = 'cuda' From 6b00041129871bff9790ea75c18e201828d2761a Mon Sep 17 00:00:00 2001 From: Shota Mizusaki Date: Fri, 12 Jul 2024 14:49:07 +0900 Subject: [PATCH 3/4] delete build file --- build/lib/pytorch3dunet/__init__.py | 1 - build/lib/pytorch3dunet/__version__.py | 1 - build/lib/pytorch3dunet/augment/__init__.py | 0 build/lib/pytorch3dunet/augment/transforms.py | 761 ------------------ build/lib/pytorch3dunet/datasets/__init__.py | 0 build/lib/pytorch3dunet/datasets/dsb.py | 108 --- build/lib/pytorch3dunet/datasets/hdf5.py | 293 ------- build/lib/pytorch3dunet/datasets/utils.py | 361 --------- build/lib/pytorch3dunet/predict.py | 59 -- build/lib/pytorch3dunet/train.py | 35 - build/lib/pytorch3dunet/unet3d/__init__.py | 0 .../pytorch3dunet/unet3d/buildingblocks.py | 545 ------------- build/lib/pytorch3dunet/unet3d/config.py | 79 -- build/lib/pytorch3dunet/unet3d/losses.py | 345 -------- build/lib/pytorch3dunet/unet3d/metrics.py | 445 ---------- build/lib/pytorch3dunet/unet3d/model.py | 249 ------ build/lib/pytorch3dunet/unet3d/predictor.py | 281 ------- build/lib/pytorch3dunet/unet3d/se.py | 113 --- build/lib/pytorch3dunet/unet3d/seg_metrics.py | 123 --- build/lib/pytorch3dunet/unet3d/trainer.py | 404 ---------- build/lib/pytorch3dunet/unet3d/utils.py | 366 --------- 21 files changed, 4569 deletions(-) delete mode 100644 build/lib/pytorch3dunet/__init__.py delete mode 100644 build/lib/pytorch3dunet/__version__.py delete mode 100644 build/lib/pytorch3dunet/augment/__init__.py delete mode 100644 build/lib/pytorch3dunet/augment/transforms.py delete mode 100644 build/lib/pytorch3dunet/datasets/__init__.py delete mode 100644 build/lib/pytorch3dunet/datasets/dsb.py delete mode 100644 build/lib/pytorch3dunet/datasets/hdf5.py delete mode 100644 build/lib/pytorch3dunet/datasets/utils.py delete mode 100644 build/lib/pytorch3dunet/predict.py delete mode 100644 build/lib/pytorch3dunet/train.py delete mode 100644 build/lib/pytorch3dunet/unet3d/__init__.py delete mode 100644 build/lib/pytorch3dunet/unet3d/buildingblocks.py delete mode 100644 build/lib/pytorch3dunet/unet3d/config.py delete mode 100644 build/lib/pytorch3dunet/unet3d/losses.py delete mode 100644 build/lib/pytorch3dunet/unet3d/metrics.py delete mode 100644 build/lib/pytorch3dunet/unet3d/model.py delete mode 100644 build/lib/pytorch3dunet/unet3d/predictor.py delete mode 100644 build/lib/pytorch3dunet/unet3d/se.py delete mode 100644 build/lib/pytorch3dunet/unet3d/seg_metrics.py delete mode 100644 build/lib/pytorch3dunet/unet3d/trainer.py delete mode 100644 build/lib/pytorch3dunet/unet3d/utils.py diff --git a/build/lib/pytorch3dunet/__init__.py b/build/lib/pytorch3dunet/__init__.py deleted file mode 100644 index 9226fe7e..00000000 --- a/build/lib/pytorch3dunet/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .__version__ import __version__ diff --git a/build/lib/pytorch3dunet/__version__.py b/build/lib/pytorch3dunet/__version__.py deleted file mode 100644 index 655be529..00000000 --- a/build/lib/pytorch3dunet/__version__.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = '1.8.7' diff --git a/build/lib/pytorch3dunet/augment/__init__.py b/build/lib/pytorch3dunet/augment/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/build/lib/pytorch3dunet/augment/transforms.py b/build/lib/pytorch3dunet/augment/transforms.py deleted file mode 100644 index 527d596b..00000000 --- a/build/lib/pytorch3dunet/augment/transforms.py +++ /dev/null @@ -1,761 +0,0 @@ -import importlib -import random - -import numpy as np -import torch -from scipy.ndimage import rotate, map_coordinates, gaussian_filter, convolve -from skimage import measure -from skimage.filters import gaussian -from skimage.segmentation import find_boundaries - -# WARN: use fixed random state for reproducibility; if you want to randomize on each run seed with `time.time()` e.g. -GLOBAL_RANDOM_STATE = np.random.RandomState(47) - - -class Compose(object): - def __init__(self, transforms): - self.transforms = transforms - - def __call__(self, m): - for t in self.transforms: - m = t(m) - return m - - -class RandomFlip: - """ - Randomly flips the image across the given axes. Image can be either 3D (DxHxW) or 4D (CxDxHxW). - - When creating make sure that the provided RandomStates are consistent between raw and labeled datasets, - otherwise the models won't converge. - """ - - def __init__(self, random_state, axis_prob=0.5, **kwargs): - assert random_state is not None, 'RandomState cannot be None' - self.random_state = random_state - self.axes = (0, 1, 2) - self.axis_prob = axis_prob - - def __call__(self, m): - assert m.ndim in [3, 4], 'Supports only 3D (DxHxW) or 4D (CxDxHxW) images' - - for axis in self.axes: - if self.random_state.uniform() > self.axis_prob: - if m.ndim == 3: - m = np.flip(m, axis) - else: - channels = [np.flip(m[c], axis) for c in range(m.shape[0])] - m = np.stack(channels, axis=0) - - return m - - -class RandomRotate90: - """ - Rotate an array by 90 degrees around a randomly chosen plane. Image can be either 3D (DxHxW) or 4D (CxDxHxW). - - When creating make sure that the provided RandomStates are consistent between raw and labeled datasets, - otherwise the models won't converge. - - IMPORTANT: assumes DHW axis order (that's why rotation is performed across (1,2) axis) - """ - - def __init__(self, random_state, **kwargs): - self.random_state = random_state - # always rotate around z-axis - self.axis = (1, 2) - - def __call__(self, m): - assert m.ndim in [3, 4], 'Supports only 3D (DxHxW) or 4D (CxDxHxW) images' - - # pick number of rotations at random - k = self.random_state.randint(0, 4) - # rotate k times around a given plane - if m.ndim == 3: - m = np.rot90(m, k, self.axis) - else: - channels = [np.rot90(m[c], k, self.axis) for c in range(m.shape[0])] - m = np.stack(channels, axis=0) - - return m - - -class RandomRotate: - """ - Rotate an array by a random degrees from taken from (-angle_spectrum, angle_spectrum) interval. - Rotation axis is picked at random from the list of provided axes. - """ - - def __init__(self, random_state, angle_spectrum=30, axes=None, mode='reflect', order=0, **kwargs): - if axes is None: - axes = [(1, 0), (2, 1), (2, 0)] - else: - assert isinstance(axes, list) and len(axes) > 0 - - self.random_state = random_state - self.angle_spectrum = angle_spectrum - self.axes = axes - self.mode = mode - self.order = order - - def __call__(self, m): - axis = self.axes[self.random_state.randint(len(self.axes))] - angle = self.random_state.randint(-self.angle_spectrum, self.angle_spectrum) - - if m.ndim == 3: - m = rotate(m, angle, axes=axis, reshape=False, order=self.order, mode=self.mode, cval=-1) - else: - channels = [rotate(m[c], angle, axes=axis, reshape=False, order=self.order, mode=self.mode, cval=-1) for c - in range(m.shape[0])] - m = np.stack(channels, axis=0) - - return m - - -class RandomContrast: - """ - Adjust contrast by scaling each voxel to `mean + alpha * (v - mean)`. - """ - - def __init__(self, random_state, alpha=(0.5, 1.5), mean=0.0, execution_probability=0.1, **kwargs): - self.random_state = random_state - assert len(alpha) == 2 - self.alpha = alpha - self.mean = mean - self.execution_probability = execution_probability - - def __call__(self, m): - if self.random_state.uniform() < self.execution_probability: - alpha = self.random_state.uniform(self.alpha[0], self.alpha[1]) - result = self.mean + alpha * (m - self.mean) - return np.clip(result, -1, 1) - - return m - - -# it's relatively slow, i.e. ~1s per patch of size 64x200x200, so use multiple workers in the DataLoader -# remember to use spline_order=0 when transforming the labels -class ElasticDeformation: - """ - Apply elasitc deformations of 3D patches on a per-voxel mesh. Assumes ZYX axis order (or CZYX if the data is 4D). - Based on: https://github.com/fcalvet/image_tools/blob/master/image_augmentation.py#L62 - """ - - def __init__(self, random_state, spline_order, alpha=2000, sigma=50, execution_probability=0.1, apply_3d=True, - **kwargs): - """ - :param spline_order: the order of spline interpolation (use 0 for labeled images) - :param alpha: scaling factor for deformations - :param sigma: smoothing factor for Gaussian filter - :param execution_probability: probability of executing this transform - :param apply_3d: if True apply deformations in each axis - """ - self.random_state = random_state - self.spline_order = spline_order - self.alpha = alpha - self.sigma = sigma - self.execution_probability = execution_probability - self.apply_3d = apply_3d - - def __call__(self, m): - if self.random_state.uniform() < self.execution_probability: - assert m.ndim in [3, 4] - - if m.ndim == 3: - volume_shape = m.shape - else: - volume_shape = m[0].shape - - if self.apply_3d: - dz = gaussian_filter(self.random_state.randn(*volume_shape), self.sigma, mode="reflect") * self.alpha - else: - dz = np.zeros_like(m) - - dy, dx = [ - gaussian_filter( - self.random_state.randn(*volume_shape), - self.sigma, mode="reflect" - ) * self.alpha for _ in range(2) - ] - - z_dim, y_dim, x_dim = volume_shape - z, y, x = np.meshgrid(np.arange(z_dim), np.arange(y_dim), np.arange(x_dim), indexing='ij') - indices = z + dz, y + dy, x + dx - - if m.ndim == 3: - return map_coordinates(m, indices, order=self.spline_order, mode='reflect') - else: - channels = [map_coordinates(c, indices, order=self.spline_order, mode='reflect') for c in m] - return np.stack(channels, axis=0) - - return m - - -class CropToFixed: - def __init__(self, random_state, size=(256, 256), centered=False, **kwargs): - self.random_state = random_state - self.crop_y, self.crop_x = size - self.centered = centered - - def __call__(self, m): - def _padding(pad_total): - half_total = pad_total // 2 - return (half_total, pad_total - half_total) - - def _rand_range_and_pad(crop_size, max_size): - """ - Returns a tuple: - max_value (int) for the corner dimension. The corner dimension is chosen as `self.random_state(max_value)` - pad (int): padding in both directions; if crop_size is lt max_size the pad is 0 - """ - if crop_size < max_size: - return max_size - crop_size, (0, 0) - else: - return 1, _padding(crop_size - max_size) - - def _start_and_pad(crop_size, max_size): - if crop_size < max_size: - return (max_size - crop_size) // 2, (0, 0) - else: - return 0, _padding(crop_size - max_size) - - assert m.ndim in (3, 4) - if m.ndim == 3: - _, y, x = m.shape - else: - _, _, y, x = m.shape - - if not self.centered: - y_range, y_pad = _rand_range_and_pad(self.crop_y, y) - x_range, x_pad = _rand_range_and_pad(self.crop_x, x) - - y_start = self.random_state.randint(y_range) - x_start = self.random_state.randint(x_range) - - else: - y_start, y_pad = _start_and_pad(self.crop_y, y) - x_start, x_pad = _start_and_pad(self.crop_x, x) - - if m.ndim == 3: - result = m[:, y_start:y_start + self.crop_y, x_start:x_start + self.crop_x] - return np.pad(result, pad_width=((0, 0), y_pad, x_pad), mode='reflect') - else: - channels = [] - for c in range(m.shape[0]): - result = m[c][:, y_start:y_start + self.crop_y, x_start:x_start + self.crop_x] - channels.append(np.pad(result, pad_width=((0, 0), y_pad, x_pad), mode='reflect')) - return np.stack(channels, axis=0) - - -class AbstractLabelToBoundary: - AXES_TRANSPOSE = [ - (0, 1, 2), # X - (0, 2, 1), # Y - (2, 0, 1) # Z - ] - - def __init__(self, ignore_index=None, aggregate_affinities=False, append_label=False, **kwargs): - """ - :param ignore_index: label to be ignored in the output, i.e. after computing the boundary the label ignore_index - will be restored where is was in the patch originally - :param aggregate_affinities: aggregate affinities with the same offset across Z,Y,X axes - :param append_label: if True append the orignal ground truth labels to the last channel - :param blur: Gaussian blur the boundaries - :param sigma: standard deviation for Gaussian kernel - """ - self.ignore_index = ignore_index - self.aggregate_affinities = aggregate_affinities - self.append_label = append_label - - def __call__(self, m): - """ - Extract boundaries from a given 3D label tensor. - :param m: input 3D tensor - :return: binary mask, with 1-label corresponding to the boundary and 0-label corresponding to the background - """ - assert m.ndim == 3 - - kernels = self.get_kernels() - boundary_arr = [np.where(np.abs(convolve(m, kernel)) > 0, 1, 0) for kernel in kernels] - channels = np.stack(boundary_arr) - results = [] - if self.aggregate_affinities: - assert len(kernels) % 3 == 0, "Number of kernels must be divided by 3 (one kernel per offset per Z,Y,X axes" - # aggregate affinities with the same offset - for i in range(0, len(kernels), 3): - # merge across X,Y,Z axes (logical OR) - xyz_aggregated_affinities = np.logical_or.reduce(channels[i:i + 3, ...]).astype(np.int32) - # recover ignore index - xyz_aggregated_affinities = _recover_ignore_index(xyz_aggregated_affinities, m, self.ignore_index) - results.append(xyz_aggregated_affinities) - else: - results = [_recover_ignore_index(channels[i], m, self.ignore_index) for i in range(channels.shape[0])] - - if self.append_label: - # append original input data - results.append(m) - - # stack across channel dim - return np.stack(results, axis=0) - - @staticmethod - def create_kernel(axis, offset): - # create conv kernel - k_size = offset + 1 - k = np.zeros((1, 1, k_size), dtype=np.int32) - k[0, 0, 0] = 1 - k[0, 0, offset] = -1 - return np.transpose(k, axis) - - def get_kernels(self): - raise NotImplementedError - - -class StandardLabelToBoundary: - def __init__(self, ignore_index=None, append_label=False, mode='thick', foreground=False, - **kwargs): - self.ignore_index = ignore_index - self.append_label = append_label - self.mode = mode - self.foreground = foreground - - def __call__(self, m): - assert m.ndim == 3 - - boundaries = find_boundaries(m, connectivity=2, mode=self.mode) - boundaries = boundaries.astype('int32') - - results = [] - if self.foreground: - foreground = (m > 0).astype('uint8') - results.append(_recover_ignore_index(foreground, m, self.ignore_index)) - - results.append(_recover_ignore_index(boundaries, m, self.ignore_index)) - - if self.append_label: - # append original input data - results.append(m) - - return np.stack(results, axis=0) - - -class BlobsToMask: - """ - Returns binary mask from labeled image, i.e. every label greater than 0 is treated as foreground. - - """ - - def __init__(self, append_label=False, boundary=False, cross_entropy=False, **kwargs): - self.cross_entropy = cross_entropy - self.boundary = boundary - self.append_label = append_label - - def __call__(self, m): - assert m.ndim == 3 - - # get the segmentation mask - mask = (m > 0).astype('uint8') - results = [mask] - - if self.boundary: - outer = find_boundaries(m, connectivity=2, mode='outer') - if self.cross_entropy: - # boundary is class 2 - mask[outer > 0] = 2 - results = [mask] - else: - results.append(outer) - - if self.append_label: - results.append(m) - - return np.stack(results, axis=0) - - -class RandomLabelToAffinities(AbstractLabelToBoundary): - """ - Converts a given volumetric label array to binary mask corresponding to borders between labels. - One specify the max_offset (thickness) of the border. Then the offset is picked at random every time you call - the transformer (offset is picked form the range 1:max_offset) for each axis and the boundary computed. - One may use this scheme in order to make the network more robust against various thickness of borders in the ground - truth (think of it as a boundary denoising scheme). - """ - - def __init__(self, random_state, max_offset=10, ignore_index=None, append_label=False, z_offset_scale=2, **kwargs): - super().__init__(ignore_index=ignore_index, append_label=append_label, aggregate_affinities=False) - self.random_state = random_state - self.offsets = tuple(range(1, max_offset + 1)) - self.z_offset_scale = z_offset_scale - - def get_kernels(self): - rand_offset = self.random_state.choice(self.offsets) - axis_ind = self.random_state.randint(3) - # scale down z-affinities due to anisotropy - if axis_ind == 2: - rand_offset = max(1, rand_offset // self.z_offset_scale) - - rand_axis = self.AXES_TRANSPOSE[axis_ind] - # return a single kernel - return [self.create_kernel(rand_axis, rand_offset)] - - -class LabelToAffinities(AbstractLabelToBoundary): - """ - Converts a given volumetric label array to binary mask corresponding to borders between labels (which can be seen - as an affinity graph: https://arxiv.org/pdf/1706.00120.pdf) - One specify the offsets (thickness) of the border. The boundary will be computed via the convolution operator. - """ - - def __init__(self, offsets, ignore_index=None, append_label=False, aggregate_affinities=False, z_offsets=None, - **kwargs): - super().__init__(ignore_index=ignore_index, append_label=append_label, - aggregate_affinities=aggregate_affinities) - - assert isinstance(offsets, list) or isinstance(offsets, tuple), 'offsets must be a list or a tuple' - assert all(a > 0 for a in offsets), "'offsets must be positive" - assert len(set(offsets)) == len(offsets), "'offsets' must be unique" - if z_offsets is not None: - assert len(offsets) == len(z_offsets), 'z_offsets length must be the same as the length of offsets' - else: - # if z_offsets is None just use the offsets for z-affinities - z_offsets = list(offsets) - self.z_offsets = z_offsets - - self.kernels = [] - # create kernel for every axis-offset pair - for xy_offset, z_offset in zip(offsets, z_offsets): - for axis_ind, axis in enumerate(self.AXES_TRANSPOSE): - final_offset = xy_offset - if axis_ind == 2: - final_offset = z_offset - # create kernels for a given offset in every direction - self.kernels.append(self.create_kernel(axis, final_offset)) - - def get_kernels(self): - return self.kernels - - -class LabelToZAffinities(AbstractLabelToBoundary): - """ - Converts a given volumetric label array to binary mask corresponding to borders between labels (which can be seen - as an affinity graph: https://arxiv.org/pdf/1706.00120.pdf) - One specify the offsets (thickness) of the border. The boundary will be computed via the convolution operator. - """ - - def __init__(self, offsets, ignore_index=None, append_label=False, **kwargs): - super().__init__(ignore_index=ignore_index, append_label=append_label) - - assert isinstance(offsets, list) or isinstance(offsets, tuple), 'offsets must be a list or a tuple' - assert all(a > 0 for a in offsets), "'offsets must be positive" - assert len(set(offsets)) == len(offsets), "'offsets' must be unique" - - self.kernels = [] - z_axis = self.AXES_TRANSPOSE[2] - # create kernels - for z_offset in offsets: - self.kernels.append(self.create_kernel(z_axis, z_offset)) - - def get_kernels(self): - return self.kernels - - -class LabelToBoundaryAndAffinities: - """ - Combines the StandardLabelToBoundary and LabelToAffinities in the hope - that that training the network to predict both would improve the main task: boundary prediction. - """ - - def __init__(self, xy_offsets, z_offsets, append_label=False, blur=False, sigma=1, ignore_index=None, mode='thick', - foreground=False, **kwargs): - # blur only StandardLabelToBoundary results; we don't want to blur the affinities - self.l2b = StandardLabelToBoundary(blur=blur, sigma=sigma, ignore_index=ignore_index, mode=mode, - foreground=foreground) - self.l2a = LabelToAffinities(offsets=xy_offsets, z_offsets=z_offsets, append_label=append_label, - ignore_index=ignore_index) - - def __call__(self, m): - boundary = self.l2b(m) - affinities = self.l2a(m) - return np.concatenate((boundary, affinities), axis=0) - - -class LabelToMaskAndAffinities: - def __init__(self, xy_offsets, z_offsets, append_label=False, background=0, ignore_index=None, **kwargs): - self.background = background - self.l2a = LabelToAffinities(offsets=xy_offsets, z_offsets=z_offsets, append_label=append_label, - ignore_index=ignore_index) - - def __call__(self, m): - mask = m > self.background - mask = np.expand_dims(mask.astype(np.uint8), axis=0) - affinities = self.l2a(m) - return np.concatenate((mask, affinities), axis=0) - - -class Standardize: - """ - Apply Z-score normalization to a given input tensor, i.e. re-scaling the values to be 0-mean and 1-std. - """ - - def __init__(self, eps=1e-10, mean=None, std=None, channelwise=False, **kwargs): - if mean is not None or std is not None: - assert mean is not None and std is not None - self.mean = mean - self.std = std - self.eps = eps - self.channelwise = channelwise - - def __call__(self, m): - if self.mean is not None: - mean, std = self.mean, self.std - else: - if self.channelwise: - # normalize per-channel - axes = list(range(m.ndim)) - # average across channels - axes = tuple(axes[1:]) - mean = np.mean(m, axis=axes, keepdims=True) - std = np.std(m, axis=axes, keepdims=True) - else: - mean = np.mean(m) - std = np.std(m) - - return (m - mean) / np.clip(std, a_min=self.eps, a_max=None) - - -class PercentileNormalizer: - def __init__(self, pmin=1, pmax=99.6, channelwise=False, eps=1e-10, **kwargs): - self.eps = eps - self.pmin = pmin - self.pmax = pmax - self.channelwise = channelwise - - def __call__(self, m): - if self.channelwise: - axes = list(range(m.ndim)) - # average across channels - axes = tuple(axes[1:]) - pmin = np.percentile(m, self.pmin, axis=axes, keepdims=True) - pmax = np.percentile(m, self.pmax, axis=axes, keepdims=True) - else: - pmin = np.percentile(m, self.pmin) - pmax = np.percentile(m, self.pmax) - - return (m - pmin) / (pmax - pmin + self.eps) - - -class Normalize: - """ - Apply simple min-max scaling to a given input tensor, i.e. shrinks the range of the data - in a fixed range of [-1, 1] or in case of norm01==True to [0, 1]. In addition, data can be - clipped by specifying min_value/max_value either globally using single values or via a - list/tuple channelwise if enabled. - """ - - def __init__(self, min_value=None, max_value=None, norm01=False, channelwise=False, - eps=1e-10, **kwargs): - if min_value is not None and max_value is not None: - assert max_value > min_value - self.min_value = min_value - self.max_value = max_value - self.norm01 = norm01 - self.channelwise = channelwise - self.eps = eps - - def __call__(self, m): - if self.channelwise: - # get min/max channelwise - axes = list(range(m.ndim)) - axes = tuple(axes[1:]) - if self.min_value is None or 'None' in self.min_value: - min_value = np.min(m, axis=axes, keepdims=True) - - if self.max_value is None or 'None' in self.max_value: - max_value = np.max(m, axis=axes, keepdims=True) - - # check if non None in self.min_value/self.max_value - # if present and if so copy value to min_value - if self.min_value is not None: - for i,v in enumerate(self.min_value): - if v != 'None': - min_value[i] = v - - if self.max_value is not None: - for i,v in enumerate(self.max_value): - if v != 'None': - max_value[i] = v - else: - if self.min_value is None: - min_value = np.min(m) - else: - min_value = self.min_value - - if self.max_value is None: - max_value = np.max(m) - else: - max_value = self.max_value - - # calculate norm_0_1 with min_value / max_value with the same dimension - # in case of channelwise application - norm_0_1 = (m - min_value) / (max_value - min_value + self.eps) - - if self.norm01 is True: - return np.clip(norm_0_1, 0, 1) - else: - return np.clip(2 * norm_0_1 - 1, -1, 1) - - -class AdditiveGaussianNoise: - def __init__(self, random_state, scale=(0.0, 1.0), execution_probability=0.1, **kwargs): - self.execution_probability = execution_probability - self.random_state = random_state - self.scale = scale - - def __call__(self, m): - if self.random_state.uniform() < self.execution_probability: - std = self.random_state.uniform(self.scale[0], self.scale[1]) - gaussian_noise = self.random_state.normal(0, std, size=m.shape) - return m + gaussian_noise - return m - - -class AdditivePoissonNoise: - def __init__(self, random_state, lam=(0.0, 1.0), execution_probability=0.1, **kwargs): - self.execution_probability = execution_probability - self.random_state = random_state - self.lam = lam - - def __call__(self, m): - if self.random_state.uniform() < self.execution_probability: - lam = self.random_state.uniform(self.lam[0], self.lam[1]) - poisson_noise = self.random_state.poisson(lam, size=m.shape) - return m + poisson_noise - return m - - -class ToTensor: - """ - Converts a given input numpy.ndarray into torch.Tensor. - - Args: - expand_dims (bool): if True, adds a channel dimension to the input data - dtype (np.dtype): the desired output data type - """ - - def __init__(self, expand_dims, dtype=np.float32, **kwargs): - self.expand_dims = expand_dims - self.dtype = dtype - - def __call__(self, m): - assert m.ndim in [3, 4], 'Supports only 3D (DxHxW) or 4D (CxDxHxW) images' - # add channel dimension - if self.expand_dims and m.ndim == 3: - m = np.expand_dims(m, axis=0) - - return torch.from_numpy(m.astype(dtype=self.dtype)) - - -class Relabel: - """ - Relabel a numpy array of labels into a consecutive numbers, e.g. - [10, 10, 0, 6, 6] -> [2, 2, 0, 1, 1]. Useful when one has an instance segmentation volume - at hand and would like to create a one-hot-encoding for it. Without a consecutive labeling the task would be harder. - """ - - def __init__(self, append_original=False, run_cc=True, ignore_label=None, **kwargs): - self.append_original = append_original - self.ignore_label = ignore_label - self.run_cc = run_cc - - if ignore_label is not None: - assert append_original, "ignore_label present, so append_original must be true, so that one can localize the ignore region" - - def __call__(self, m): - orig = m - if self.run_cc: - # assign 0 to the ignore region - m = measure.label(m, background=self.ignore_label) - - _, unique_labels = np.unique(m, return_inverse=True) - result = unique_labels.reshape(m.shape) - if self.append_original: - result = np.stack([result, orig]) - return result - - -class Identity: - def __init__(self, **kwargs): - pass - - def __call__(self, m): - return m - - -class RgbToLabel: - def __call__(self, img): - img = np.array(img) - assert img.ndim == 3 and img.shape[2] == 3 - result = img[..., 0] * 65536 + img[..., 1] * 256 + img[..., 2] - return result - - -class LabelToTensor: - def __call__(self, m): - m = np.array(m) - return torch.from_numpy(m.astype(dtype='int64')) - - -class GaussianBlur3D: - def __init__(self, sigma=[.1, 2.], execution_probability=0.5, **kwargs): - self.sigma = sigma - self.execution_probability = execution_probability - - def __call__(self, x): - if random.random() < self.execution_probability: - sigma = random.uniform(self.sigma[0], self.sigma[1]) - x = gaussian(x, sigma=sigma) - return x - return x - - -class Transformer: - def __init__(self, phase_config, base_config): - self.phase_config = phase_config - self.config_base = base_config - self.seed = GLOBAL_RANDOM_STATE.randint(10000000) - - def raw_transform(self): - return self._create_transform('raw') - - def label_transform(self): - return self._create_transform('label') - - def weight_transform(self): - return self._create_transform('weight') - - @staticmethod - def _transformer_class(class_name): - m = importlib.import_module('pytorch3dunet.augment.transforms') - clazz = getattr(m, class_name) - return clazz - - def _create_transform(self, name): - assert name in self.phase_config, f'Could not find {name} transform' - return Compose([ - self._create_augmentation(c) for c in self.phase_config[name] - ]) - - def _create_augmentation(self, c): - config = dict(self.config_base) - config.update(c) - config['random_state'] = np.random.RandomState(self.seed) - aug_class = self._transformer_class(config['name']) - return aug_class(**config) - - -def _recover_ignore_index(input, orig, ignore_index): - if ignore_index is not None: - mask = orig == ignore_index - input[mask] = ignore_index - - return input diff --git a/build/lib/pytorch3dunet/datasets/__init__.py b/build/lib/pytorch3dunet/datasets/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/build/lib/pytorch3dunet/datasets/dsb.py b/build/lib/pytorch3dunet/datasets/dsb.py deleted file mode 100644 index 5d0cde86..00000000 --- a/build/lib/pytorch3dunet/datasets/dsb.py +++ /dev/null @@ -1,108 +0,0 @@ -import collections -import os - -import imageio -import numpy as np -import torch - -from pytorch3dunet.augment import transforms -from pytorch3dunet.datasets.utils import ConfigDataset, calculate_stats -from pytorch3dunet.unet3d.utils import get_logger - -logger = get_logger('DSB2018Dataset') - - -def dsb_prediction_collate(batch): - """ - Forms a mini-batch of (images, paths) during test time for the DSB-like datasets. - """ - error_msg = "batch must contain tensors or str; found {}" - if isinstance(batch[0], torch.Tensor): - return torch.stack(batch, 0) - elif isinstance(batch[0], str): - return list(batch) - elif isinstance(batch[0], collections.Sequence): - # transpose tuples, i.e. [[1, 2], ['a', 'b']] to be [[1, 'a'], [2, 'b']] - transposed = zip(*batch) - return [dsb_prediction_collate(samples) for samples in transposed] - - raise TypeError((error_msg.format(type(batch[0])))) - - -class DSB2018Dataset(ConfigDataset): - def __init__(self, root_dir, phase, transformer_config, expand_dims=True): - assert os.path.isdir(root_dir), f'{root_dir} is not a directory' - assert phase in ['train', 'val', 'test'] - - self.phase = phase - - # load raw images - images_dir = os.path.join(root_dir, 'images') - assert os.path.isdir(images_dir) - self.images, self.paths = self._load_files(images_dir, expand_dims) - self.file_path = images_dir - - stats = calculate_stats(self.images, True) - - transformer = transforms.Transformer(transformer_config, stats) - - # load raw images transformer - self.raw_transform = transformer.raw_transform() - - if phase != 'test': - # load labeled images - masks_dir = os.path.join(root_dir, 'masks') - assert os.path.isdir(masks_dir) - self.masks, _ = self._load_files(masks_dir, expand_dims) - assert len(self.images) == len(self.masks) - # load label images transformer - self.masks_transform = transformer.label_transform() - else: - self.masks = None - self.masks_transform = None - - def __getitem__(self, idx): - if idx >= len(self): - raise StopIteration - - img = self.images[idx] - if self.phase != 'test': - mask = self.masks[idx] - return self.raw_transform(img), self.masks_transform(mask) - else: - return self.raw_transform(img), self.paths[idx] - - def __len__(self): - return len(self.images) - - @classmethod - def prediction_collate(cls, batch): - return dsb_prediction_collate(batch) - - @classmethod - def create_datasets(cls, dataset_config, phase): - phase_config = dataset_config[phase] - # load data augmentation configuration - transformer_config = phase_config['transformer'] - # load files to process - file_paths = phase_config['file_paths'] - expand_dims = dataset_config.get('expand_dims', True) - return [cls(file_paths[0], phase, transformer_config, expand_dims)] - - @staticmethod - def _load_files(dir, expand_dims): - files_data = [] - paths = [] - for file in os.listdir(dir): - path = os.path.join(dir, file) - img = np.asarray(imageio.imread(path)) - if expand_dims: - dims = img.ndim - img = np.expand_dims(img, axis=0) - if dims == 3: - img = np.transpose(img, (3, 0, 1, 2)) - - files_data.append(img) - paths.append(path) - - return files_data, paths diff --git a/build/lib/pytorch3dunet/datasets/hdf5.py b/build/lib/pytorch3dunet/datasets/hdf5.py deleted file mode 100644 index 040adb85..00000000 --- a/build/lib/pytorch3dunet/datasets/hdf5.py +++ /dev/null @@ -1,293 +0,0 @@ -import glob -import os -from abc import abstractmethod -from itertools import chain - -import h5py - -import pytorch3dunet.augment.transforms as transforms -from pytorch3dunet.datasets.utils import get_slice_builder, ConfigDataset, calculate_stats, mirror_pad -from pytorch3dunet.unet3d.utils import get_logger - -logger = get_logger('HDF5Dataset') - - -def _create_padded_indexes(indexes, halo_shape): - return tuple(slice(index.start, index.stop + 2 * halo) for index, halo in zip(indexes, halo_shape)) - - -def traverse_h5_paths(file_paths): - assert isinstance(file_paths, list) - results = [] - for file_path in file_paths: - if os.path.isdir(file_path): - # if file path is a directory take all H5 files in that directory - iters = [glob.glob(os.path.join(file_path, ext)) for ext in ['*.h5', '*.hdf', '*.hdf5', '*.hd5']] - for fp in chain(*iters): - results.append(fp) - else: - results.append(file_path) - return results - - -class AbstractHDF5Dataset(ConfigDataset): - """ - Implementation of torch.utils.data.Dataset backed by the HDF5 files, which iterates over the raw and label datasets - patch by patch with a given stride. - - Args: - file_path (str): path to H5 file containing raw data as well as labels and per pixel weights (optional) - phase (str): 'train' for training, 'val' for validation, 'test' for testing - slice_builder_config (dict): configuration of the SliceBuilder - transformer_config (dict): data augmentation configuration - raw_internal_path (str or list): H5 internal path to the raw dataset - label_internal_path (str or list): H5 internal path to the label dataset - weight_internal_path (str or list): H5 internal path to the per pixel weights (optional) - global_normalization (bool): if True, the mean and std of the raw data will be calculated over the whole dataset - """ - - def __init__(self, file_path, phase, slice_builder_config, transformer_config, raw_internal_path='raw', - label_internal_path='label', weight_internal_path=None, global_normalization=True): - assert phase in ['train', 'val', 'test'] - - self.phase = phase - self.file_path = file_path - self.raw_internal_path = raw_internal_path - self.label_internal_path = label_internal_path - self.weight_internal_path = weight_internal_path - - self.halo_shape = slice_builder_config.get('halo_shape', [0, 0, 0]) - - if global_normalization: - logger.info('Calculating mean and std of the raw data...') - with h5py.File(file_path, 'r') as f: - raw = f[raw_internal_path][:] - stats = calculate_stats(raw) - else: - stats = calculate_stats(None, True) - - self.transformer = transforms.Transformer(transformer_config, stats) - self.raw_transform = self.transformer.raw_transform() - - if phase != 'test': - # create label/weight transform only in train/val phase - self.label_transform = self.transformer.label_transform() - - if weight_internal_path is not None: - self.weight_transform = self.transformer.weight_transform() - else: - self.weight_transform = None - - self._check_volume_sizes() - else: - # 'test' phase used only for predictions so ignore the label dataset - self.label = None - self.weight_map = None - - # compare patch and stride configuration - patch_shape = slice_builder_config.get('patch_shape') - stride_shape = slice_builder_config.get('stride_shape') - if sum(self.halo_shape) != 0 and patch_shape != stride_shape: - logger.warning(f'Found non-zero halo shape {self.halo_shape}. ' - f'In this case: patch shape and stride shape should be equal for optimal prediction ' - f'performance, but found patch_shape: {patch_shape} and stride_shape: {stride_shape}!') - - with h5py.File(file_path, 'r') as f: - raw = f[raw_internal_path] - label = f[label_internal_path] if phase != 'test' else None - weight_map = f[weight_internal_path] if weight_internal_path is not None else None - # build slice indices for raw and label data sets - slice_builder = get_slice_builder(raw, label, weight_map, slice_builder_config) - self.raw_slices = slice_builder.raw_slices - self.label_slices = slice_builder.label_slices - self.weight_slices = slice_builder.weight_slices - - self.patch_count = len(self.raw_slices) - logger.info(f'Number of patches: {self.patch_count}') - - @abstractmethod - def get_raw_patch(self, idx): - raise NotImplementedError - - @abstractmethod - def get_label_patch(self, idx): - raise NotImplementedError - - @abstractmethod - def get_weight_patch(self, idx): - raise NotImplementedError - - @abstractmethod - def get_raw_padded_patch(self, idx): - raise NotImplementedError - - def volume_shape(self): - with h5py.File(self.file_path, 'r') as f: - raw = f[self.raw_internal_path] - if raw.ndim == 3: - return raw.shape - else: - return raw.shape[1:] - - def __getitem__(self, idx): - if idx >= len(self): - raise StopIteration - - raw_idx = self.raw_slices[idx] - - if self.phase == 'test': - if len(raw_idx) == 4: - # discard the channel dimension in the slices: predictor requires only the spatial dimensions of the volume - raw_idx = raw_idx[1:] # Remove the first element if raw_idx has 4 elements - raw_idx_padded = (slice(None),) + _create_padded_indexes(raw_idx, self.halo_shape) - else: - raw_idx_padded = _create_padded_indexes(raw_idx, self.halo_shape) - - raw_patch_transformed = self.raw_transform(self.get_raw_padded_patch(raw_idx_padded)) - return raw_patch_transformed, raw_idx - else: - raw_patch_transformed = self.raw_transform(self.get_raw_patch(raw_idx)) - - # get the slice for a given index 'idx' - label_idx = self.label_slices[idx] - label_patch_transformed = self.label_transform(self.get_label_patch(label_idx)) - if self.weight_internal_path is not None: - weight_idx = self.weight_slices[idx] - weight_patch_transformed = self.weight_transform(self.get_weight_patch(weight_idx)) - return raw_patch_transformed, label_patch_transformed, weight_patch_transformed - # return the transformed raw and label patches - return raw_patch_transformed, label_patch_transformed - - def __len__(self): - return self.patch_count - - def _check_volume_sizes(self): - def _volume_shape(volume): - if volume.ndim == 3: - return volume.shape - return volume.shape[1:] - - with h5py.File(self.file_path, 'r') as f: - raw = f[self.raw_internal_path] - label = f[self.label_internal_path] - assert raw.ndim in [3, 4], 'Raw dataset must be 3D (DxHxW) or 4D (CxDxHxW)' - assert label.ndim in [3, 4], 'Label dataset must be 3D (DxHxW) or 4D (CxDxHxW)' - assert _volume_shape(raw) == _volume_shape(label), 'Raw and labels have to be of the same size' - if self.weight_internal_path is not None: - weight_map = f[self.weight_internal_path] - assert weight_map.ndim in [3, 4], 'Weight map dataset must be 3D (DxHxW) or 4D (CxDxHxW)' - assert _volume_shape(raw) == _volume_shape(weight_map), 'Raw and weight map have to be of the same size' - - @classmethod - def create_datasets(cls, dataset_config, phase): - phase_config = dataset_config[phase] - - # load data augmentation configuration - transformer_config = phase_config['transformer'] - # load slice builder config - slice_builder_config = phase_config['slice_builder'] - # load files to process - file_paths = phase_config['file_paths'] - # file_paths may contain both files and directories; if the file_path is a directory all H5 files inside - # are going to be included in the final file_paths - file_paths = traverse_h5_paths(file_paths) - - datasets = [] - for file_path in file_paths: - try: - logger.info(f'Loading {phase} set from: {file_path}...') - dataset = cls(file_path=file_path, - phase=phase, - slice_builder_config=slice_builder_config, - transformer_config=transformer_config, - raw_internal_path=dataset_config.get('raw_internal_path', 'raw'), - label_internal_path=dataset_config.get('label_internal_path', 'label'), - weight_internal_path=dataset_config.get('weight_internal_path', None), - global_normalization=dataset_config.get('global_normalization', None)) - datasets.append(dataset) - except Exception: - logger.error(f'Skipping {phase} set: {file_path}', exc_info=True) - return datasets - - -class StandardHDF5Dataset(AbstractHDF5Dataset): - """ - Implementation of the HDF5 dataset which loads the data from the H5 files into the memory. - Fast but might consume a lot of memory. - """ - - def __init__(self, file_path, phase, slice_builder_config, transformer_config, - raw_internal_path='raw', label_internal_path='label', weight_internal_path=None, - global_normalization=True): - super().__init__(file_path=file_path, phase=phase, slice_builder_config=slice_builder_config, - transformer_config=transformer_config, raw_internal_path=raw_internal_path, - label_internal_path=label_internal_path, weight_internal_path=weight_internal_path, - global_normalization=global_normalization) - self._raw = None - self._raw_padded = None - self._label = None - self._weight_map = None - - def get_raw_patch(self, idx): - if self._raw is None: - with h5py.File(self.file_path, 'r') as f: - assert self.raw_internal_path in f, f'Dataset {self.raw_internal_path} not found in {self.file_path}' - self._raw = f[self.raw_internal_path][:] - return self._raw[idx] - - def get_label_patch(self, idx): - if self._label is None: - with h5py.File(self.file_path, 'r') as f: - assert self.label_internal_path in f, f'Dataset {self.label_internal_path} not found in {self.file_path}' - self._label = f[self.label_internal_path][:] - return self._label[idx] - - def get_weight_patch(self, idx): - if self._weight_map is None: - with h5py.File(self.file_path, 'r') as f: - assert self.weight_internal_path in f, f'Dataset {self.weight_internal_path} not found in {self.file_path}' - self._weight_map = f[self.weight_internal_path][:] - return self._weight_map[idx] - - def get_raw_padded_patch(self, idx): - if self._raw_padded is None: - with h5py.File(self.file_path, 'r') as f: - assert self.raw_internal_path in f, f'Dataset {self.raw_internal_path} not found in {self.file_path}' - self._raw_padded = mirror_pad(f[self.raw_internal_path][:], self.halo_shape) - return self._raw_padded[idx] - - -class LazyHDF5Dataset(AbstractHDF5Dataset): - """Implementation of the HDF5 dataset which loads the data lazily. It's slower, but has a low memory footprint.""" - - def __init__(self, file_path, phase, slice_builder_config, transformer_config, - raw_internal_path='raw', label_internal_path='label', weight_internal_path=None, - global_normalization=False): - super().__init__(file_path=file_path, phase=phase, slice_builder_config=slice_builder_config, - transformer_config=transformer_config, raw_internal_path=raw_internal_path, - label_internal_path=label_internal_path, weight_internal_path=weight_internal_path, - global_normalization=global_normalization) - - logger.info("Using LazyHDF5Dataset") - - def get_raw_patch(self, idx): - with h5py.File(self.file_path, 'r') as f: - return f[self.raw_internal_path][idx] - - def get_label_patch(self, idx): - with h5py.File(self.file_path, 'r') as f: - return f[self.label_internal_path][idx] - - def get_weight_patch(self, idx): - with h5py.File(self.file_path, 'r') as f: - return f[self.weight_internal_path][idx] - - def get_raw_padded_patch(self, idx): - with h5py.File(self.file_path, 'r+') as f: - if 'raw_padded' in f: - return f['raw_padded'][idx] - - raw = f[self.raw_internal_path][:] - raw_padded = mirror_pad(raw, self.halo_shape) - f.create_dataset('raw_padded', data=raw_padded, compression='gzip') - return raw_padded[idx] diff --git a/build/lib/pytorch3dunet/datasets/utils.py b/build/lib/pytorch3dunet/datasets/utils.py deleted file mode 100644 index 1ffeefe4..00000000 --- a/build/lib/pytorch3dunet/datasets/utils.py +++ /dev/null @@ -1,361 +0,0 @@ -import collections -from typing import Any - -import numpy as np -import torch -from torch.utils.data import DataLoader, ConcatDataset, Dataset - -from pytorch3dunet.unet3d.utils import get_logger, get_class - -logger = get_logger('Dataset') - - -class ConfigDataset(Dataset): - def __getitem__(self, index): - raise NotImplementedError - - def __len__(self): - raise NotImplementedError - - @classmethod - def create_datasets(cls, dataset_config, phase): - """ - Factory method for creating a list of datasets based on the provided config. - - Args: - dataset_config (dict): dataset configuration - phase (str): one of ['train', 'val', 'test'] - - Returns: - list of `Dataset` instances - """ - raise NotImplementedError - - @classmethod - def prediction_collate(cls, batch): - """Default collate_fn. Override in child class for non-standard datasets.""" - return default_prediction_collate(batch) - - -class SliceBuilder: - """ - Builds the position of the patches in a given raw/label/weight ndarray based on the patch and stride shape. - - Args: - raw_dataset (ndarray): raw data - label_dataset (ndarray): ground truth labels - weight_dataset (ndarray): weights for the labels - patch_shape (tuple): the shape of the patch DxHxW - stride_shape (tuple): the shape of the stride DxHxW - kwargs: additional metadata - """ - - def __init__(self, raw_dataset, label_dataset, weight_dataset, patch_shape, stride_shape, **kwargs): - patch_shape = tuple(patch_shape) - stride_shape = tuple(stride_shape) - skip_shape_check = kwargs.get('skip_shape_check', False) - if not skip_shape_check: - self._check_patch_shape(patch_shape) - - self._raw_slices = self._build_slices(raw_dataset, patch_shape, stride_shape) - if label_dataset is None: - self._label_slices = None - else: - # take the first element in the label_dataset to build slices - self._label_slices = self._build_slices(label_dataset, patch_shape, stride_shape) - assert len(self._raw_slices) == len(self._label_slices) - if weight_dataset is None: - self._weight_slices = None - else: - self._weight_slices = self._build_slices(weight_dataset, patch_shape, stride_shape) - assert len(self.raw_slices) == len(self._weight_slices) - - @property - def raw_slices(self): - return self._raw_slices - - @property - def label_slices(self): - return self._label_slices - - @property - def weight_slices(self): - return self._weight_slices - - @staticmethod - def _build_slices(dataset, patch_shape, stride_shape): - """Iterates over a given n-dim dataset patch-by-patch with a given stride - and builds an array of slice positions. - - Returns: - list of slices, i.e. - [(slice, slice, slice, slice), ...] if len(shape) == 4 - [(slice, slice, slice), ...] if len(shape) == 3 - """ - slices = [] - if dataset.ndim == 4: - in_channels, i_z, i_y, i_x = dataset.shape - else: - i_z, i_y, i_x = dataset.shape - - k_z, k_y, k_x = patch_shape - s_z, s_y, s_x = stride_shape - z_steps = SliceBuilder._gen_indices(i_z, k_z, s_z) - for z in z_steps: - y_steps = SliceBuilder._gen_indices(i_y, k_y, s_y) - for y in y_steps: - x_steps = SliceBuilder._gen_indices(i_x, k_x, s_x) - for x in x_steps: - slice_idx = ( - slice(z, z + k_z), - slice(y, y + k_y), - slice(x, x + k_x), - ) - if dataset.ndim == 4: - slice_idx = (slice(0, in_channels),) + slice_idx - slices.append(slice_idx) - return slices - - @staticmethod - def _gen_indices(i, k, s): - assert i >= k, 'Sample size has to be bigger than the patch size' - for j in range(0, i - k + 1, s): - yield j - if j + k < i: - yield i - k - - @staticmethod - def _check_patch_shape(patch_shape): - assert len(patch_shape) == 3, 'patch_shape must be a 3D tuple' - assert patch_shape[1] >= 64 and patch_shape[2] >= 64, 'Height and Width must be greater or equal 64' - - -class FilterSliceBuilder(SliceBuilder): - """ - Filter patches containing more than `1 - threshold` of ignore_index label - """ - - def __init__(self, raw_dataset, label_dataset, weight_dataset, patch_shape, stride_shape, ignore_index=None, - threshold=0.6, slack_acceptance=0.01, **kwargs): - super().__init__(raw_dataset, label_dataset, weight_dataset, patch_shape, stride_shape, **kwargs) - if label_dataset is None: - return - - rand_state = np.random.RandomState(47) - - def ignore_predicate(raw_label_idx): - label_idx = raw_label_idx[1] - patch = label_dataset[label_idx] - if ignore_index is not None: - patch = np.copy(patch) - patch[patch == ignore_index] = 0 - non_ignore_counts = np.count_nonzero(patch != 0) - non_ignore_counts = non_ignore_counts / patch.size - return non_ignore_counts > threshold or rand_state.rand() < slack_acceptance - - zipped_slices = zip(self.raw_slices, self.label_slices) - # ignore slices containing too much ignore_index - logger.info(f'Filtering slices...') - filtered_slices = list(filter(ignore_predicate, zipped_slices)) - # unzip and save slices - raw_slices, label_slices = zip(*filtered_slices) - self._raw_slices = list(raw_slices) - self._label_slices = list(label_slices) - - -def _loader_classes(class_name): - modules = [ - 'pytorch3dunet.datasets.hdf5', - 'pytorch3dunet.datasets.dsb', - 'pytorch3dunet.datasets.utils' - ] - return get_class(class_name, modules) - - -def get_slice_builder(raws, labels, weight_maps, config): - assert 'name' in config - logger.info(f"Slice builder config: {config}") - slice_builder_cls = _loader_classes(config['name']) - return slice_builder_cls(raws, labels, weight_maps, **config) - - -def get_train_loaders(config): - """ - Returns dictionary containing the training and validation loaders (torch.utils.data.DataLoader). - - :param config: a top level configuration object containing the 'loaders' key - :return: dict { - 'train': - 'val': - } - """ - assert 'loaders' in config, 'Could not find data loaders configuration' - loaders_config = config['loaders'] - - logger.info('Creating training and validation set loaders...') - - # get dataset class - dataset_cls_str = loaders_config.get('dataset', None) - if dataset_cls_str is None: - dataset_cls_str = 'StandardHDF5Dataset' - logger.warning(f"Cannot find dataset class in the config. Using default '{dataset_cls_str}'.") - dataset_class = _loader_classes(dataset_cls_str) - - assert set(loaders_config['train']['file_paths']).isdisjoint(loaders_config['val']['file_paths']), \ - "Train and validation 'file_paths' overlap. One cannot use validation data for training!" - - train_datasets = dataset_class.create_datasets(loaders_config, phase='train') - - val_datasets = dataset_class.create_datasets(loaders_config, phase='val') - - num_workers = loaders_config.get('num_workers', 1) - logger.info(f'Number of workers for train/val dataloader: {num_workers}') - batch_size = loaders_config.get('batch_size', 1) - if torch.cuda.device_count() > 1 and not config['device'] == 'cpu': - logger.info( - f'{torch.cuda.device_count()} GPUs available. Using batch_size = {torch.cuda.device_count()} * {batch_size}') - batch_size = batch_size * torch.cuda.device_count() - - logger.info(f'Batch size for train/val loader: {batch_size}') - # when training with volumetric data use batch_size of 1 due to GPU memory constraints - return { - 'train': DataLoader(ConcatDataset(train_datasets), batch_size=batch_size, shuffle=True, pin_memory=True, - num_workers=num_workers), - # don't shuffle during validation: useful when showing how predictions for a given batch get better over time - 'val': DataLoader(ConcatDataset(val_datasets), batch_size=batch_size, shuffle=False, pin_memory=True, - num_workers=num_workers) - } - - -def get_test_loaders(config): - """ - Returns test DataLoader. - - :return: generator of DataLoader objects - """ - - assert 'loaders' in config, 'Could not find data loaders configuration' - loaders_config = config['loaders'] - - logger.info('Creating test set loaders...') - - # get dataset class - dataset_cls_str = loaders_config.get('dataset', None) - if dataset_cls_str is None: - dataset_cls_str = 'StandardHDF5Dataset' - logger.warning(f"Cannot find dataset class in the config. Using default '{dataset_cls_str}'.") - dataset_class = _loader_classes(dataset_cls_str) - - test_datasets = dataset_class.create_datasets(loaders_config, phase='test') - - num_workers = loaders_config.get('num_workers', 1) - logger.info(f'Number of workers for the dataloader: {num_workers}') - - batch_size = loaders_config.get('batch_size', 1) - if torch.cuda.device_count() > 1 and not config['device'] == 'cpu': - logger.info( - f'{torch.cuda.device_count()} GPUs available. Using batch_size = {torch.cuda.device_count()} * {batch_size}') - batch_size = batch_size * torch.cuda.device_count() - - logger.info(f'Batch size for dataloader: {batch_size}') - - # use generator in order to create data loaders lazily one by one - for test_dataset in test_datasets: - logger.info(f'Loading test set from: {test_dataset.file_path}...') - if hasattr(test_dataset, 'prediction_collate'): - collate_fn = test_dataset.prediction_collate - else: - collate_fn = default_prediction_collate - - yield DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True, - collate_fn=collate_fn) - - -def default_prediction_collate(batch): - """ - Default collate_fn to form a mini-batch of Tensor(s) for HDF5 based datasets - """ - error_msg = "batch must contain tensors or slice; found {}" - if isinstance(batch[0], torch.Tensor): - return torch.stack(batch, 0) - elif isinstance(batch[0], tuple) and isinstance(batch[0][0], slice): - return batch - elif isinstance(batch[0], collections.abc.Sequence): - transposed = zip(*batch) - return [default_prediction_collate(samples) for samples in transposed] - - raise TypeError((error_msg.format(type(batch[0])))) - - -def calculate_stats(img: np.array, skip: bool = False) -> dict[str, Any]: - """ - Calculates the minimum percentile, maximum percentile, mean, and standard deviation of the image. - - Args: - img: The input image array. - skip: if True, skip the calculation and return None for all values. - - Returns: - tuple[float, float, float, float]: The minimum percentile, maximum percentile, mean, and std dev - """ - if not skip: - pmin, pmax, mean, std = np.percentile(img, 1), np.percentile(img, 99.6), np.mean(img), np.std(img) - else: - pmin, pmax, mean, std = None, None, None, None - - return { - 'pmin': pmin, - 'pmax': pmax, - 'mean': mean, - 'std': std - } - - -def mirror_pad(image, padding_shape): - """ - Pad the image with a mirror reflection of itself. - - This function is used on data in its original shape before it is split into patches. - - Args: - image (np.ndarray): The input image array to be padded. - padding_shape (tuple of int): Specifies the amount of padding for each dimension, should be YX or ZYX. - - Returns: - np.ndarray: The mirror-padded image. - - Raises: - ValueError: If any element of padding_shape is negative. - """ - assert len(padding_shape) == 3, "Padding shape must be specified for each dimension: ZYX" - - if any(p < 0 for p in padding_shape): - raise ValueError("padding_shape must be non-negative") - - if all(p == 0 for p in padding_shape): - return image - - pad_width = [(p, p) for p in padding_shape] - - if image.ndim == 4: - pad_width = [(0, 0)] + pad_width - return np.pad(image, pad_width, mode='reflect') - - -def remove_padding(m, padding_shape): - """ - Removes padding from the margins of a multi-dimensional array. - - Args: - m (np.ndarray): The input array to be unpadded. - padding_shape (tuple of int, optional): The amount of padding to remove from each dimension. - Assumes the tuple length matches the array dimensions. - - Returns: - np.ndarray: The unpadded array. - """ - if padding_shape is None: - return m - - # Correctly construct slice objects for each dimension in padding_shape and apply them to m. - return m[(..., *(slice(p, -p or None) for p in padding_shape))] diff --git a/build/lib/pytorch3dunet/predict.py b/build/lib/pytorch3dunet/predict.py deleted file mode 100644 index cc54fcf7..00000000 --- a/build/lib/pytorch3dunet/predict.py +++ /dev/null @@ -1,59 +0,0 @@ -import importlib -import os - -import torch -import torch.nn as nn - -from pytorch3dunet.datasets.utils import get_test_loaders -from pytorch3dunet.unet3d import utils -from pytorch3dunet.unet3d.config import load_config -from pytorch3dunet.unet3d.model import get_model - -logger = utils.get_logger('UNet3DPredict') - - -def get_predictor(model, config): - output_dir = config['loaders'].get('output_dir', None) - # override output_dir if provided in the 'predictor' section of the config - output_dir = config.get('predictor', {}).get('output_dir', output_dir) - if output_dir is not None: - os.makedirs(output_dir, exist_ok=True) - - predictor_config = config.get('predictor', {}) - class_name = predictor_config.get('name', 'StandardPredictor') - - m = importlib.import_module('pytorch3dunet.unet3d.predictor') - predictor_class = getattr(m, class_name) - out_channels = config['model'].get('out_channels') - return predictor_class(model, output_dir, out_channels, **predictor_config) - - -def main(): - # Load configuration - config, _ = load_config() - - # Create the model - model = get_model(config['model']) - - # Load model state - model_path = config['model_path'] - logger.info(f'Loading model from {model_path}...') - utils.load_checkpoint(model_path, model) - # use DataParallel if more than 1 GPU available - - if torch.cuda.device_count() > 1 and not config['device'] == 'cpu': - model = nn.DataParallel(model) - logger.info(f'Using {torch.cuda.device_count()} GPUs for prediction') - if torch.cuda.is_available() and not config['device'] == 'cpu': - model = model.cuda() - - # create predictor instance - predictor = get_predictor(model, config) - - for test_loader in get_test_loaders(config): - # run the model prediction on the test_loader and save the results in the output_dir - predictor(test_loader) - - -if __name__ == '__main__': - main() diff --git a/build/lib/pytorch3dunet/train.py b/build/lib/pytorch3dunet/train.py deleted file mode 100644 index eceaf719..00000000 --- a/build/lib/pytorch3dunet/train.py +++ /dev/null @@ -1,35 +0,0 @@ -import random - -import torch - -from pytorch3dunet.unet3d.config import load_config, copy_config -from pytorch3dunet.unet3d.trainer import create_trainer -from pytorch3dunet.unet3d.utils import get_logger - -logger = get_logger('TrainingSetup') - - -def main(): - # Load and log experiment configuration - config, config_path = load_config() - logger.info(config) - - manual_seed = config.get('manual_seed', None) - if manual_seed is not None: - logger.info(f'Seed the RNG for all devices with {manual_seed}') - logger.warning('Using CuDNN deterministic setting. This may slow down the training!') - random.seed(manual_seed) - torch.manual_seed(manual_seed) - # see https://pytorch.org/docs/stable/notes/randomness.html - torch.backends.cudnn.deterministic = True - - # Create trainer - trainer = create_trainer(config) - # Copy config file - copy_config(config, config_path) - # Start training - trainer.fit() - - -if __name__ == '__main__': - main() diff --git a/build/lib/pytorch3dunet/unet3d/__init__.py b/build/lib/pytorch3dunet/unet3d/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/build/lib/pytorch3dunet/unet3d/buildingblocks.py b/build/lib/pytorch3dunet/unet3d/buildingblocks.py deleted file mode 100644 index 25679c24..00000000 --- a/build/lib/pytorch3dunet/unet3d/buildingblocks.py +++ /dev/null @@ -1,545 +0,0 @@ -from functools import partial - -import torch -from torch import nn as nn -from torch.nn import functional as F - -from pytorch3dunet.unet3d.se import ChannelSELayer3D, ChannelSpatialSELayer3D, SpatialSELayer3D - - -def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding, - dropout_prob, is3d): - """ - Create a list of modules with together constitute a single conv layer with non-linearity - and optional batchnorm/groupnorm. - - Args: - in_channels (int): number of input channels - out_channels (int): number of output channels - kernel_size(int or tuple): size of the convolving kernel - order (string): order of things, e.g. - 'cr' -> conv + ReLU - 'gcr' -> groupnorm + conv + ReLU - 'cl' -> conv + LeakyReLU - 'ce' -> conv + ELU - 'bcr' -> batchnorm + conv + ReLU - 'cbrd' -> conv + batchnorm + ReLU + dropout - 'cbrD' -> conv + batchnorm + ReLU + dropout2d - num_groups (int): number of groups for the GroupNorm - padding (int or tuple): add zero-padding added to all three sides of the input - dropout_prob (float): dropout probability - is3d (bool): is3d (bool): if True use Conv3d, otherwise use Conv2d - Return: - list of tuple (name, module) - """ - assert 'c' in order, "Conv layer MUST be present" - assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer' - - modules = [] - for i, char in enumerate(order): - if char == 'r': - modules.append(('ReLU', nn.ReLU(inplace=True))) - elif char == 'l': - modules.append(('LeakyReLU', nn.LeakyReLU(inplace=True))) - elif char == 'e': - modules.append(('ELU', nn.ELU(inplace=True))) - elif char == 'c': - # add learnable bias only in the absence of batchnorm/groupnorm - bias = not ('g' in order or 'b' in order) - if is3d: - conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) - else: - conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) - - modules.append(('conv', conv)) - elif char == 'g': - is_before_conv = i < order.index('c') - if is_before_conv: - num_channels = in_channels - else: - num_channels = out_channels - - # use only one group if the given number of groups is greater than the number of channels - if num_channels < num_groups: - num_groups = 1 - - assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}' - modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels))) - elif char == 'b': - is_before_conv = i < order.index('c') - if is3d: - bn = nn.BatchNorm3d - else: - bn = nn.BatchNorm2d - - if is_before_conv: - modules.append(('batchnorm', bn(in_channels))) - else: - modules.append(('batchnorm', bn(out_channels))) - elif char == 'd': - modules.append(('dropout', nn.Dropout(p=dropout_prob))) - elif char == 'D': - modules.append(('dropout2d', nn.Dropout2d(p=dropout_prob))) - else: - raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c', 'd', 'D']") - - return modules - - -class SingleConv(nn.Sequential): - """ - Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order - of operations can be specified via the `order` parameter - - Args: - in_channels (int): number of input channels - out_channels (int): number of output channels - kernel_size (int or tuple): size of the convolving kernel - order (string): determines the order of layers, e.g. - 'cr' -> conv + ReLU - 'crg' -> conv + ReLU + groupnorm - 'cl' -> conv + LeakyReLU - 'ce' -> conv + ELU - num_groups (int): number of groups for the GroupNorm - padding (int or tuple): add zero-padding - dropout_prob (float): dropout probability, default 0.1 - is3d (bool): if True use Conv3d, otherwise use Conv2d - """ - - def __init__(self, in_channels, out_channels, kernel_size=3, order='gcr', num_groups=8, - padding=1, dropout_prob=0.1, is3d=True): - super(SingleConv, self).__init__() - - for name, module in create_conv(in_channels, out_channels, kernel_size, order, - num_groups, padding, dropout_prob, is3d): - self.add_module(name, module) - - -class DoubleConv(nn.Sequential): - """ - A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d). - We use (Conv3d+ReLU+GroupNorm3d) by default. - This can be changed however by providing the 'order' argument, e.g. in order - to change to Conv3d+BatchNorm3d+ELU use order='cbe'. - Use padded convolutions to make sure that the output (H_out, W_out) is the same - as (H_in, W_in), so that you don't have to crop in the decoder path. - - Args: - in_channels (int): number of input channels - out_channels (int): number of output channels - encoder (bool): if True we're in the encoder path, otherwise we're in the decoder - kernel_size (int or tuple): size of the convolving kernel - order (string): determines the order of layers, e.g. - 'cr' -> conv + ReLU - 'crg' -> conv + ReLU + groupnorm - 'cl' -> conv + LeakyReLU - 'ce' -> conv + ELU - num_groups (int): number of groups for the GroupNorm - padding (int or tuple): add zero-padding added to all three sides of the input - upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2 - dropout_prob (float or tuple): dropout probability for each convolution, default 0.1 - is3d (bool): if True use Conv3d instead of Conv2d layers - """ - - def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr', - num_groups=8, padding=1, upscale=2, dropout_prob=0.1, is3d=True): - super(DoubleConv, self).__init__() - if encoder: - # we're in the encoder path - conv1_in_channels = in_channels - if upscale == 1: - conv1_out_channels = out_channels - else: - conv1_out_channels = out_channels // 2 - if conv1_out_channels < in_channels: - conv1_out_channels = in_channels - conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels - else: - # we're in the decoder path, decrease the number of channels in the 1st convolution - conv1_in_channels, conv1_out_channels = in_channels, out_channels - conv2_in_channels, conv2_out_channels = out_channels, out_channels - - # check if dropout_prob is a tuple and if so - # split it for different dropout probabilities for each convolution. - if isinstance(dropout_prob, list) or isinstance(dropout_prob, tuple): - dropout_prob1 = dropout_prob[0] - dropout_prob2 = dropout_prob[1] - else: - dropout_prob1 = dropout_prob2 = dropout_prob - - # conv1 - self.add_module('SingleConv1', - SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups, - padding=padding, dropout_prob=dropout_prob1, is3d=is3d)) - # conv2 - self.add_module('SingleConv2', - SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups, - padding=padding, dropout_prob=dropout_prob2, is3d=is3d)) - - -class ResNetBlock(nn.Module): - """ - Residual block that can be used instead of standard DoubleConv in the Encoder module. - Motivated by: https://arxiv.org/pdf/1706.00120.pdf - - Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm. - """ - - def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, is3d=True, **kwargs): - super(ResNetBlock, self).__init__() - - if in_channels != out_channels: - # conv1x1 for increasing the number of channels - if is3d: - self.conv1 = nn.Conv3d(in_channels, out_channels, 1) - else: - self.conv1 = nn.Conv2d(in_channels, out_channels, 1) - else: - self.conv1 = nn.Identity() - - # residual block - self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups, - is3d=is3d) - # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual - n_order = order - for c in 'rel': - n_order = n_order.replace(c, '') - self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order, - num_groups=num_groups, is3d=is3d) - - # create non-linearity separately - if 'l' in order: - self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True) - elif 'e' in order: - self.non_linearity = nn.ELU(inplace=True) - else: - self.non_linearity = nn.ReLU(inplace=True) - - def forward(self, x): - # apply first convolution to bring the number of channels to out_channels - residual = self.conv1(x) - - # residual block - out = self.conv2(residual) - out = self.conv3(out) - - out += residual - out = self.non_linearity(out) - - return out - - -class ResNetBlockSE(ResNetBlock): - def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, se_module='scse', **kwargs): - super(ResNetBlockSE, self).__init__( - in_channels, out_channels, kernel_size=kernel_size, order=order, - num_groups=num_groups, **kwargs) - assert se_module in ['scse', 'cse', 'sse'] - if se_module == 'scse': - self.se_module = ChannelSpatialSELayer3D(num_channels=out_channels, reduction_ratio=1) - elif se_module == 'cse': - self.se_module = ChannelSELayer3D(num_channels=out_channels, reduction_ratio=1) - elif se_module == 'sse': - self.se_module = SpatialSELayer3D(num_channels=out_channels) - - def forward(self, x): - out = super().forward(x) - out = self.se_module(out) - return out - - -class Encoder(nn.Module): - """ - A single module from the encoder path consisting of the optional max - pooling layer (one may specify the MaxPool kernel_size to be different - from the standard (2,2,2), e.g. if the volumetric data is anisotropic - (make sure to use complementary scale_factor in the decoder path) followed by - a basic module (DoubleConv or ResNetBlock). - - Args: - in_channels (int): number of input channels - out_channels (int): number of output channels - conv_kernel_size (int or tuple): size of the convolving kernel - apply_pooling (bool): if True use MaxPool3d before DoubleConv - pool_kernel_size (int or tuple): the size of the window - pool_type (str): pooling layer: 'max' or 'avg' - basic_module(nn.Module): either ResNetBlock or DoubleConv - conv_layer_order (string): determines the order of layers - in `DoubleConv` module. See `DoubleConv` for more info. - num_groups (int): number of groups for the GroupNorm - padding (int or tuple): add zero-padding added to all three sides of the input - upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2 - dropout_prob (float or tuple): dropout probability, default 0.1 - is3d (bool): use 3d or 2d convolutions/pooling operation - """ - - def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True, - pool_kernel_size=2, pool_type='max', basic_module=DoubleConv, conv_layer_order='gcr', - num_groups=8, padding=1, upscale=2, dropout_prob=0.1, is3d=True): - super(Encoder, self).__init__() - assert pool_type in ['max', 'avg'] - if apply_pooling: - if pool_type == 'max': - if is3d: - self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size) - else: - self.pooling = nn.MaxPool2d(kernel_size=pool_kernel_size) - else: - if is3d: - self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size) - else: - self.pooling = nn.AvgPool2d(kernel_size=pool_kernel_size) - else: - self.pooling = None - - self.basic_module = basic_module(in_channels, out_channels, - encoder=True, - kernel_size=conv_kernel_size, - order=conv_layer_order, - num_groups=num_groups, - padding=padding, - upscale=upscale, - dropout_prob=dropout_prob, - is3d=is3d) - - def forward(self, x): - if self.pooling is not None: - x = self.pooling(x) - x = self.basic_module(x) - return x - - -class Decoder(nn.Module): - """ - A single module for decoder path consisting of the upsampling layer - (either learned ConvTranspose3d or nearest neighbor interpolation) - followed by a basic module (DoubleConv or ResNetBlock). - - Args: - in_channels (int): number of input channels - out_channels (int): number of output channels - conv_kernel_size (int or tuple): size of the convolving kernel - scale_factor (int or tuple): used as the multiplier for the image H/W/D in - case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation - from the corresponding encoder - basic_module(nn.Module): either ResNetBlock or DoubleConv - conv_layer_order (string): determines the order of layers - in `DoubleConv` module. See `DoubleConv` for more info. - num_groups (int): number of groups for the GroupNorm - padding (int or tuple): add zero-padding added to all three sides of the input - upsample (str): algorithm used for upsampling: - InterpolateUpsampling: 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area' - TransposeConvUpsampling: 'deconv' - No upsampling: None - Default: 'default' (chooses automatically) - dropout_prob (float or tuple): dropout probability, default 0.1 - """ - - def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=2, basic_module=DoubleConv, - conv_layer_order='gcr', num_groups=8, padding=1, upsample='default', - dropout_prob=0.1, is3d=True): - super(Decoder, self).__init__() - - # perform concat joining per default - concat = True - - # don't adapt channels after join operation - adapt_channels = False - - if upsample is not None and upsample != 'none': - if upsample == 'default': - if basic_module == DoubleConv: - upsample = 'nearest' # use nearest neighbor interpolation for upsampling - concat = True # use concat joining - adapt_channels = False # don't adapt channels - elif basic_module == ResNetBlock or basic_module == ResNetBlockSE: - upsample = 'deconv' # use deconvolution upsampling - concat = False # use summation joining - adapt_channels = True # adapt channels after joining - - # perform deconvolution upsampling if mode is deconv - if upsample == 'deconv': - self.upsampling = TransposeConvUpsampling(in_channels=in_channels, out_channels=out_channels, - kernel_size=conv_kernel_size, scale_factor=scale_factor, - is3d=is3d) - else: - self.upsampling = InterpolateUpsampling(mode=upsample) - else: - # no upsampling - self.upsampling = NoUpsampling() - # concat joining - self.joining = partial(self._joining, concat=True) - - # perform joining operation - self.joining = partial(self._joining, concat=concat) - - # adapt the number of in_channels for the ResNetBlock - if adapt_channels is True: - in_channels = out_channels - - self.basic_module = basic_module(in_channels, out_channels, - encoder=False, - kernel_size=conv_kernel_size, - order=conv_layer_order, - num_groups=num_groups, - padding=padding, - dropout_prob=dropout_prob, - is3d=is3d) - - def forward(self, encoder_features, x): - x = self.upsampling(encoder_features=encoder_features, x=x) - x = self.joining(encoder_features, x) - x = self.basic_module(x) - return x - - @staticmethod - def _joining(encoder_features, x, concat): - if concat: - return torch.cat((encoder_features, x), dim=1) - else: - return encoder_features + x - - -def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding, - conv_upscale, dropout_prob, - layer_order, num_groups, pool_kernel_size, is3d): - # create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)` - encoders = [] - for i, out_feature_num in enumerate(f_maps): - if i == 0: - # apply conv_coord only in the first encoder if any - encoder = Encoder(in_channels, out_feature_num, - apply_pooling=False, # skip pooling in the firs encoder - basic_module=basic_module, - conv_layer_order=layer_order, - conv_kernel_size=conv_kernel_size, - num_groups=num_groups, - padding=conv_padding, - upscale=conv_upscale, - dropout_prob=dropout_prob, - is3d=is3d) - else: - encoder = Encoder(f_maps[i - 1], out_feature_num, - basic_module=basic_module, - conv_layer_order=layer_order, - conv_kernel_size=conv_kernel_size, - num_groups=num_groups, - pool_kernel_size=pool_kernel_size, - padding=conv_padding, - upscale=conv_upscale, - dropout_prob=dropout_prob, - is3d=is3d) - - encoders.append(encoder) - - return nn.ModuleList(encoders) - - -def create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, - num_groups, upsample, dropout_prob, is3d): - # create decoder path consisting of the Decoder modules. The length of the decoder list is equal to `len(f_maps) - 1` - decoders = [] - reversed_f_maps = list(reversed(f_maps)) - for i in range(len(reversed_f_maps) - 1): - if basic_module == DoubleConv and upsample != 'deconv': - in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1] - else: - in_feature_num = reversed_f_maps[i] - - out_feature_num = reversed_f_maps[i + 1] - - decoder = Decoder(in_feature_num, out_feature_num, - basic_module=basic_module, - conv_layer_order=layer_order, - conv_kernel_size=conv_kernel_size, - num_groups=num_groups, - padding=conv_padding, - upsample=upsample, - dropout_prob=dropout_prob, - is3d=is3d) - decoders.append(decoder) - return nn.ModuleList(decoders) - - -class AbstractUpsampling(nn.Module): - """ - Abstract class for upsampling. A given implementation should upsample a given 5D input tensor using either - interpolation or learned transposed convolution. - """ - - def __init__(self, upsample): - super(AbstractUpsampling, self).__init__() - self.upsample = upsample - - def forward(self, encoder_features, x): - # get the spatial dimensions of the output given the encoder_features - output_size = encoder_features.size()[2:] - # upsample the input and return - return self.upsample(x, output_size) - - -class InterpolateUpsampling(AbstractUpsampling): - """ - Args: - mode (str): algorithm used for upsampling: - 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest' - used only if transposed_conv is False - """ - - def __init__(self, mode='nearest'): - upsample = partial(self._interpolate, mode=mode) - super().__init__(upsample) - - @staticmethod - def _interpolate(x, size, mode): - return F.interpolate(x, size=size, mode=mode) - - -class TransposeConvUpsampling(AbstractUpsampling): - """ - Args: - in_channels (int): number of input channels for transposed conv - used only if transposed_conv is True - out_channels (int): number of output channels for transpose conv - used only if transposed_conv is True - kernel_size (int or tuple): size of the convolving kernel - used only if transposed_conv is True - scale_factor (int or tuple): stride of the convolution - used only if transposed_conv is True - is3d (bool): if True use ConvTranspose3d, otherwise use ConvTranspose2d - """ - - class Upsample(nn.Module): - """ - Workaround the 'ValueError: requested an output size...' in the `_output_padding` method in - transposed convolution. It performs transposed conv followed by the interpolation to the correct size if necessary. - """ - - def __init__(self, conv_transposed, is3d): - super().__init__() - self.conv_transposed = conv_transposed - self.is3d = is3d - - def forward(self, x, size): - x = self.conv_transposed(x) - return F.interpolate(x, size=size) - - def __init__(self, in_channels, out_channels, kernel_size=3, scale_factor=2, is3d=True): - # make sure that the output size reverses the MaxPool3d from the corresponding encoder - if is3d is True: - conv_transposed = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, - stride=scale_factor, padding=1, bias=False) - else: - conv_transposed = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, - stride=scale_factor, padding=1, bias=False) - upsample = self.Upsample(conv_transposed, is3d) - super().__init__(upsample) - - -class NoUpsampling(AbstractUpsampling): - def __init__(self): - super().__init__(self._no_upsampling) - - @staticmethod - def _no_upsampling(x, size): - return x diff --git a/build/lib/pytorch3dunet/unet3d/config.py b/build/lib/pytorch3dunet/unet3d/config.py deleted file mode 100644 index bb011632..00000000 --- a/build/lib/pytorch3dunet/unet3d/config.py +++ /dev/null @@ -1,79 +0,0 @@ -import argparse -import os -import shutil - -import torch -import yaml - -from pytorch3dunet.unet3d import utils - -logger = utils.get_logger('ConfigLoader') - - -def _override_config(args, config): - """Overrides config params with the ones given in command line.""" - - args_dict = vars(args) - # remove the first argument which is the config file path - args_dict.pop('config') - - for key, value in args_dict.items(): - if value is None: - continue - c = config - for k in key.split('.'): - if k not in c: - raise ValueError(f'Invalid config key: {key}') - if isinstance(c[k], dict): - c = c[k] - else: - c[k] = value - - -def load_config(): - parser = argparse.ArgumentParser(description='UNet3D') - parser.add_argument('--config', type=str, help='Path to the YAML config file', required=True) - # add additional command line arguments for the prediction that override the ones in the config file - parser.add_argument('--model_path', type=str, required=False) - parser.add_argument('--loaders.output_dir', type=str, required=False) - parser.add_argument('--loaders.test.file_paths', type=str, nargs="+", required=False) - parser.add_argument('--loaders.test.slice_builder.patch_shape', type=int, nargs="+", required=False) - parser.add_argument('--loaders.test.slice_builder.stride_shape', type=int, nargs="+", required=False) - - args = parser.parse_args() - config_path = args.config - config = yaml.safe_load(open(config_path, 'r')) - _override_config(args, config) - - device = config.get('device', None) - if device == 'cpu': - logger.warning('CPU mode forced in config, this will likely result in slow training/prediction') - config['device'] = 'cpu' - return config - - if torch.cuda.is_available(): - config['device'] = 'cuda' - else: - logger.warning('CUDA not available, using CPU') - config['device'] = 'cpu' - return config, config_path - - -def copy_config(config, config_path): - """Copies the config file to the checkpoint folder.""" - - def _get_last_subfolder_path(path): - subfolders = [f.path for f in os.scandir(path) if f.is_dir()] - return max(subfolders, default=None) - - checkpoint_dir = os.path.join( - config['trainer'].pop('checkpoint_dir'), 'logs') - last_run_dir = _get_last_subfolder_path(checkpoint_dir) - config_file_name = os.path.basename(config_path) - - if last_run_dir: - shutil.copy2(config_path, os.path.join(last_run_dir, config_file_name)) - - -def _load_config_yaml(config_file): - return yaml.safe_load(open(config_file, 'r')) diff --git a/build/lib/pytorch3dunet/unet3d/losses.py b/build/lib/pytorch3dunet/unet3d/losses.py deleted file mode 100644 index 6a53966f..00000000 --- a/build/lib/pytorch3dunet/unet3d/losses.py +++ /dev/null @@ -1,345 +0,0 @@ -import torch -import torch.nn.functional as F -from torch import nn as nn -from torch.nn import MSELoss, SmoothL1Loss, L1Loss - - -def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None): - """ - Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given a multi channel input and target. - Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function. - - Args: - input (torch.Tensor): NxCxSpatial input tensor - target (torch.Tensor): NxCxSpatial target tensor - epsilon (float): prevents division by zero - weight (torch.Tensor): Cx1 tensor of weight per channel/class - """ - - # input and target shapes must match - assert input.size() == target.size(), "'input' and 'target' must have the same shape" - - input = flatten(input) - target = flatten(target) - target = target.float() - - # compute per channel Dice Coefficient - intersect = (input * target).sum(-1) - if weight is not None: - intersect = weight * intersect - - # here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1) - denominator = (input * input).sum(-1) + (target * target).sum(-1) - return 2 * (intersect / denominator.clamp(min=epsilon)) - - -class _MaskingLossWrapper(nn.Module): - """ - Loss wrapper which prevents the gradient of the loss to be computed where target is equal to `ignore_index`. - """ - - def __init__(self, loss, ignore_index): - super(_MaskingLossWrapper, self).__init__() - assert ignore_index is not None, 'ignore_index cannot be None' - self.loss = loss - self.ignore_index = ignore_index - - def forward(self, input, target): - mask = target.clone().ne_(self.ignore_index) - mask.requires_grad = False - - # mask out input/target so that the gradient is zero where on the mask - input = input * mask - target = target * mask - - # forward masked input and target to the loss - return self.loss(input, target) - - -class SkipLastTargetChannelWrapper(nn.Module): - """ - Loss wrapper which removes additional target channel - """ - - def __init__(self, loss, squeeze_channel=False): - super(SkipLastTargetChannelWrapper, self).__init__() - self.loss = loss - self.squeeze_channel = squeeze_channel - - def forward(self, input, target, weight=None): - assert target.size(1) > 1, 'Target tensor has a singleton channel dimension, cannot remove channel' - - # skips last target channel if needed - target = target[:, :-1, ...] - - if self.squeeze_channel: - # squeeze channel dimension - target = torch.squeeze(target, dim=1) - if weight is not None: - return self.loss(input, target, weight) - return self.loss(input, target) - - -class _AbstractDiceLoss(nn.Module): - """ - Base class for different implementations of Dice loss. - """ - - def __init__(self, weight=None, normalization='sigmoid'): - super(_AbstractDiceLoss, self).__init__() - self.register_buffer('weight', weight) - # The output from the network during training is assumed to be un-normalized probabilities and we would - # like to normalize the logits. Since Dice (or soft Dice in this case) is usually used for binary data, - # normalizing the channels with Sigmoid is the default choice even for multi-class segmentation problems. - # However if one would like to apply Softmax in order to get the proper probability distribution from the - # output, just specify `normalization=Softmax` - assert normalization in ['sigmoid', 'softmax', 'none'] - if normalization == 'sigmoid': - self.normalization = nn.Sigmoid() - elif normalization == 'softmax': - self.normalization = nn.Softmax(dim=1) - else: - self.normalization = lambda x: x - - def dice(self, input, target, weight): - # actual Dice score computation; to be implemented by the subclass - raise NotImplementedError - - def forward(self, input, target): - # get probabilities from logits - input = self.normalization(input) - - # compute per channel Dice coefficient - per_channel_dice = self.dice(input, target, weight=self.weight) - - # average Dice score across all channels/classes - return 1. - torch.mean(per_channel_dice) - - -class DiceLoss(_AbstractDiceLoss): - """Computes Dice Loss according to https://arxiv.org/abs/1606.04797. - For multi-class segmentation `weight` parameter can be used to assign different weights per class. - The input to the loss function is assumed to be a logit and will be normalized by the Sigmoid function. - """ - - def __init__(self, weight=None, normalization='sigmoid'): - super().__init__(weight, normalization) - - def dice(self, input, target, weight): - return compute_per_channel_dice(input, target, weight=self.weight) - - -class GeneralizedDiceLoss(_AbstractDiceLoss): - """Computes Generalized Dice Loss (GDL) as described in https://arxiv.org/pdf/1707.03237.pdf. - """ - - def __init__(self, normalization='sigmoid', epsilon=1e-6): - super().__init__(weight=None, normalization=normalization) - self.epsilon = epsilon - - def dice(self, input, target, weight): - assert input.size() == target.size(), "'input' and 'target' must have the same shape" - - input = flatten(input) - target = flatten(target) - target = target.float() - - if input.size(0) == 1: - # for GDL to make sense we need at least 2 channels (see https://arxiv.org/pdf/1707.03237.pdf) - # put foreground and background voxels in separate channels - input = torch.cat((input, 1 - input), dim=0) - target = torch.cat((target, 1 - target), dim=0) - - # GDL weighting: the contribution of each label is corrected by the inverse of its volume - w_l = target.sum(-1) - w_l = 1 / (w_l * w_l).clamp(min=self.epsilon) - w_l.requires_grad = False - - intersect = (input * target).sum(-1) - intersect = intersect * w_l - - denominator = (input + target).sum(-1) - denominator = (denominator * w_l).clamp(min=self.epsilon) - - return 2 * (intersect.sum() / denominator.sum()) - - -class BCEDiceLoss(nn.Module): - """Linear combination of BCE and Dice losses""" - - def __init__(self, alpha, beta): - super(BCEDiceLoss, self).__init__() - self.alpha = alpha - self.bce = nn.BCEWithLogitsLoss() - self.beta = beta - self.dice = DiceLoss() - - def forward(self, input, target): - return self.alpha * self.bce(input, target) + self.beta * self.dice(input, target) - - -class WeightedCrossEntropyLoss(nn.Module): - """WeightedCrossEntropyLoss (WCE) as described in https://arxiv.org/pdf/1707.03237.pdf - """ - - def __init__(self, ignore_index=-1): - super(WeightedCrossEntropyLoss, self).__init__() - self.ignore_index = ignore_index - - def forward(self, input, target): - weight = self._class_weights(input) - return F.cross_entropy(input, target, weight=weight, ignore_index=self.ignore_index) - - @staticmethod - def _class_weights(input): - # normalize the input first - input = F.softmax(input, dim=1) - flattened = flatten(input) - nominator = (1. - flattened).sum(-1) - denominator = flattened.sum(-1) - class_weights = nominator / denominator - return class_weights.detach() - - -class PixelWiseCrossEntropyLoss(nn.Module): - def __init__(self, ignore_index=None): - super(PixelWiseCrossEntropyLoss, self).__init__() - self.ignore_index = ignore_index - self.log_softmax = nn.LogSoftmax(dim=1) - - def forward(self, input, target, weights): - assert target.size() == weights.size() - # normalize the input - log_probabilities = self.log_softmax(input) - # standard CrossEntropyLoss requires the target to be (NxDxHxW), so we need to expand it to (NxCxDxHxW) - if self.ignore_index is not None: - mask = target == self.ignore_index - target[mask] = 0 - else: - mask = torch.zeros_like(target) - # add channel dimension and invert the mask - mask = 1 - mask.unsqueeze(1) - # convert target to one-hot encoding - target = F.one_hot(target.long()) - if target.ndim == 5: - # permute target to (NxCxDxHxW) - target = target.permute(0, 4, 1, 2, 3).contiguous() - else: - target = target.permute(0, 3, 1, 2).contiguous() - # apply the mask on the target - target = target * mask - # add channel dimension to the weights - weights = weights.unsqueeze(1) - # compute the losses - result = -weights * target * log_probabilities - return result.mean() - - -class WeightedSmoothL1Loss(nn.SmoothL1Loss): - def __init__(self, threshold, initial_weight, apply_below_threshold=True): - super().__init__(reduction="none") - self.threshold = threshold - self.apply_below_threshold = apply_below_threshold - self.weight = initial_weight - - def forward(self, input, target): - l1 = super().forward(input, target) - - if self.apply_below_threshold: - mask = target < self.threshold - else: - mask = target >= self.threshold - - l1[mask] = l1[mask] * self.weight - - return l1.mean() - - -def flatten(tensor): - """Flattens a given tensor such that the channel axis is first. - The shapes are transformed as follows: - (N, C, D, H, W) -> (C, N * D * H * W) - """ - # number of channels - C = tensor.size(1) - # new axis order - axis_order = (1, 0) + tuple(range(2, tensor.dim())) - # Transpose: (N, C, D, H, W) -> (C, N, D, H, W) - transposed = tensor.permute(axis_order) - # Flatten: (C, N, D, H, W) -> (C, N * D * H * W) - return transposed.contiguous().view(C, -1) - - -def get_loss_criterion(config): - """ - Returns the loss function based on provided configuration - :param config: (dict) a top level configuration object containing the 'loss' key - :return: an instance of the loss function - """ - assert 'loss' in config, 'Could not find loss function configuration' - loss_config = config['loss'] - name = loss_config.pop('name') - - ignore_index = loss_config.pop('ignore_index', None) - skip_last_target = loss_config.pop('skip_last_target', False) - weight = loss_config.pop('weight', None) - - if weight is not None: - weight = torch.tensor(weight) - - pos_weight = loss_config.pop('pos_weight', None) - if pos_weight is not None: - pos_weight = torch.tensor(pos_weight) - - loss = _create_loss(name, loss_config, weight, ignore_index, pos_weight) - - if not (ignore_index is None or name in ['CrossEntropyLoss', 'WeightedCrossEntropyLoss']): - # use MaskingLossWrapper only for non-cross-entropy losses, since CE losses allow specifying 'ignore_index' directly - loss = _MaskingLossWrapper(loss, ignore_index) - - if skip_last_target: - loss = SkipLastTargetChannelWrapper(loss, loss_config.get('squeeze_channel', False)) - - if torch.cuda.is_available(): - loss = loss.cuda() - - return loss - - -####################################################################################################################### - -def _create_loss(name, loss_config, weight, ignore_index, pos_weight): - if name == 'BCEWithLogitsLoss': - return nn.BCEWithLogitsLoss(pos_weight=pos_weight) - elif name == 'BCEDiceLoss': - alpha = loss_config.get('alpha', 1.) - beta = loss_config.get('beta', 1.) - return BCEDiceLoss(alpha, beta) - elif name == 'CrossEntropyLoss': - if ignore_index is None: - ignore_index = -100 # use the default 'ignore_index' as defined in the CrossEntropyLoss - return nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index) - elif name == 'WeightedCrossEntropyLoss': - if ignore_index is None: - ignore_index = -100 # use the default 'ignore_index' as defined in the CrossEntropyLoss - return WeightedCrossEntropyLoss(ignore_index=ignore_index) - elif name == 'PixelWiseCrossEntropyLoss': - return PixelWiseCrossEntropyLoss(ignore_index=ignore_index) - elif name == 'GeneralizedDiceLoss': - normalization = loss_config.get('normalization', 'sigmoid') - return GeneralizedDiceLoss(normalization=normalization) - elif name == 'DiceLoss': - normalization = loss_config.get('normalization', 'sigmoid') - return DiceLoss(weight=weight, normalization=normalization) - elif name == 'MSELoss': - return MSELoss() - elif name == 'SmoothL1Loss': - return SmoothL1Loss() - elif name == 'L1Loss': - return L1Loss() - elif name == 'WeightedSmoothL1Loss': - return WeightedSmoothL1Loss(threshold=loss_config['threshold'], - initial_weight=loss_config['initial_weight'], - apply_below_threshold=loss_config.get('apply_below_threshold', True)) - else: - raise RuntimeError(f"Unsupported loss function: '{name}'") diff --git a/build/lib/pytorch3dunet/unet3d/metrics.py b/build/lib/pytorch3dunet/unet3d/metrics.py deleted file mode 100644 index 2b60b4b7..00000000 --- a/build/lib/pytorch3dunet/unet3d/metrics.py +++ /dev/null @@ -1,445 +0,0 @@ -import importlib - -import numpy as np -import torch -from skimage import measure -from skimage.metrics import adapted_rand_error, peak_signal_noise_ratio, mean_squared_error - -from pytorch3dunet.unet3d.losses import compute_per_channel_dice -from pytorch3dunet.unet3d.seg_metrics import AveragePrecision, Accuracy -from pytorch3dunet.unet3d.utils import get_logger, expand_as_one_hot, convert_to_numpy - -logger = get_logger('EvalMetric') - - -class DiceCoefficient: - """Computes Dice Coefficient. - Generalized to multiple channels by computing per-channel Dice Score - (as described in https://arxiv.org/pdf/1707.03237.pdf) and then simply taking the average. - Input is expected to be probabilities instead of logits. - This metric is mostly useful when channels contain the same semantic class (e.g. affinities computed with different offsets). - DO NOT USE this metric when training with DiceLoss, otherwise the results will be biased towards the loss. - """ - - def __init__(self, epsilon=1e-6, **kwargs): - self.epsilon = epsilon - - def __call__(self, input, target): - # Average across channels in order to get the final score - return torch.mean(compute_per_channel_dice(input, target, epsilon=self.epsilon)) - - -class MeanIoU: - """ - Computes IoU for each class separately and then averages over all classes. - """ - - def __init__(self, skip_channels=(), ignore_index=None, **kwargs): - """ - :param skip_channels: list/tuple of channels to be ignored from the IoU computation - :param ignore_index: id of the label to be ignored from IoU computation - """ - self.ignore_index = ignore_index - self.skip_channels = skip_channels - - def __call__(self, input, target): - """ - :param input: 5D probability maps torch float tensor (NxCxDxHxW) - :param target: 4D or 5D ground truth torch tensor. 4D (NxDxHxW) tensor will be expanded to 5D as one-hot - :return: intersection over union averaged over all channels - """ - assert input.dim() == 5 - - n_classes = input.size()[1] - - if target.dim() == 4: - target = expand_as_one_hot(target, C=n_classes, ignore_index=self.ignore_index) - - assert input.size() == target.size() - - per_batch_iou = [] - for _input, _target in zip(input, target): - binary_prediction = self._binarize_predictions(_input, n_classes) - - if self.ignore_index is not None: - # zero out ignore_index - mask = _target == self.ignore_index - binary_prediction[mask] = 0 - _target[mask] = 0 - - # convert to uint8 just in case - binary_prediction = binary_prediction.byte() - _target = _target.byte() - - per_channel_iou = [] - for c in range(n_classes): - if c in self.skip_channels: - continue - - per_channel_iou.append(self._jaccard_index(binary_prediction[c], _target[c])) - - assert per_channel_iou, "All channels were ignored from the computation" - mean_iou = torch.mean(torch.tensor(per_channel_iou)) - per_batch_iou.append(mean_iou) - - return torch.mean(torch.tensor(per_batch_iou)) - - def _binarize_predictions(self, input, n_classes): - """ - Puts 1 for the class/channel with the highest probability and 0 in other channels. Returns byte tensor of the - same size as the input tensor. - """ - if n_classes == 1: - # for single channel input just threshold the probability map - result = input > 0.5 - return result.long() - - _, max_index = torch.max(input, dim=0, keepdim=True) - return torch.zeros_like(input, dtype=torch.uint8).scatter_(0, max_index, 1) - - def _jaccard_index(self, prediction, target): - """ - Computes IoU for a given target and prediction tensors - """ - return torch.sum(prediction & target).float() / torch.clamp(torch.sum(prediction | target).float(), min=1e-8) - - -class AdaptedRandError: - """ - A functor which computes an Adapted Rand error as defined by the SNEMI3D contest - (http://brainiac2.mit.edu/SNEMI3D/evaluation). - - This is a generic implementation which takes the input, converts it to the segmentation image (see `input_to_segm()`) - and then computes the ARand between the segmentation and the ground truth target. Depending on one's use case - it's enough to extend this class and implement the `input_to_segm` method. - - Args: - use_last_target (bool): if true, use the last channel from the target to compute the ARand, otherwise the first. - """ - - def __init__(self, use_last_target=False, ignore_index=None, **kwargs): - self.use_last_target = use_last_target - self.ignore_index = ignore_index - - def __call__(self, input, target): - """ - Compute ARand Error for each input, target pair in the batch and return the mean value. - - Args: - input (torch.tensor): 5D (NCDHW) output from the network - target (torch.tensor): 5D (NCDHW) ground truth segmentation - - Returns: - average ARand Error across the batch - """ - - # converts input and target to numpy arrays - input, target = convert_to_numpy(input, target) - if self.use_last_target: - target = target[:, -1, ...] # 4D - else: - # use 1st target channel - target = target[:, 0, ...] # 4D - - # ensure target is of integer type - target = target.astype(np.int32) - - if self.ignore_index is not None: - target[target == self.ignore_index] = 0 - - per_batch_arand = [] - for _input, _target in zip(input, target): - if np.all(_target == _target.flat[0]): # skip ARand eval if there is only one label in the patch due to zero-division - logger.info('Skipping ARandError computation: only 1 label present in the ground truth') - per_batch_arand.append(0.) - continue - - # convert _input to segmentation CDHW - segm = self.input_to_segm(_input) - assert segm.ndim == 4 - - # compute per channel arand and return the minimum value - per_channel_arand = [adapted_rand_error(_target, channel_segm)[0] for channel_segm in segm] - per_batch_arand.append(np.min(per_channel_arand)) - - # return mean arand error - mean_arand = torch.mean(torch.tensor(per_batch_arand)) - logger.info(f'ARand: {mean_arand.item()}') - return mean_arand - - def input_to_segm(self, input): - """ - Converts input tensor (output from the network) to the segmentation image. E.g. if the input is the boundary - pmaps then one option would be to threshold it and run connected components in order to return the segmentation. - - :param input: 4D tensor (CDHW) - :return: segmentation volume either 4D (segmentation per channel) - """ - # by deafult assume that input is a segmentation volume itself - return input - - -class BoundaryAdaptedRandError(AdaptedRandError): - """ - Compute ARand between the input boundary map and target segmentation. - Boundary map is thresholded, and connected components is run to get the predicted segmentation - """ - - def __init__(self, thresholds=None, use_last_target=True, ignore_index=None, input_channel=None, invert_pmaps=True, - save_plots=False, plots_dir='.', **kwargs): - super().__init__(use_last_target=use_last_target, ignore_index=ignore_index, save_plots=save_plots, - plots_dir=plots_dir, **kwargs) - - if thresholds is None: - thresholds = [0.3, 0.4, 0.5, 0.6] - assert isinstance(thresholds, list) - self.thresholds = thresholds - self.input_channel = input_channel - self.invert_pmaps = invert_pmaps - - def input_to_segm(self, input): - if self.input_channel is not None: - input = np.expand_dims(input[self.input_channel], axis=0) - - segs = [] - for predictions in input: - for th in self.thresholds: - # threshold probability maps - predictions = predictions > th - - if self.invert_pmaps: - # for connected component analysis we need to treat boundary signal as background - # assign 0-label to boundary mask - predictions = np.logical_not(predictions) - - predictions = predictions.astype(np.uint8) - # run connected components on the predicted mask; consider only 1-connectivity - seg = measure.label(predictions, background=0, connectivity=1) - segs.append(seg) - - return np.stack(segs) - - -class GenericAdaptedRandError(AdaptedRandError): - def __init__(self, input_channels, thresholds=None, use_last_target=True, ignore_index=None, invert_channels=None, - **kwargs): - - super().__init__(use_last_target=use_last_target, ignore_index=ignore_index, **kwargs) - assert isinstance(input_channels, list) or isinstance(input_channels, tuple) - self.input_channels = input_channels - if thresholds is None: - thresholds = [0.3, 0.4, 0.5, 0.6] - assert isinstance(thresholds, list) - self.thresholds = thresholds - if invert_channels is None: - invert_channels = [] - self.invert_channels = invert_channels - - def input_to_segm(self, input): - # pick only the channels specified in the input_channels - results = [] - for i in self.input_channels: - c = input[i] - # invert channel if necessary - if i in self.invert_channels: - c = 1 - c - results.append(c) - - input = np.stack(results) - - segs = [] - for predictions in input: - for th in self.thresholds: - # run connected components on the predicted mask; consider only 1-connectivity - seg = measure.label((predictions > th).astype(np.uint8), background=0, connectivity=1) - segs.append(seg) - - return np.stack(segs) - - -class GenericAveragePrecision: - def __init__(self, min_instance_size=None, use_last_target=False, metric='ap', **kwargs): - self.min_instance_size = min_instance_size - self.use_last_target = use_last_target - assert metric in ['ap', 'acc'] - if metric == 'ap': - # use AveragePrecision - self.metric = AveragePrecision() - else: - # use Accuracy at 0.5 IoU - self.metric = Accuracy(iou_threshold=0.5) - - def __call__(self, input, target): - if target.dim() == 5: - if self.use_last_target: - target = target[:, -1, ...] # 4D - else: - # use 1st target channel - target = target[:, 0, ...] # 4D - - input1 = input2 = input - multi_head = isinstance(input, tuple) - if multi_head: - input1, input2 = input - - input1, input2, target = convert_to_numpy(input1, input2, target) - - batch_aps = [] - i_batch = 0 - # iterate over the batch - for inp1, inp2, tar in zip(input1, input2, target): - if multi_head: - inp = (inp1, inp2) - else: - inp = inp1 - - segs = self.input_to_seg(inp, tar) # expects 4D - assert segs.ndim == 4 - # convert target to seg - tar = self.target_to_seg(tar) - - # filter small instances if necessary - tar = self._filter_instances(tar) - - # compute average precision per channel - segs_aps = [self.metric(self._filter_instances(seg), tar) for seg in segs] - - logger.info(f'Batch: {i_batch}. Max Average Precision for channel: {np.argmax(segs_aps)}') - # save max AP - batch_aps.append(np.max(segs_aps)) - i_batch += 1 - - return torch.tensor(batch_aps).mean() - - def _filter_instances(self, input): - """ - Filters instances smaller than 'min_instance_size' by overriding them with 0-index - :param input: input instance segmentation - """ - if self.min_instance_size is not None: - labels, counts = np.unique(input, return_counts=True) - for label, count in zip(labels, counts): - if count < self.min_instance_size: - input[input == label] = 0 - return input - - def input_to_seg(self, input, target=None): - raise NotImplementedError - - def target_to_seg(self, target): - return target - - -class BlobsAveragePrecision(GenericAveragePrecision): - """ - Computes Average Precision given foreground prediction and ground truth instance segmentation. - """ - - def __init__(self, thresholds=None, metric='ap', min_instance_size=None, input_channel=0, **kwargs): - super().__init__(min_instance_size=min_instance_size, use_last_target=True, metric=metric) - if thresholds is None: - thresholds = [0.4, 0.5, 0.6, 0.7, 0.8] - assert isinstance(thresholds, list) - self.thresholds = thresholds - self.input_channel = input_channel - - def input_to_seg(self, input, target=None): - input = input[self.input_channel] - segs = [] - for th in self.thresholds: - # threshold and run connected components - mask = (input > th).astype(np.uint8) - seg = measure.label(mask, background=0, connectivity=1) - segs.append(seg) - return np.stack(segs) - - -class BlobsBoundaryAveragePrecision(GenericAveragePrecision): - """ - Computes Average Precision given foreground prediction, boundary prediction and ground truth instance segmentation. - Segmentation mask is computed as (P_mask - P_boundary) > th followed by a connected component - """ - - def __init__(self, thresholds=None, metric='ap', min_instance_size=None, **kwargs): - super().__init__(min_instance_size=min_instance_size, use_last_target=True, metric=metric) - if thresholds is None: - thresholds = [0.3, 0.4, 0.5, 0.6, 0.7] - assert isinstance(thresholds, list) - self.thresholds = thresholds - - def input_to_seg(self, input, target=None): - # input = P_mask - P_boundary - input = input[0] - input[1] - segs = [] - for th in self.thresholds: - # threshold and run connected components - mask = (input > th).astype(np.uint8) - seg = measure.label(mask, background=0, connectivity=1) - segs.append(seg) - return np.stack(segs) - - -class BoundaryAveragePrecision(GenericAveragePrecision): - """ - Computes Average Precision given boundary prediction and ground truth instance segmentation. - """ - - def __init__(self, thresholds=None, min_instance_size=None, input_channel=0, **kwargs): - super().__init__(min_instance_size=min_instance_size, use_last_target=True) - if thresholds is None: - thresholds = [0.3, 0.4, 0.5, 0.6] - assert isinstance(thresholds, list) - self.thresholds = thresholds - self.input_channel = input_channel - - def input_to_seg(self, input, target=None): - input = input[self.input_channel] - segs = [] - for th in self.thresholds: - seg = measure.label(np.logical_not(input > th).astype(np.uint8), background=0, connectivity=1) - segs.append(seg) - return np.stack(segs) - - -class PSNR: - """ - Computes Peak Signal to Noise Ratio. Use e.g. as an eval metric for denoising task - """ - - def __init__(self, **kwargs): - pass - - def __call__(self, input, target): - input, target = convert_to_numpy(input, target) - return peak_signal_noise_ratio(target, input) - - -class MSE: - """ - Computes MSE between input and target - """ - - def __init__(self, **kwargs): - pass - - def __call__(self, input, target): - input, target = convert_to_numpy(input, target) - return mean_squared_error(input, target) - - -def get_evaluation_metric(config): - """ - Returns the evaluation metric function based on provided configuration - :param config: (dict) a top level configuration object containing the 'eval_metric' key - :return: an instance of the evaluation metric - """ - - def _metric_class(class_name): - m = importlib.import_module('pytorch3dunet.unet3d.metrics') - clazz = getattr(m, class_name) - return clazz - - assert 'eval_metric' in config, 'Could not find evaluation metric configuration' - metric_config = config['eval_metric'] - metric_class = _metric_class(metric_config['name']) - return metric_class(**metric_config) diff --git a/build/lib/pytorch3dunet/unet3d/model.py b/build/lib/pytorch3dunet/unet3d/model.py deleted file mode 100644 index e4de49a7..00000000 --- a/build/lib/pytorch3dunet/unet3d/model.py +++ /dev/null @@ -1,249 +0,0 @@ -import torch.nn as nn - -from pytorch3dunet.unet3d.buildingblocks import DoubleConv, ResNetBlock, ResNetBlockSE, \ - create_decoders, create_encoders -from pytorch3dunet.unet3d.utils import get_class, number_of_features_per_level - - -class AbstractUNet(nn.Module): - """ - Base class for standard and residual UNet. - - Args: - in_channels (int): number of input channels - out_channels (int): number of output segmentation masks; - Note that the of out_channels might correspond to either - different semantic classes or to different binary segmentation mask. - It's up to the user of the class to interpret the out_channels and - use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class) - or BCEWithLogitsLoss (two-class) respectively) - f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number - of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4 - final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the final 1x1 convolution, - otherwise apply nn.Softmax. In effect only if `self.training == False`, i.e. during validation/testing - basic_module: basic model for the encoder/decoder (DoubleConv, ResNetBlock, ....) - layer_order (string): determines the order of layers in `SingleConv` module. - E.g. 'crg' stands for GroupNorm3d+Conv3d+ReLU. See `SingleConv` for more info - num_groups (int): number of groups for the GroupNorm - num_levels (int): number of levels in the encoder/decoder path (applied only if f_maps is an int) - default: 4 - is_segmentation (bool): if True and the model is in eval mode, Sigmoid/Softmax normalization is applied - after the final convolution; if False (regression problem) the normalization layer is skipped - conv_kernel_size (int or tuple): size of the convolving kernel in the basic_module - pool_kernel_size (int or tuple): the size of the window - conv_padding (int or tuple): add zero-padding added to all three sides of the input - conv_upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2 - upsample (str): algorithm used for decoder upsampling: - InterpolateUpsampling: 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area' - TransposeConvUpsampling: 'deconv' - No upsampling: None - Default: 'default' (chooses automatically) - dropout_prob (float or tuple): dropout probability, default: 0.1 - is3d (bool): if True the model is 3D, otherwise 2D, default: True - """ - - def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr', - num_groups=8, num_levels=4, is_segmentation=True, conv_kernel_size=3, pool_kernel_size=2, - conv_padding=1, conv_upscale=2, upsample='default', dropout_prob=0.1, is3d=True): - super(AbstractUNet, self).__init__() - - if isinstance(f_maps, int): - f_maps = number_of_features_per_level(f_maps, num_levels=num_levels) - - assert isinstance(f_maps, list) or isinstance(f_maps, tuple) - assert len(f_maps) > 1, "Required at least 2 levels in the U-Net" - if 'g' in layer_order: - assert num_groups is not None, "num_groups must be specified if GroupNorm is used" - - # create encoder path - self.encoders = create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, - conv_padding, conv_upscale, dropout_prob, - layer_order, num_groups, pool_kernel_size, is3d) - - # create decoder path - self.decoders = create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, - layer_order, num_groups, upsample, dropout_prob, - is3d) - - # in the last layer a 1×1 convolution reduces the number of output channels to the number of labels - if is3d: - self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1) - else: - self.final_conv = nn.Conv2d(f_maps[0], out_channels, 1) - - if is_segmentation: - # semantic segmentation problem - if final_sigmoid: - self.final_activation = nn.Sigmoid() - else: - self.final_activation = nn.Softmax(dim=1) - else: - # regression problem - self.final_activation = None - - def forward(self, x): - # encoder part - encoders_features = [] - for encoder in self.encoders: - x = encoder(x) - # reverse the encoder outputs to be aligned with the decoder - encoders_features.insert(0, x) - - # remove the last encoder's output from the list - # !!remember: it's the 1st in the list - encoders_features = encoders_features[1:] - - # decoder part - for decoder, encoder_features in zip(self.decoders, encoders_features): - # pass the output from the corresponding encoder and the output - # of the previous decoder - x = decoder(encoder_features, x) - - x = self.final_conv(x) - - # apply final_activation (i.e. Sigmoid or Softmax) only during prediction. - # During training the network outputs logits - if not self.training and self.final_activation is not None: - x = self.final_activation(x) - - return x - - -class UNet3D(AbstractUNet): - """ - 3DUnet model from - `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation" - `. - - Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder - """ - - def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', - num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1, - conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs): - super(UNet3D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - final_sigmoid=final_sigmoid, - basic_module=DoubleConv, - f_maps=f_maps, - layer_order=layer_order, - num_groups=num_groups, - num_levels=num_levels, - is_segmentation=is_segmentation, - conv_padding=conv_padding, - conv_upscale=conv_upscale, - upsample=upsample, - dropout_prob=dropout_prob, - is3d=True) - - -class ResidualUNet3D(AbstractUNet): - """ - Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf. - Uses ResNetBlock as a basic building block, summation joining instead - of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts). - Since the model effectively becomes a residual net, in theory it allows for deeper UNet. - """ - - def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', - num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1, - conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs): - super(ResidualUNet3D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - final_sigmoid=final_sigmoid, - basic_module=ResNetBlock, - f_maps=f_maps, - layer_order=layer_order, - num_groups=num_groups, - num_levels=num_levels, - is_segmentation=is_segmentation, - conv_padding=conv_padding, - conv_upscale=conv_upscale, - upsample=upsample, - dropout_prob=dropout_prob, - is3d=True) - - -class ResidualUNetSE3D(AbstractUNet): - """_summary_ - Residual 3DUnet model implementation with squeeze and excitation based on - https://arxiv.org/pdf/1706.00120.pdf. - Uses ResNetBlockSE as a basic building block, summation joining instead - of concatenation joining and transposed convolutions for upsampling (watch - out for block artifacts). Since the model effectively becomes a residual - net, in theory it allows for deeper UNet. - """ - - def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', - num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1, - conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs): - super(ResidualUNetSE3D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - final_sigmoid=final_sigmoid, - basic_module=ResNetBlockSE, - f_maps=f_maps, - layer_order=layer_order, - num_groups=num_groups, - num_levels=num_levels, - is_segmentation=is_segmentation, - conv_padding=conv_padding, - conv_upscale=conv_upscale, - upsample=upsample, - dropout_prob=dropout_prob, - is3d=True) - - -class UNet2D(AbstractUNet): - """ - 2DUnet model from - `"U-Net: Convolutional Networks for Biomedical Image Segmentation" ` - """ - - def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', - num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1, - conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs): - super(UNet2D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - final_sigmoid=final_sigmoid, - basic_module=DoubleConv, - f_maps=f_maps, - layer_order=layer_order, - num_groups=num_groups, - num_levels=num_levels, - is_segmentation=is_segmentation, - conv_padding=conv_padding, - conv_upscale=conv_upscale, - upsample=upsample, - dropout_prob=dropout_prob, - is3d=False) - - -class ResidualUNet2D(AbstractUNet): - """ - Residual 2DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf. - """ - - def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', - num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1, - conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs): - super(ResidualUNet2D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - final_sigmoid=final_sigmoid, - basic_module=ResNetBlock, - f_maps=f_maps, - layer_order=layer_order, - num_groups=num_groups, - num_levels=num_levels, - is_segmentation=is_segmentation, - conv_padding=conv_padding, - conv_upscale=conv_upscale, - upsample=upsample, - dropout_prob=dropout_prob, - is3d=False) - - -def get_model(model_config): - model_class = get_class(model_config['name'], modules=[ - 'pytorch3dunet.unet3d.model' - ]) - return model_class(**model_config) diff --git a/build/lib/pytorch3dunet/unet3d/predictor.py b/build/lib/pytorch3dunet/unet3d/predictor.py deleted file mode 100644 index c9b4f6eb..00000000 --- a/build/lib/pytorch3dunet/unet3d/predictor.py +++ /dev/null @@ -1,281 +0,0 @@ -import os -import time -from concurrent import futures -from pathlib import Path - -import h5py -import numpy as np -import torch -from skimage import measure -from torch import nn -from tqdm import tqdm - -from pytorch3dunet.datasets.hdf5 import AbstractHDF5Dataset -from pytorch3dunet.datasets.utils import SliceBuilder, remove_padding -from pytorch3dunet.unet3d.model import UNet2D -from pytorch3dunet.unet3d.utils import get_logger - -logger = get_logger('UNetPredictor') - - -def _get_output_file(dataset, suffix='_predictions', output_dir=None): - input_dir, file_name = os.path.split(dataset.file_path) - if output_dir is None: - output_dir = input_dir - output_filename = os.path.splitext(file_name)[0] + suffix + '.h5' - return Path(output_dir) / output_filename - - -def _is_2d_model(model): - if isinstance(model, nn.DataParallel): - model = model.module - return isinstance(model, UNet2D) - - -class _AbstractPredictor: - def __init__(self, - model: nn.Module, - output_dir: str, - out_channels: int, - output_dataset: str = 'predictions', - save_segmentation: bool = False, - prediction_channel: int = None, - **kwargs): - """ - Base class for predictors. - Args: - model: segmentation model - output_dir: directory where the predictions will be saved - out_channels: number of output channels of the model - output_dataset: name of the dataset in the H5 file where the predictions will be saved - save_segmentation: if true the segmentation will be saved instead of the probability maps - prediction_channel: save only the specified channel from the network output - """ - self.model = model - self.output_dir = output_dir - self.out_channels = out_channels - self.output_dataset = output_dataset - self.save_segmentation = save_segmentation - self.prediction_channel = prediction_channel - - def __call__(self, test_loader): - raise NotImplementedError - - -class StandardPredictor(_AbstractPredictor): - """ - Applies the model on the given dataset and saves the result as H5 file. - Predictions from the network are kept in memory. If the results from the network don't fit in into RAM - use `LazyPredictor` instead. - - The output dataset names inside the H5 is given by `output_dataset` config argument. - """ - - def __init__(self, - model: nn.Module, - output_dir: str, - out_channels: int, - output_dataset: str = 'predictions', - save_segmentation: bool = False, - prediction_channel: int = None, - **kwargs): - super().__init__(model, output_dir, out_channels, output_dataset, save_segmentation, prediction_channel, - **kwargs) - - def __call__(self, test_loader): - assert isinstance(test_loader.dataset, AbstractHDF5Dataset) - logger.info(f"Processing '{test_loader.dataset.file_path}'...") - start = time.perf_counter() - - logger.info(f'Running inference on {len(test_loader)} batches') - # dimensionality of the output predictions - volume_shape = test_loader.dataset.volume_shape() - if self.prediction_channel is not None: - # single channel prediction map - prediction_maps_shape = (1,) + volume_shape - else: - prediction_maps_shape = (self.out_channels,) + volume_shape - - # create destination H5 file - output_file = _get_output_file(dataset=test_loader.dataset, output_dir=self.output_dir) - with h5py.File(output_file, 'w') as h5_output_file: - # allocate prediction and normalization arrays - logger.info('Allocating prediction and normalization arrays...') - prediction_map, normalization_mask = self._allocate_prediction_maps(prediction_maps_shape, h5_output_file) - - # determine halo used for padding - patch_halo = test_loader.dataset.halo_shape - - # Sets the module in evaluation mode explicitly - # It is necessary for batchnorm/dropout layers if present as well as final Sigmoid/Softmax to be applied - self.model.eval() - # Run predictions on the entire input dataset - with torch.no_grad(): - for input, indices in tqdm(test_loader): - # send batch to gpu - if torch.cuda.is_available(): - input = input.pin_memory().cuda(non_blocking=True) - - if _is_2d_model(self.model): - # remove the singleton z-dimension from the input - input = torch.squeeze(input, dim=-3) - # forward pass - prediction = self.model(input) - # add the singleton z-dimension to the output - prediction = torch.unsqueeze(prediction, dim=-3) - else: - # forward pass - prediction = self.model(input) - - # unpad the predicted patch - prediction = remove_padding(prediction, patch_halo) - # convert to numpy array - prediction = prediction.cpu().numpy() - # for each batch sample - for pred, index in zip(prediction, indices): - # save patch index: (C,D,H,W) - if self.prediction_channel is None: - channel_slice = slice(0, self.out_channels) - else: - # use only the specified channel - channel_slice = slice(0, 1) - pred = np.expand_dims(pred[self.prediction_channel], axis=0) - - # add channel dimension to the index - index = (channel_slice,) + tuple(index) - # accumulate probabilities into the output prediction array - prediction_map[index] += pred - # count voxel visits for normalization - normalization_mask[index] += 1 - - logger.info(f'Finished inference in {time.perf_counter() - start:.2f} seconds') - # save results - output_type = 'segmentation' if self.save_segmentation else 'probability maps' - logger.info(f'Saving {output_type} to: {output_file}') - self._save_results(prediction_map, normalization_mask, h5_output_file, test_loader.dataset) - - def _allocate_prediction_maps(self, output_shape, output_file): - # initialize the output prediction arrays - prediction_map = np.zeros(output_shape, dtype='float32') - # initialize normalization mask in order to average out probabilities of overlapping patches - normalization_mask = np.zeros(output_shape, dtype='uint8') - return prediction_map, normalization_mask - - def _save_results(self, prediction_map, normalization_mask, output_file, dataset): - result = prediction_map / normalization_mask - if self.save_segmentation: - result = np.argmax(result, axis=0).astype('uint16') - output_file.create_dataset(self.output_dataset, data=result, compression="gzip") - - -class LazyPredictor(StandardPredictor): - """ - Applies the model on the given dataset and saves the result in the `output_file` in the H5 format. - Predicted patches are directly saved into the H5 and they won't be stored in memory. Since this predictor - is slower than the `StandardPredictor` it should only be used when the predicted volume does not fit into RAM. - """ - - def __init__(self, - model: nn.Module, - output_dir: str, - out_channels: int, - output_dataset: str = 'predictions', - save_segmentation: bool = False, - prediction_channel: int = None, - **kwargs): - super().__init__(model, output_dir, out_channels, output_dataset, save_segmentation, prediction_channel, - **kwargs) - - def _allocate_prediction_maps(self, output_shape, output_file): - # allocate datasets for probability maps - prediction_map = output_file.create_dataset(self.output_dataset, - shape=output_shape, - dtype='float32', - chunks=True, - compression='gzip') - # allocate datasets for normalization masks - normalization_mask = output_file.create_dataset('normalization', - shape=output_shape, - dtype='uint8', - chunks=True, - compression='gzip') - return prediction_map, normalization_mask - - def _save_results(self, prediction_map, normalization_mask, output_file, dataset): - z, y, x = prediction_map.shape[1:] - # take slices which are 1/27 of the original volume - patch_shape = (z // 3, y // 3, x // 3) - if self.save_segmentation: - output_file.create_dataset('segmentation', shape=(z, y, x), dtype='uint16', chunks=True, compression='gzip') - - for index in SliceBuilder._build_slices(prediction_map, patch_shape=patch_shape, stride_shape=patch_shape): - logger.info(f'Normalizing slice: {index}') - prediction_map[index] /= normalization_mask[index] - # make sure to reset the slice that has been visited already in order to avoid 'double' normalization - # when the patches overlap with each other - normalization_mask[index] = 1 - # save segmentation - if self.save_segmentation: - output_file['segmentation'][index[1:]] = np.argmax(prediction_map[index], axis=0).astype('uint16') - - del output_file['normalization'] - if self.save_segmentation: - del output_file[self.output_dataset] - - -class DSB2018Predictor(_AbstractPredictor): - def __init__(self, model, output_dir, config, save_segmentation=True, pmaps_thershold=0.5, **kwargs): - super().__init__(model, output_dir, config, **kwargs) - self.pmaps_threshold = pmaps_thershold - self.save_segmentation = save_segmentation - - def _slice_from_pad(self, pad): - if pad == 0: - return slice(None, None) - else: - return slice(pad, -pad) - - def __call__(self, test_loader): - # Sets the module in evaluation mode explicitly - self.model.eval() - # initial process pool for saving results to disk - executor = futures.ProcessPoolExecutor(max_workers=32) - # Run predictions on the entire input dataset - with torch.no_grad(): - for img, path in test_loader: - # send batch to gpu - if torch.cuda.is_available(): - img = img.cuda(non_blocking=True) - # forward pass - pred = self.model(img) - - executor.submit( - dsb_save_batch, - self.output_dir, - path - ) - - print('Waiting for all predictions to be saved to disk...') - executor.shutdown(wait=True) - - -def dsb_save_batch(output_dir, path, pred, save_segmentation=True, pmaps_thershold=0.5): - def _pmaps_to_seg(pred): - mask = (pred > pmaps_thershold) - return measure.label(mask).astype('uint16') - - # convert to numpy array - for single_pred, single_path in zip(pred, path): - logger.info(f'Processing {single_path}') - single_pred = single_pred.squeeze() - - # save to h5 file - out_file = os.path.splitext(single_path)[0] + '_predictions.h5' - if output_dir is not None: - out_file = os.path.join(output_dir, os.path.split(out_file)[1]) - - with h5py.File(out_file, 'w') as f: - # logger.info(f'Saving output to {out_file}') - f.create_dataset('predictions', data=single_pred, compression='gzip') - if save_segmentation: - f.create_dataset('segmentation', data=_pmaps_to_seg(single_pred), compression='gzip') diff --git a/build/lib/pytorch3dunet/unet3d/se.py b/build/lib/pytorch3dunet/unet3d/se.py deleted file mode 100644 index 23fac3d7..00000000 --- a/build/lib/pytorch3dunet/unet3d/se.py +++ /dev/null @@ -1,113 +0,0 @@ -""" -3D Squeeze and Excitation Modules -***************************** -3D Extensions of the following 2D squeeze and excitation blocks: - 1. `Channel Squeeze and Excitation `_ - 2. `Spatial Squeeze and Excitation `_ - 3. `Channel and Spatial Squeeze and Excitation `_ -New Project & Excite block, designed specifically for 3D inputs - 'quote' - Coded by -- Anne-Marie Rickmann (https://github.com/arickm) -""" - -import torch -from torch import nn as nn -from torch.nn import functional as F - - -class ChannelSELayer3D(nn.Module): - """ - 3D extension of Squeeze-and-Excitation (SE) block described in: - *Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507* - *Zhu et al., AnatomyNet, arXiv:arXiv:1808.05238* - """ - - def __init__(self, num_channels, reduction_ratio=2): - """ - Args: - num_channels (int): No of input channels - reduction_ratio (int): By how much should the num_channels should be reduced - """ - super(ChannelSELayer3D, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool3d(1) - num_channels_reduced = num_channels // reduction_ratio - self.reduction_ratio = reduction_ratio - self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True) - self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True) - self.relu = nn.ReLU() - self.sigmoid = nn.Sigmoid() - - def forward(self, x): - batch_size, num_channels, D, H, W = x.size() - # Average along each channel - squeeze_tensor = self.avg_pool(x) - - # channel excitation - fc_out_1 = self.relu(self.fc1(squeeze_tensor.view(batch_size, num_channels))) - fc_out_2 = self.sigmoid(self.fc2(fc_out_1)) - - output_tensor = torch.mul(x, fc_out_2.view(batch_size, num_channels, 1, 1, 1)) - - return output_tensor - - -class SpatialSELayer3D(nn.Module): - """ - 3D extension of SE block -- squeezing spatially and exciting channel-wise described in: - *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018* - """ - - def __init__(self, num_channels): - """ - Args: - num_channels (int): No of input channels - """ - super(SpatialSELayer3D, self).__init__() - self.conv = nn.Conv3d(num_channels, 1, 1) - self.sigmoid = nn.Sigmoid() - - def forward(self, x, weights=None): - """ - Args: - weights (torch.Tensor): weights for few shot learning - x: X, shape = (batch_size, num_channels, D, H, W) - - Returns: - (torch.Tensor): output_tensor - """ - # channel squeeze - batch_size, channel, D, H, W = x.size() - - if weights: - weights = weights.view(1, channel, 1, 1) - out = F.conv2d(x, weights) - else: - out = self.conv(x) - - squeeze_tensor = self.sigmoid(out) - - # spatial excitation - output_tensor = torch.mul(x, squeeze_tensor.view(batch_size, 1, D, H, W)) - - return output_tensor - - -class ChannelSpatialSELayer3D(nn.Module): - """ - 3D extension of concurrent spatial and channel squeeze & excitation: - *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, arXiv:1803.02579* - """ - - def __init__(self, num_channels, reduction_ratio=2): - """ - Args: - num_channels (int): No of input channels - reduction_ratio (int): By how much should the num_channels should be reduced - """ - super(ChannelSpatialSELayer3D, self).__init__() - self.cSE = ChannelSELayer3D(num_channels, reduction_ratio) - self.sSE = SpatialSELayer3D(num_channels) - - def forward(self, input_tensor): - output_tensor = torch.max(self.cSE(input_tensor), self.sSE(input_tensor)) - return output_tensor diff --git a/build/lib/pytorch3dunet/unet3d/seg_metrics.py b/build/lib/pytorch3dunet/unet3d/seg_metrics.py deleted file mode 100644 index e713ea23..00000000 --- a/build/lib/pytorch3dunet/unet3d/seg_metrics.py +++ /dev/null @@ -1,123 +0,0 @@ -import numpy as np -from skimage.metrics import contingency_table - - -def precision(tp, fp, fn): - return tp / (tp + fp) if tp > 0 else 0 - - -def recall(tp, fp, fn): - return tp / (tp + fn) if tp > 0 else 0 - - -def accuracy(tp, fp, fn): - return tp / (tp + fp + fn) if tp > 0 else 0 - - -def f1(tp, fp, fn): - return (2 * tp) / (2 * tp + fp + fn) if tp > 0 else 0 - - -def _relabel(input): - _, unique_labels = np.unique(input, return_inverse=True) - return unique_labels.reshape(input.shape) - - -def _iou_matrix(gt, seg): - # relabel gt and seg for smaller memory footprint of contingency table - gt = _relabel(gt) - seg = _relabel(seg) - - # get number of overlapping pixels between GT and SEG - n_inter = contingency_table(gt, seg).A - - # number of pixels for GT instances - n_gt = n_inter.sum(axis=1, keepdims=True) - # number of pixels for SEG instances - n_seg = n_inter.sum(axis=0, keepdims=True) - - # number of pixels in the union between GT and SEG instances - n_union = n_gt + n_seg - n_inter - - iou_matrix = n_inter / n_union - # make sure that the values are within [0,1] range - assert 0 <= np.min(iou_matrix) <= np.max(iou_matrix) <= 1 - - return iou_matrix - - -class SegmentationMetrics: - """ - Computes precision, recall, accuracy, f1 score for a given ground truth and predicted segmentation. - Contingency table for a given ground truth and predicted segmentation is computed eagerly upon construction - of the instance of `SegmentationMetrics`. - - Args: - gt (ndarray): ground truth segmentation - seg (ndarray): predicted segmentation - """ - - def __init__(self, gt, seg): - self.iou_matrix = _iou_matrix(gt, seg) - - def metrics(self, iou_threshold): - """ - Computes precision, recall, accuracy, f1 score at a given IoU threshold - """ - # ignore background - iou_matrix = self.iou_matrix[1:, 1:] - detection_matrix = (iou_matrix > iou_threshold).astype(np.uint8) - n_gt, n_seg = detection_matrix.shape - - # if the iou_matrix is empty or all values are 0 - trivial = min(n_gt, n_seg) == 0 or np.all(detection_matrix == 0) - if trivial: - tp = fp = fn = 0 - else: - # count non-zero rows to get the number of TP - tp = np.count_nonzero(detection_matrix.sum(axis=1)) - # count zero rows to get the number of FN - fn = n_gt - tp - # count zero columns to get the number of FP - fp = n_seg - np.count_nonzero(detection_matrix.sum(axis=0)) - - return { - 'precision': precision(tp, fp, fn), - 'recall': recall(tp, fp, fn), - 'accuracy': accuracy(tp, fp, fn), - 'f1': f1(tp, fp, fn) - } - - -class Accuracy: - """ - Computes accuracy between ground truth and predicted segmentation a a given threshold value. - Defined as: AC = TP / (TP + FP + FN). - Kaggle DSB2018 calls it Precision, see: - https://www.kaggle.com/stkbailey/step-by-step-explanation-of-scoring-metric. - """ - - def __init__(self, iou_threshold): - self.iou_threshold = iou_threshold - - def __call__(self, input_seg, gt_seg): - metrics = SegmentationMetrics(gt_seg, input_seg).metrics(self.iou_threshold) - return metrics['accuracy'] - - -class AveragePrecision: - """ - Average precision taken for the IoU range (0.5, 0.95) with a step of 0.05 as defined in: - https://www.kaggle.com/stkbailey/step-by-step-explanation-of-scoring-metric - """ - - def __init__(self): - self.iou_range = np.linspace(0.50, 0.95, 10) - - def __call__(self, input_seg, gt_seg): - # compute contingency_table - sm = SegmentationMetrics(gt_seg, input_seg) - # compute accuracy for each threshold - acc = [sm.metrics(iou)['accuracy'] for iou in self.iou_range] - # return the average - return np.mean(acc) diff --git a/build/lib/pytorch3dunet/unet3d/trainer.py b/build/lib/pytorch3dunet/unet3d/trainer.py deleted file mode 100644 index 4b59d568..00000000 --- a/build/lib/pytorch3dunet/unet3d/trainer.py +++ /dev/null @@ -1,404 +0,0 @@ -import os -import torch -import torch.nn as nn -from torch.optim.lr_scheduler import ReduceLROnPlateau -from torch.utils.tensorboard import SummaryWriter -from datetime import datetime - -from pytorch3dunet.datasets.utils import get_train_loaders -from pytorch3dunet.unet3d.losses import get_loss_criterion -from pytorch3dunet.unet3d.metrics import get_evaluation_metric -from pytorch3dunet.unet3d.model import get_model, UNet2D -from pytorch3dunet.unet3d.utils import get_logger, get_tensorboard_formatter, create_optimizer, \ - create_lr_scheduler, get_number_of_learnable_parameters -from . import utils - -logger = get_logger('UNetTrainer') - - -def create_trainer(config): - # Create the model - model = get_model(config['model']) - - if torch.cuda.device_count() > 1 and not config['device'] == 'cpu': - model = nn.DataParallel(model) - logger.info(f'Using {torch.cuda.device_count()} GPUs for prediction') - if torch.cuda.is_available() and not config['device'] == 'cpu': - model = model.cuda() - - # Log the number of learnable parameters - logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}') - - # Create loss criterion - loss_criterion = get_loss_criterion(config) - # Create evaluation metric - eval_criterion = get_evaluation_metric(config) - - # Create data loaders - loaders = get_train_loaders(config) - - # Create the optimizer - optimizer = create_optimizer(config['optimizer'], model) - - # Create learning rate adjustment strategy - lr_scheduler = create_lr_scheduler(config.get('lr_scheduler', None), optimizer) - - trainer_config = config['trainer'] - # Create tensorboard formatter - tensorboard_formatter = get_tensorboard_formatter(trainer_config.pop('tensorboard_formatter', None)) - # Create trainer - resume = trainer_config.pop('resume', None) - pre_trained = trainer_config.pop('pre_trained', None) - - return UNetTrainer(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, loss_criterion=loss_criterion, - eval_criterion=eval_criterion, loaders=loaders, tensorboard_formatter=tensorboard_formatter, - resume=resume, pre_trained=pre_trained, **trainer_config) - - -class UNetTrainer: - """UNet trainer. - - Args: - model (Unet3D): UNet 3D model to be trained - optimizer (nn.optim.Optimizer): optimizer used for training - lr_scheduler (torch.optim.lr_scheduler._LRScheduler): learning rate scheduler - WARN: bear in mind that lr_scheduler.step() is invoked after every validation step - (i.e. validate_after_iters) not after every epoch. So e.g. if one uses StepLR with step_size=30 - the learning rate will be adjusted after every 30 * validate_after_iters iterations. - loss_criterion (callable): loss function - eval_criterion (callable): used to compute training/validation metric (such as Dice, IoU, AP or Rand score) - saving the best checkpoint is based on the result of this function on the validation set - loaders (dict): 'train' and 'val' loaders - checkpoint_dir (string): dir for saving checkpoints and tensorboard logs - max_num_epochs (int): maximum number of epochs - max_num_iterations (int): maximum number of iterations - validate_after_iters (int): validate after that many iterations - log_after_iters (int): number of iterations before logging to tensorboard - validate_iters (int): number of validation iterations, if None validate - on the whole validation set - eval_score_higher_is_better (bool): if True higher eval scores are considered better - best_eval_score (float): best validation score so far (higher better) - num_iterations (int): useful when loading the model from the checkpoint - num_epoch (int): useful when loading the model from the checkpoint - tensorboard_formatter (callable): converts a given batch of input/output/target image to a series of images - that can be displayed in tensorboard - skip_train_validation (bool): if True eval_criterion is not evaluated on the training set (used mostly when - evaluation is expensive) - """ - - def __init__(self, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders, checkpoint_dir, - max_num_epochs, max_num_iterations, validate_after_iters=200, log_after_iters=100, validate_iters=None, - num_iterations=1, num_epoch=0, eval_score_higher_is_better=True, tensorboard_formatter=None, - skip_train_validation=False, resume=None, pre_trained=None, **kwargs): - - self.model = model - self.optimizer = optimizer - self.scheduler = lr_scheduler - self.loss_criterion = loss_criterion - self.eval_criterion = eval_criterion - self.loaders = loaders - self.checkpoint_dir = checkpoint_dir - self.max_num_epochs = max_num_epochs - self.max_num_iterations = max_num_iterations - self.validate_after_iters = validate_after_iters - self.log_after_iters = log_after_iters - self.validate_iters = validate_iters - self.eval_score_higher_is_better = eval_score_higher_is_better - - logger.info(model) - logger.info(f'eval_score_higher_is_better: {eval_score_higher_is_better}') - - # initialize the best_eval_score - if eval_score_higher_is_better: - self.best_eval_score = float('-inf') - else: - self.best_eval_score = float('+inf') - - self.writer = SummaryWriter( - log_dir=os.path.join( - checkpoint_dir, 'logs', - datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - ) - ) - - assert tensorboard_formatter is not None, 'TensorboardFormatter must be provided' - self.tensorboard_formatter = tensorboard_formatter - - self.num_iterations = num_iterations - self.num_epochs = num_epoch - self.skip_train_validation = skip_train_validation - - if resume is not None: - logger.info(f"Loading checkpoint '{resume}'...") - state = utils.load_checkpoint(resume, self.model, self.optimizer) - logger.info( - f"Checkpoint loaded from '{resume}'. Epoch: {state['num_epochs']}. Iteration: {state['num_iterations']}. " - f"Best val score: {state['best_eval_score']}." - ) - self.best_eval_score = state['best_eval_score'] - self.num_iterations = state['num_iterations'] - self.num_epochs = state['num_epochs'] - self.checkpoint_dir = os.path.split(resume)[0] - elif pre_trained is not None: - logger.info(f"Logging pre-trained model from '{pre_trained}'...") - utils.load_checkpoint(pre_trained, self.model, None) - if 'checkpoint_dir' not in kwargs: - self.checkpoint_dir = os.path.split(pre_trained)[0] - - def fit(self): - for _ in range(self.num_epochs, self.max_num_epochs): - # train for one epoch - should_terminate = self.train() - - if should_terminate: - logger.info('Stopping criterion is satisfied. Finishing training') - return - - self.num_epochs += 1 - logger.info(f"Reached maximum number of epochs: {self.max_num_epochs}. Finishing training...") - - def train(self): - """Trains the model for 1 epoch. - - Returns: - True if the training should be terminated immediately, False otherwise - """ - train_losses = utils.RunningAverage() - train_eval_scores = utils.RunningAverage() - - # sets the model in training mode - self.model.train() - - for t in self.loaders['train']: - logger.info(f'Training iteration [{self.num_iterations}/{self.max_num_iterations}]. ' - f'Epoch [{self.num_epochs}/{self.max_num_epochs - 1}]') - - input, target, weight = self._split_training_batch(t) - - output, loss = self._forward_pass(input, target, weight) - - train_losses.update(loss.item(), self._batch_size(input)) - - # compute gradients and update parameters - self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() - - if self.num_iterations % self.validate_after_iters == 0: - # set the model in eval mode - self.model.eval() - # evaluate on validation set - eval_score = self.validate() - # set the model back to training mode - self.model.train() - - # adjust learning rate if necessary - if isinstance(self.scheduler, ReduceLROnPlateau): - self.scheduler.step(eval_score) - elif self.scheduler is not None: - self.scheduler.step() - - # log current learning rate in tensorboard - self._log_lr() - # remember best validation metric - is_best = self._is_best_eval_score(eval_score) - - # save checkpoint - self._save_checkpoint(is_best) - - if self.num_iterations % self.log_after_iters == 0: - # compute eval criterion - if not self.skip_train_validation: - # apply final activation before calculating eval score - if isinstance(self.model, nn.DataParallel): - final_activation = self.model.module.final_activation - else: - final_activation = self.model.final_activation - - if final_activation is not None: - act_output = final_activation(output) - else: - act_output = output - eval_score = self.eval_criterion(act_output, target) - train_eval_scores.update(eval_score.item(), self._batch_size(input)) - - # log stats, params and images - logger.info( - f'Training stats. Loss: {train_losses.avg}. Evaluation score: {train_eval_scores.avg}') - self._log_stats('train', train_losses.avg, train_eval_scores.avg) - # self._log_params() - self._log_images(input, target, output, 'train_') - - if self.should_stop(): - return True - - self.num_iterations += 1 - - return False - - def should_stop(self): - """ - Training will terminate if maximum number of iterations is exceeded or the learning rate drops below - some predefined threshold (1e-6 in our case) - """ - if self.max_num_iterations < self.num_iterations: - logger.info(f'Maximum number of iterations {self.max_num_iterations} exceeded.') - return True - - min_lr = 1e-6 - lr = self.optimizer.param_groups[0]['lr'] - if lr < min_lr: - logger.info(f'Learning rate below the minimum {min_lr}.') - return True - - return False - - def validate(self): - logger.info('Validating...') - - val_losses = utils.RunningAverage() - val_scores = utils.RunningAverage() - - with torch.no_grad(): - for i, t in enumerate(self.loaders['val']): - logger.info(f'Validation iteration {i}') - - input, target, weight = self._split_training_batch(t) - - output, loss = self._forward_pass(input, target, weight) - val_losses.update(loss.item(), self._batch_size(input)) - - if i % 100 == 0: - self._log_images(input, target, output, 'val_') - - eval_score = self.eval_criterion(output, target) - val_scores.update(eval_score.item(), self._batch_size(input)) - - if self.validate_iters is not None and self.validate_iters <= i: - # stop validation - break - - self._log_stats('val', val_losses.avg, val_scores.avg) - logger.info(f'Validation finished. Loss: {val_losses.avg}. Evaluation score: {val_scores.avg}') - return val_scores.avg - - def _split_training_batch(self, t): - def _move_to_gpu(input): - if isinstance(input, tuple) or isinstance(input, list): - return tuple([_move_to_gpu(x) for x in input]) - else: - if torch.cuda.is_available(): - input = input.cuda(non_blocking=True) - return input - - t = _move_to_gpu(t) - weight = None - if len(t) == 2: - input, target = t - else: - input, target, weight = t - return input, target, weight - - def _forward_pass(self, input, target, weight=None): - if isinstance(self.model, UNet2D): - # remove the singleton z-dimension from the input - input = torch.squeeze(input, dim=-3) - # forward pass - output = self.model(input) - # add the singleton z-dimension to the output - output = torch.unsqueeze(output, dim=-3) - else: - # forward pass - output = self.model(input) - - # compute the loss - if weight is None: - loss = self.loss_criterion(output, target) - else: - loss = self.loss_criterion(output, target, weight) - - return output, loss - - def _is_best_eval_score(self, eval_score): - if self.eval_score_higher_is_better: - is_best = eval_score > self.best_eval_score - else: - is_best = eval_score < self.best_eval_score - - if is_best: - logger.info(f'Saving new best evaluation metric: {eval_score}') - self.best_eval_score = eval_score - - return is_best - - def _save_checkpoint(self, is_best): - # remove `module` prefix from layer names when using `nn.DataParallel` - # see: https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/20 - if isinstance(self.model, nn.DataParallel): - state_dict = self.model.module.state_dict() - else: - state_dict = self.model.state_dict() - - last_file_path = os.path.join(self.checkpoint_dir, 'last_checkpoint.pytorch') - logger.info(f"Saving checkpoint to '{last_file_path}'") - - utils.save_checkpoint({ - 'num_epochs': self.num_epochs + 1, - 'num_iterations': self.num_iterations, - 'model_state_dict': state_dict, - 'best_eval_score': self.best_eval_score, - 'optimizer_state_dict': self.optimizer.state_dict(), - }, is_best, checkpoint_dir=self.checkpoint_dir) - - def _log_lr(self): - lr = self.optimizer.param_groups[0]['lr'] - self.writer.add_scalar('learning_rate', lr, self.num_iterations) - - def _log_stats(self, phase, loss_avg, eval_score_avg): - tag_value = { - f'{phase}_loss_avg': loss_avg, - f'{phase}_eval_score_avg': eval_score_avg - } - - for tag, value in tag_value.items(): - self.writer.add_scalar(tag, value, self.num_iterations) - - def _log_params(self): - logger.info('Logging model parameters and gradients') - for name, value in self.model.named_parameters(): - self.writer.add_histogram(name, value.data.cpu().numpy(), self.num_iterations) - self.writer.add_histogram(name + '/grad', value.grad.data.cpu().numpy(), self.num_iterations) - - def _log_images(self, input, target, prediction, prefix=''): - - if isinstance(self.model, nn.DataParallel): - net = self.model.module - else: - net = self.model - - if net.final_activation is not None: - prediction = net.final_activation(prediction) - - inputs_map = { - 'inputs': input, - 'targets': target, - 'predictions': prediction - } - img_sources = {} - for name, batch in inputs_map.items(): - if isinstance(batch, list) or isinstance(batch, tuple): - for i, b in enumerate(batch): - img_sources[f'{name}{i}'] = b.data.cpu().numpy() - else: - img_sources[name] = batch.data.cpu().numpy() - - for name, batch in img_sources.items(): - for tag, image in self.tensorboard_formatter(name, batch): - self.writer.add_image(prefix + tag, image, self.num_iterations) - - @staticmethod - def _batch_size(input): - if isinstance(input, list) or isinstance(input, tuple): - return input[0].size(0) - else: - return input.size(0) diff --git a/build/lib/pytorch3dunet/unet3d/utils.py b/build/lib/pytorch3dunet/unet3d/utils.py deleted file mode 100644 index 01d5559c..00000000 --- a/build/lib/pytorch3dunet/unet3d/utils.py +++ /dev/null @@ -1,366 +0,0 @@ -import importlib -import logging -import os -import shutil -import sys - -import h5py -import numpy as np -import torch -from torch import optim - - -def save_checkpoint(state, is_best, checkpoint_dir): - """Saves model and training parameters at '{checkpoint_dir}/last_checkpoint.pytorch'. - If is_best==True saves '{checkpoint_dir}/best_checkpoint.pytorch' as well. - - Args: - state (dict): contains model's state_dict, optimizer's state_dict, epoch - and best evaluation metric value so far - is_best (bool): if True state contains the best model seen so far - checkpoint_dir (string): directory where the checkpoint are to be saved - """ - - if not os.path.exists(checkpoint_dir): - os.mkdir(checkpoint_dir) - - last_file_path = os.path.join(checkpoint_dir, 'last_checkpoint.pytorch') - torch.save(state, last_file_path) - if is_best: - best_file_path = os.path.join(checkpoint_dir, 'best_checkpoint.pytorch') - shutil.copyfile(last_file_path, best_file_path) - - -def load_checkpoint(checkpoint_path, model, optimizer=None, - model_key='model_state_dict', optimizer_key='optimizer_state_dict'): - """Loads model and training parameters from a given checkpoint_path - If optimizer is provided, loads optimizer's state_dict of as well. - - Args: - checkpoint_path (string): path to the checkpoint to be loaded - model (torch.nn.Module): model into which the parameters are to be copied - optimizer (torch.optim.Optimizer) optional: optimizer instance into - which the parameters are to be copied - - Returns: - state - """ - if not os.path.exists(checkpoint_path): - raise IOError(f"Checkpoint '{checkpoint_path}' does not exist") - - state = torch.load(checkpoint_path, map_location='cpu') - model.load_state_dict(state[model_key]) - - if optimizer is not None: - optimizer.load_state_dict(state[optimizer_key]) - - return state - - -def save_network_output(output_path, output, logger=None): - if logger is not None: - logger.info(f'Saving network output to: {output_path}...') - output = output.detach().cpu()[0] - with h5py.File(output_path, 'w') as f: - f.create_dataset('predictions', data=output, compression='gzip') - - -loggers = {} - - -def get_logger(name, level=logging.INFO): - global loggers - if loggers.get(name) is not None: - return loggers[name] - else: - logger = logging.getLogger(name) - logger.setLevel(level) - # Logging to console - stream_handler = logging.StreamHandler(sys.stdout) - formatter = logging.Formatter( - '%(asctime)s [%(threadName)s] %(levelname)s %(name)s - %(message)s') - stream_handler.setFormatter(formatter) - logger.addHandler(stream_handler) - - loggers[name] = logger - - return logger - - -def get_number_of_learnable_parameters(model): - return sum(p.numel() for p in model.parameters() if p.requires_grad) - - -class RunningAverage: - """Computes and stores the average - """ - - def __init__(self): - self.count = 0 - self.sum = 0 - self.avg = 0 - - def update(self, value, n=1): - self.count += n - self.sum += value * n - self.avg = self.sum / self.count - - -def number_of_features_per_level(init_channel_number, num_levels): - return [init_channel_number * 2 ** k for k in range(num_levels)] - - -class _TensorboardFormatter: - """ - Tensorboard formatters converts a given batch of images (be it input/output to the network or the target segmentation - image) to a series of images that can be displayed in tensorboard. This is the parent class for all tensorboard - formatters which ensures that returned images are in the 'CHW' format. - """ - - def __init__(self, **kwargs): - pass - - def __call__(self, name, batch): - """ - Transform a batch to a series of tuples of the form (tag, img), where `tag` corresponds to the image tag - and `img` is the image itself. - - Args: - name (str): one of 'inputs'/'targets'/'predictions' - batch (torch.tensor): 4D or 5D torch tensor - """ - - def _check_img(tag_img): - tag, img = tag_img - - assert img.ndim == 2 or img.ndim == 3, 'Only 2D (HW) and 3D (CHW) images are accepted for display' - - if img.ndim == 2: - img = np.expand_dims(img, axis=0) - else: - C = img.shape[0] - assert C == 1 or C == 3, 'Only (1, H, W) or (3, H, W) images are supported' - - return tag, img - - tagged_images = self.process_batch(name, batch) - - return list(map(_check_img, tagged_images)) - - def process_batch(self, name, batch): - raise NotImplementedError - - -class DefaultTensorboardFormatter(_TensorboardFormatter): - def __init__(self, skip_last_target=False, **kwargs): - super().__init__(**kwargs) - self.skip_last_target = skip_last_target - - def process_batch(self, name, batch): - if name == 'targets' and self.skip_last_target: - batch = batch[:, :-1, ...] - - tag_template = '{}/batch_{}/channel_{}/slice_{}' - - tagged_images = [] - - if batch.ndim == 5: - # NCDHW - slice_idx = batch.shape[2] // 2 # get the middle slice - for batch_idx in range(batch.shape[0]): - for channel_idx in range(batch.shape[1]): - tag = tag_template.format(name, batch_idx, channel_idx, slice_idx) - img = batch[batch_idx, channel_idx, slice_idx, ...] - tagged_images.append((tag, self._normalize_img(img))) - else: - # batch has no channel dim: NDHW - slice_idx = batch.shape[1] // 2 # get the middle slice - for batch_idx in range(batch.shape[0]): - tag = tag_template.format(name, batch_idx, 0, slice_idx) - img = batch[batch_idx, slice_idx, ...] - tagged_images.append((tag, self._normalize_img(img))) - - return tagged_images - - @staticmethod - def _normalize_img(img): - return np.nan_to_num((img - np.min(img)) / np.ptp(img)) - - -def _find_masks(batch, min_size=10): - """Center the z-slice in the 'middle' of a given instance, given a batch of instances - - Args: - batch (ndarray): 5d numpy tensor (NCDHW) - """ - result = [] - for b in batch: - assert b.shape[0] == 1 - patch = b[0] - z_sum = patch.sum(axis=(1, 2)) - coords = np.where(z_sum > min_size)[0] - if len(coords) > 0: - ind = coords[len(coords) // 2] - result.append(b[:, ind:ind + 1, ...]) - else: - ind = b.shape[1] // 2 - result.append(b[:, ind:ind + 1, ...]) - - return np.stack(result, axis=0) - - -def get_tensorboard_formatter(formatter_config): - if formatter_config is None: - return DefaultTensorboardFormatter() - - class_name = formatter_config['name'] - m = importlib.import_module('pytorch3dunet.unet3d.utils') - clazz = getattr(m, class_name) - return clazz(**formatter_config) - - -def expand_as_one_hot(input, C, ignore_index=None): - """ - Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector. - It is assumed that the batch dimension is present. - Args: - input (torch.Tensor): 3D/4D input image - C (int): number of channels/labels - ignore_index (int): ignore index to be kept during the expansion - Returns: - 4D/5D output torch.Tensor (NxCxSPATIAL) - """ - assert input.dim() == 4 - - # expand the input tensor to Nx1xSPATIAL before scattering - input = input.unsqueeze(1) - # create output tensor shape (NxCxSPATIAL) - shape = list(input.size()) - shape[1] = C - - if ignore_index is not None: - # create ignore_index mask for the result - mask = input.expand(shape) == ignore_index - # clone the src tensor and zero out ignore_index in the input - input = input.clone() - input[input == ignore_index] = 0 - # scatter to get the one-hot tensor - result = torch.zeros(shape).to(input.device).scatter_(1, input, 1) - # bring back the ignore_index in the result - result[mask] = ignore_index - return result - else: - # scatter to get the one-hot tensor - return torch.zeros(shape).to(input.device).scatter_(1, input, 1) - - -def convert_to_numpy(*inputs): - """ - Coverts input tensors to numpy ndarrays - - Args: - inputs (iteable of torch.Tensor): torch tensor - - Returns: - tuple of ndarrays - """ - - def _to_numpy(i): - assert isinstance(i, torch.Tensor), "Expected input to be torch.Tensor" - return i.detach().cpu().numpy() - - return (_to_numpy(i) for i in inputs) - - -def create_optimizer(optimizer_config, model): - optim_name = optimizer_config.get('name', 'Adam') - # common optimizer settings - learning_rate = optimizer_config.get('learning_rate', 1e-3) - weight_decay = optimizer_config.get('weight_decay', 0) - - # grab optimizer specific settings and init - # optimizer - if optim_name == 'Adadelta': - rho = optimizer_config.get('rho', 0.9) - optimizer = optim.Adadelta(model.parameters(), lr=learning_rate, rho=rho, - weight_decay=weight_decay) - elif optim_name == 'Adagrad': - lr_decay = optimizer_config.get('lr_decay', 0) - optimizer = optim.Adagrad(model.parameters(), lr=learning_rate, lr_decay=lr_decay, - weight_decay=weight_decay) - elif optim_name == 'AdamW': - betas = tuple(optimizer_config.get('betas', (0.9, 0.999))) - optimizer = optim.AdamW(model.parameters(), lr=learning_rate, betas=betas, - weight_decay=weight_decay) - elif optim_name == 'SparseAdam': - betas = tuple(optimizer_config.get('betas', (0.9, 0.999))) - optimizer = optim.SparseAdam(model.parameters(), lr=learning_rate, betas=betas) - elif optim_name == 'Adamax': - betas = tuple(optimizer_config.get('betas', (0.9, 0.999))) - optimizer = optim.Adamax(model.parameters(), lr=learning_rate, betas=betas, - weight_decay=weight_decay) - elif optim_name == 'ASGD': - lambd = optimizer_config.get('lambd', 0.0001) - alpha = optimizer_config.get('alpha', 0.75) - t0 = optimizer_config.get('t0', 1e6) - optimizer = optim.Adamax(model.parameters(), lr=learning_rate, lambd=lambd, - alpha=alpha, t0=t0, weight_decay=weight_decay) - elif optim_name == 'LBFGS': - max_iter = optimizer_config.get('max_iter', 20) - max_eval = optimizer_config.get('max_eval', None) - tolerance_grad = optimizer_config.get('tolerance_grad', 1e-7) - tolerance_change = optimizer_config.get('tolerance_change', 1e-9) - history_size = optimizer_config.get('history_size', 100) - optimizer = optim.LBFGS(model.parameters(), lr=learning_rate, max_iter=max_iter, - max_eval=max_eval, tolerance_grad=tolerance_grad, - tolerance_change=tolerance_change, history_size=history_size) - elif optim_name == 'NAdam': - betas = tuple(optimizer_config.get('betas', (0.9, 0.999))) - momentum_decay = optimizer_config.get('momentum_decay', 4e-3) - optimizer = optim.NAdam(model.parameters(), lr=learning_rate, betas=betas, - momentum_decay=momentum_decay, - weight_decay=weight_decay) - elif optim_name == 'RAdam': - betas = tuple(optimizer_config.get('betas', (0.9, 0.999))) - optimizer = optim.RAdam(model.parameters(), lr=learning_rate, betas=betas, - weight_decay=weight_decay) - elif optim_name == 'RMSprop': - alpha = optimizer_config.get('alpha', 0.99) - optimizer = optim.RMSprop(model.parameters(), lr=learning_rate, alpha=alpha, - weight_decay=weight_decay) - elif optim_name == 'Rprop': - momentum = optimizer_config.get('momentum', 0) - optimizer = optim.RMSprop(model.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=momentum) - elif optim_name == 'SGD': - momentum = optimizer_config.get('momentum', 0) - dampening = optimizer_config.get('dampening', 0) - nesterov = optimizer_config.get('nesterov', False) - optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, - dampening=dampening, nesterov=nesterov, - weight_decay=weight_decay) - else: # Adam is default - betas = tuple(optimizer_config.get('betas', (0.9, 0.999))) - optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=betas, - weight_decay=weight_decay) - - return optimizer - - -def create_lr_scheduler(lr_config, optimizer): - if lr_config is None: - return None - class_name = lr_config.pop('name') - m = importlib.import_module('torch.optim.lr_scheduler') - clazz = getattr(m, class_name) - # add optimizer to the config - lr_config['optimizer'] = optimizer - return clazz(**lr_config) - - -def get_class(class_name, modules): - for module in modules: - m = importlib.import_module(module) - clazz = getattr(m, class_name, None) - if clazz is not None: - return clazz - raise RuntimeError(f'Unsupported dataset class: {class_name}') From e2143ecd491bfd694ecb41d6aa03491ec8d882d9 Mon Sep 17 00:00:00 2001 From: Shota Mizusaki Date: Tue, 30 Jul 2024 20:32:38 +0900 Subject: [PATCH 4/4] Since csr_matrix does not have .A, change to toarray(). --- pytorch3dunet/unet3d/seg_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch3dunet/unet3d/seg_metrics.py b/pytorch3dunet/unet3d/seg_metrics.py index e713ea23..ddafe59d 100644 --- a/pytorch3dunet/unet3d/seg_metrics.py +++ b/pytorch3dunet/unet3d/seg_metrics.py @@ -29,7 +29,7 @@ def _iou_matrix(gt, seg): seg = _relabel(seg) # get number of overlapping pixels between GT and SEG - n_inter = contingency_table(gt, seg).A + n_inter = contingency_table(gt, seg).toarray() # number of pixels for GT instances n_gt = n_inter.sum(axis=1, keepdims=True)