Skip to content

Commit

Permalink
fix: normalise in place to reduce memory (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
japols authored Jan 21, 2025
1 parent 600f01e commit 40dd1a1
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions models/src/anemoi/models/preprocessing/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,12 @@ def transform(
x = x.clone()

if data_index is not None:
x[..., :] = x[..., :] * self._norm_mul[data_index] + self._norm_add[data_index]
x.mul_(self._norm_mul[data_index]).add_(self._norm_add[data_index])
elif x.shape[-1] == len(self._input_idx):
x[..., :] = x[..., :] * self._norm_mul[self._input_idx] + self._norm_add[self._input_idx]
x.mul_(self._norm_mul[self._input_idx]).add_(self._norm_add[self._input_idx])
else:
x[..., :] = x[..., :] * self._norm_mul + self._norm_add
x.mul_(self._norm_mul).add_(self._norm_add)

return x

def inverse_transform(
Expand Down Expand Up @@ -197,9 +198,9 @@ def inverse_transform(
# input and predicted tensors have different shapes
# hence, we mask out the forcing indices
if data_index is not None:
x[..., :] = (x[..., :] - self._norm_add[data_index]) / self._norm_mul[data_index]
x.subtract_(self._norm_add[data_index]).div_(self._norm_mul[data_index])
elif x.shape[-1] == len(self._output_idx):
x[..., :] = (x[..., :] - self._norm_add[self._output_idx]) / self._norm_mul[self._output_idx]
x.subtract_(self._norm_add[self._output_idx]).div_(self._norm_mul[self._output_idx])
else:
x[..., :] = (x[..., :] - self._norm_add) / self._norm_mul
x.subtract_(self._norm_add).div_(self._norm_mul)
return x

0 comments on commit 40dd1a1

Please sign in to comment.