diff --git a/torchstain/numpy/normalizers/reinhard.py b/torchstain/numpy/normalizers/reinhard.py index d299dd3..035c49f 100644 --- a/torchstain/numpy/normalizers/reinhard.py +++ b/torchstain/numpy/normalizers/reinhard.py @@ -25,7 +25,9 @@ def fit(self, target): lab = rgb2lab(target) # get summary statistics - stack_ = np.array([get_mean_std(x) for x in lab_split(lab)]) +# stack_ = np.apply_along_axis(get_mean_std, 1, lab_split(lab)) + stack_ = np.apply_along_axis(get_mean_std, axis=1, arr=lab_split(lab)) + self.target_means = stack_[:, 0] self.target_stds = stack_[:, 1] @@ -38,7 +40,9 @@ def normalize(self, I): labs = lab_split(lab) # get summary statistics from LAB - stack_ = np.array([get_mean_std(x) for x in labs]) +# stack_ = np.apply_along_axis(get_mean_std, 1, labs) + stack_ = np.apply_along_axis(get_mean_std, axis=1, arr=labs) + mus = stack_[:, 0] stds = stack_[:, 1] @@ -62,7 +66,7 @@ def normalize(self, I): else: raise ValueError("Unsupported 'method' was chosen. Choose either {None, 'modified'}.") - + # rebuild LAB lab = lab_merge(*result)