@@ -546,28 +546,57 @@ def __call__(self, m):
546546
547547class Normalize :
548548 """
549- Apply simple min-max scaling to a given input tensor, i.e. shrinks the range of the data in a fixed range of [-1, 1].
549+ Apply simple min-max scaling to a given input tensor, i.e. shrinks the range of the data
550+ in a fixed range of [-1, 1] or in case of norm01==True to [0, 1]. In addition, data can be
551+ clipped by specifying min_value/max_value either globally using single values or via a
552+ list/tuple channelwise if enabled.
550553 """
551554
552- def __init__ (self , min_value = None , max_value = None , norm01 = False , eps = 1e-10 , ** kwargs ):
555+ def __init__ (self , min_value = None , max_value = None , norm01 = False , channelwise = False ,
556+ eps = 1e-10 , ** kwargs ):
553557 if min_value is not None and max_value is not None :
554- assert max_value > min_value
558+ assert max_value > min_value
555559 self .min_value = min_value
556560 self .max_value = max_value
557561 self .norm01 = norm01
562+ self .channelwise = channelwise
558563 self .eps = eps
559564
560565 def __call__ (self , m ):
561- if self .min_value is None :
562- min_value = np .min (m )
566+ if self .channelwise :
567+ # get min/max channelwise
568+ axes = list (range (m .ndim ))
569+ axes = tuple (axes [1 :])
570+ if self .min_value is None or 'None' in self .min_value :
571+ min_value = np .min (m , axis = axes , keepdims = True )
572+
573+ if self .max_value is None or 'None' in self .max_value :
574+ max_value = np .max (m , axis = axes , keepdims = True )
575+
576+ # check if non None in self.min_value/self.max_value
577+ # if present and if so copy value to min_value
578+ if self .min_value is not None :
579+ for i ,v in enumerate (self .min_value ):
580+ if v != 'None' :
581+ min_value [i ] = v
582+
583+ if self .max_value is not None :
584+ for i ,v in enumerate (self .max_value ):
585+ if v != 'None' :
586+ max_value [i ] = v
563587 else :
564- min_value = self .min_value
588+ if self .min_value is None :
589+ min_value = np .min (m )
590+ else :
591+ min_value = self .min_value
565592
566- if self .max_value is None :
567- max_value = np .max (m )
568- else :
569- max_value = self .max_value
593+ if self .max_value is None :
594+ max_value = np .max (m )
595+ else :
596+ max_value = self .max_value
570597
598+ # calculate norm_0_1 with min_value / max_value with the same dimension
599+ # in case of channelwise application
571600 norm_0_1 = (m - min_value ) / (max_value - min_value + self .eps )
572601
573602 if self .norm01 is True :
0 commit comments