From 5bc2450ee553ccf675b10bdefcffbec62010d811 Mon Sep 17 00:00:00 2001 From: Jens Maus Date: Mon, 15 Jan 2024 11:24:36 +0100 Subject: [PATCH 1/3] implemented channelwise support for Normalize transform. Now, Normalize can be applied to multiclass/multichannel data with normalization between min/max values either derived from each class/channel separately or by specifying min/max arrays so that data clipping can be applied at the same time. --- pytorch3dunet/augment/transforms.py | 49 +++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/pytorch3dunet/augment/transforms.py b/pytorch3dunet/augment/transforms.py index f47061b9..527d596b 100644 --- a/pytorch3dunet/augment/transforms.py +++ b/pytorch3dunet/augment/transforms.py @@ -546,28 +546,57 @@ def __call__(self, m): class Normalize: """ - Apply simple min-max scaling to a given input tensor, i.e. shrinks the range of the data in a fixed range of [-1, 1]. + Apply simple min-max scaling to a given input tensor, i.e. shrinks the range of the data + in a fixed range of [-1, 1] or in case of norm01==True to [0, 1]. In addition, data can be + clipped by specifying min_value/max_value either globally using single values or via a + list/tuple channelwise if enabled. """ - def __init__(self, min_value=None, max_value=None, norm01=False, eps=1e-10, **kwargs): + def __init__(self, min_value=None, max_value=None, norm01=False, channelwise=False, + eps=1e-10, **kwargs): if min_value is not None and max_value is not None: - assert max_value > min_value + assert max_value > min_value self.min_value = min_value self.max_value = max_value self.norm01 = norm01 + self.channelwise = channelwise self.eps = eps def __call__(self, m): - if self.min_value is None: - min_value = np.min(m) + if self.channelwise: + # get min/max channelwise + axes = list(range(m.ndim)) + axes = tuple(axes[1:]) + if self.min_value is None or 'None' in self.min_value: + min_value = np.min(m, axis=axes, keepdims=True) + + if self.max_value is None or 'None' in self.max_value: + max_value = np.max(m, axis=axes, keepdims=True) + + # check if non None in self.min_value/self.max_value + # if present and if so copy value to min_value + if self.min_value is not None: + for i,v in enumerate(self.min_value): + if v != 'None': + min_value[i] = v + + if self.max_value is not None: + for i,v in enumerate(self.max_value): + if v != 'None': + max_value[i] = v else: - min_value = self.min_value + if self.min_value is None: + min_value = np.min(m) + else: + min_value = self.min_value - if self.max_value is None: - max_value = np.max(m) - else: - max_value = self.max_value + if self.max_value is None: + max_value = np.max(m) + else: + max_value = self.max_value + # calculate norm_0_1 with min_value / max_value with the same dimension + # in case of channelwise application norm_0_1 = (m - min_value) / (max_value - min_value + self.eps) if self.norm01 is True: From cfeeb28291a10ec3e00a87b97507907a06126171 Mon Sep 17 00:00:00 2001 From: Jens Maus Date: Mon, 15 Apr 2024 10:18:36 +0200 Subject: [PATCH 2/3] fix None value use for channelwise normalization. --- pytorch3dunet/augment/transforms.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch3dunet/augment/transforms.py b/pytorch3dunet/augment/transforms.py index 527d596b..0dc41f57 100644 --- a/pytorch3dunet/augment/transforms.py +++ b/pytorch3dunet/augment/transforms.py @@ -567,22 +567,22 @@ def __call__(self, m): # get min/max channelwise axes = list(range(m.ndim)) axes = tuple(axes[1:]) - if self.min_value is None or 'None' in self.min_value: + if self.min_value is None: min_value = np.min(m, axis=axes, keepdims=True) - if self.max_value is None or 'None' in self.max_value: + if self.max_value is None: max_value = np.max(m, axis=axes, keepdims=True) # check if non None in self.min_value/self.max_value # if present and if so copy value to min_value if self.min_value is not None: for i,v in enumerate(self.min_value): - if v != 'None': + if v is not None: min_value[i] = v if self.max_value is not None: for i,v in enumerate(self.max_value): - if v != 'None': + if v is not None: max_value[i] = v else: if self.min_value is None: From 41493f5e2836e85204f29c9fe20f7b53f1702c1d Mon Sep 17 00:00:00 2001 From: Jens Maus Date: Mon, 15 Apr 2024 13:55:56 +0200 Subject: [PATCH 3/3] Revert "fix None value use for channelwise normalization." This reverts commit 79b8df219231ece92606cbf08e4c9cedfdbd7636. --- pytorch3dunet/augment/transforms.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch3dunet/augment/transforms.py b/pytorch3dunet/augment/transforms.py index 0dc41f57..527d596b 100644 --- a/pytorch3dunet/augment/transforms.py +++ b/pytorch3dunet/augment/transforms.py @@ -567,22 +567,22 @@ def __call__(self, m): # get min/max channelwise axes = list(range(m.ndim)) axes = tuple(axes[1:]) - if self.min_value is None: + if self.min_value is None or 'None' in self.min_value: min_value = np.min(m, axis=axes, keepdims=True) - if self.max_value is None: + if self.max_value is None or 'None' in self.max_value: max_value = np.max(m, axis=axes, keepdims=True) # check if non None in self.min_value/self.max_value # if present and if so copy value to min_value if self.min_value is not None: for i,v in enumerate(self.min_value): - if v is not None: + if v != 'None': min_value[i] = v if self.max_value is not None: for i,v in enumerate(self.max_value): - if v is not None: + if v != 'None': max_value[i] = v else: if self.min_value is None: