diff --git a/pytorch3dunet/augment/transforms.py b/pytorch3dunet/augment/transforms.py index f47061b9..527d596b 100644 --- a/pytorch3dunet/augment/transforms.py +++ b/pytorch3dunet/augment/transforms.py @@ -546,28 +546,57 @@ def __call__(self, m): class Normalize: """ - Apply simple min-max scaling to a given input tensor, i.e. shrinks the range of the data in a fixed range of [-1, 1]. + Apply simple min-max scaling to a given input tensor, i.e. shrinks the range of the data + in a fixed range of [-1, 1] or in case of norm01==True to [0, 1]. In addition, data can be + clipped by specifying min_value/max_value either globally using single values or via a + list/tuple channelwise if enabled. """ - def __init__(self, min_value=None, max_value=None, norm01=False, eps=1e-10, **kwargs): + def __init__(self, min_value=None, max_value=None, norm01=False, channelwise=False, + eps=1e-10, **kwargs): if min_value is not None and max_value is not None: - assert max_value > min_value + assert max_value > min_value self.min_value = min_value self.max_value = max_value self.norm01 = norm01 + self.channelwise = channelwise self.eps = eps def __call__(self, m): - if self.min_value is None: - min_value = np.min(m) + if self.channelwise: + # get min/max channelwise + axes = list(range(m.ndim)) + axes = tuple(axes[1:]) + if self.min_value is None or 'None' in self.min_value: + min_value = np.min(m, axis=axes, keepdims=True) + + if self.max_value is None or 'None' in self.max_value: + max_value = np.max(m, axis=axes, keepdims=True) + + # check if non None in self.min_value/self.max_value + # if present and if so copy value to min_value + if self.min_value is not None: + for i,v in enumerate(self.min_value): + if v != 'None': + min_value[i] = v + + if self.max_value is not None: + for i,v in enumerate(self.max_value): + if v != 'None': + max_value[i] = v else: - min_value = self.min_value + if self.min_value is None: + min_value = np.min(m) + else: + min_value = self.min_value - if self.max_value is None: - max_value = np.max(m) - else: - max_value = self.max_value + if self.max_value is None: + max_value = np.max(m) + else: + max_value = self.max_value + # calculate norm_0_1 with min_value / max_value with the same dimension + # in case of channelwise application norm_0_1 = (m - min_value) / (max_value - min_value + self.eps) if self.norm01 is True: