Skip to content

Commit f7737e7

Browse files
authored
Merge pull request wolny#108 from hzdr-MedImaging/feature-channelwise-Normalize
implemented channelwise support for Normalize transform.
2 parents d919e1a + a740089 commit f7737e7

File tree

1 file changed

+39
-10
lines changed

1 file changed

+39
-10
lines changed

pytorch3dunet/augment/transforms.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -546,28 +546,57 @@ def __call__(self, m):
546546

547547
class Normalize:
548548
"""
549-
Apply simple min-max scaling to a given input tensor, i.e. shrinks the range of the data in a fixed range of [-1, 1].
549+
Apply simple min-max scaling to a given input tensor, i.e. shrinks the range of the data
550+
in a fixed range of [-1, 1] or in case of norm01==True to [0, 1]. In addition, data can be
551+
clipped by specifying min_value/max_value either globally using single values or via a
552+
list/tuple channelwise if enabled.
550553
"""
551554

552-
def __init__(self, min_value=None, max_value=None, norm01=False, eps=1e-10, **kwargs):
555+
def __init__(self, min_value=None, max_value=None, norm01=False, channelwise=False,
556+
eps=1e-10, **kwargs):
553557
if min_value is not None and max_value is not None:
554-
assert max_value > min_value
558+
assert max_value > min_value
555559
self.min_value = min_value
556560
self.max_value = max_value
557561
self.norm01 = norm01
562+
self.channelwise = channelwise
558563
self.eps = eps
559564

560565
def __call__(self, m):
561-
if self.min_value is None:
562-
min_value = np.min(m)
566+
if self.channelwise:
567+
# get min/max channelwise
568+
axes = list(range(m.ndim))
569+
axes = tuple(axes[1:])
570+
if self.min_value is None or 'None' in self.min_value:
571+
min_value = np.min(m, axis=axes, keepdims=True)
572+
573+
if self.max_value is None or 'None' in self.max_value:
574+
max_value = np.max(m, axis=axes, keepdims=True)
575+
576+
# check if non None in self.min_value/self.max_value
577+
# if present and if so copy value to min_value
578+
if self.min_value is not None:
579+
for i,v in enumerate(self.min_value):
580+
if v != 'None':
581+
min_value[i] = v
582+
583+
if self.max_value is not None:
584+
for i,v in enumerate(self.max_value):
585+
if v != 'None':
586+
max_value[i] = v
563587
else:
564-
min_value = self.min_value
588+
if self.min_value is None:
589+
min_value = np.min(m)
590+
else:
591+
min_value = self.min_value
565592

566-
if self.max_value is None:
567-
max_value = np.max(m)
568-
else:
569-
max_value = self.max_value
593+
if self.max_value is None:
594+
max_value = np.max(m)
595+
else:
596+
max_value = self.max_value
570597

598+
# calculate norm_0_1 with min_value / max_value with the same dimension
599+
# in case of channelwise application
571600
norm_0_1 = (m - min_value) / (max_value - min_value + self.eps)
572601

573602
if self.norm01 is True:

0 commit comments

Comments
 (0)