Skip to content

Commit

Permalink
revised Normalize function according to reviewer comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
jens-maus committed Nov 27, 2023
1 parent aaa860a commit 12b0b66
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions pytorch3dunet/augment/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,22 +549,27 @@ class Normalize:
Apply simple min-max scaling to a given input tensor, i.e. shrinks the range of the data in a fixed range of [-1, 1].
"""

def __init__(self, min_value=None, max_value=None, norm01=False, **kwargs):
def __init__(self, min_value=None, max_value=None, norm01=False, eps=1e-10, **kwargs):
if min_value is not None and max_value is not None:
assert max_value > min_value
self.min_value = min_value
self.max_value = max_value
self.norm01 = norm01
self.eps = eps

def __call__(self, m):
if self.min_value is None:
self.min_value = np.min(m)
min_value = np.min(m)
else:
min_value = self.min_value

if self.max_value is None:
self.max_value = np.max(m)
self.value_range = self.max_value - self.min_value
if self.value_range == 0:
self.value_range += 10**-100
norm_0_1 = (m - self.min_value) / self.value_range
max_value = np.max(m)
else:
max_value = self.max_value

norm_0_1 = (m - min_value) / (max_value - min_value + self.eps)

if self.norm01 is True:
return np.clip(norm_0_1, 0, 1)
else:
Expand Down

0 comments on commit 12b0b66

Please sign in to comment.