diff --git a/pytorch3dunet/augment/transforms.py b/pytorch3dunet/augment/transforms.py index 8d57ea6e..f47061b9 100644 --- a/pytorch3dunet/augment/transforms.py +++ b/pytorch3dunet/augment/transforms.py @@ -549,14 +549,31 @@ class Normalize: Apply simple min-max scaling to a given input tensor, i.e. shrinks the range of the data in a fixed range of [-1, 1]. """ - def __init__(self, min_value, max_value, **kwargs): - assert max_value > min_value + def __init__(self, min_value=None, max_value=None, norm01=False, eps=1e-10, **kwargs): + if min_value is not None and max_value is not None: + assert max_value > min_value self.min_value = min_value - self.value_range = max_value - min_value + self.max_value = max_value + self.norm01 = norm01 + self.eps = eps def __call__(self, m): - norm_0_1 = (m - self.min_value) / self.value_range - return np.clip(2 * norm_0_1 - 1, -1, 1) + if self.min_value is None: + min_value = np.min(m) + else: + min_value = self.min_value + + if self.max_value is None: + max_value = np.max(m) + else: + max_value = self.max_value + + norm_0_1 = (m - min_value) / (max_value - min_value + self.eps) + + if self.norm01 is True: + return np.clip(norm_0_1, 0, 1) + else: + return np.clip(2 * norm_0_1 - 1, -1, 1) class AdditiveGaussianNoise: