Skip to content

Commit

Permalink
Initialize properly self.norm_min_val – Remove comment from test_boun…
Browse files Browse the repository at this point in the history
…ding.py
  • Loading branch information
lzampier committed Jan 28, 2025
1 parent e47d86b commit cd24351
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
5 changes: 3 additions & 2 deletions models/src/anemoi/models/layers/bounding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Optional

import torch
import numpy as np
from torch import nn

from anemoi.models.data_indices.tensor import InputTensorIndex
Expand Down Expand Up @@ -137,7 +138,7 @@ def __init__(

# Compute normalized min values
self.data_index = [name_to_index[var] for var in variables]
norm_min_val = torch.zeros(len(variables))
self.norm_min_val = torch.zeros(len(variables))
for ii, variable in enumerate(variables):
stat_index = self.name_to_index_stats[variable]
if self.normalizer[ii] == "mean-std":
Expand All @@ -156,7 +157,7 @@ def __init__(
self.norm_min_val[ii] = min_val[ii] / std

# Reorder normalized min values based on data_index
self.norm_min_val = norm_min_val[torch.argsort(torch.tensor(self.data_index))]
self.norm_min_val = self.norm_min_val[np.argsort(np.array(self.data_index))]

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the ReLU activation with the normalized minimum values to the input tensor.
Expand Down
1 change: 0 additions & 1 deletion models/tests/layers/test_bounding.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def test_normalized_relu_bounding(config, name_to_index, name_to_index_stats, in
name_to_index_stats=name_to_index_stats,
)
output = bounding(input_tensor.clone())
# breakpoint()
expected_output = torch.tensor([[2.0, 2.0, 3.0], [4.0, 0.1111, 6.0], [2.0, 0.5, 0.5]])
assert torch.allclose(output, expected_output, atol=1e-4)

Expand Down

0 comments on commit cd24351

Please sign in to comment.