From dfae71eb72c33a32dccbf4ddbf8ee0dddbd77cf7 Mon Sep 17 00:00:00 2001 From: Lorenzo Zampieri Date: Mon, 16 Dec 2024 15:35:00 +0000 Subject: [PATCH 01/10] Implementation of NormalizedReluBounding for all the types of normalizations. --- src/anemoi/models/layers/bounding.py | 102 +++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/src/anemoi/models/layers/bounding.py b/src/anemoi/models/layers/bounding.py index b168591d..c0e8e459 100644 --- a/src/anemoi/models/layers/bounding.py +++ b/src/anemoi/models/layers/bounding.py @@ -11,6 +11,7 @@ from abc import ABC from abc import abstractmethod +from typing import Optional import torch from torch import nn @@ -30,12 +31,28 @@ def __init__( *, variables: list[str], name_to_index: dict, + statistics: Optional[dict] = None, + name_to_index_stats: Optional[dict] = None, ) -> None: + """Initializes the bounding strategy. + Parameters + ---------- + variables : list[str] + A list of strings representing the variables that will be bounded. + name_to_index : dict + A dictionary mapping the variable names to their corresponding indices. + statistics : dict, optional + A dictionary containing the statistics of the variables. + name_to_index_stats : dict, optional + A dictionary mapping the variable names to their corresponding indices in the statistics dictionary + """ super().__init__() self.name_to_index = name_to_index self.variables = variables self.data_index = self._create_index(variables=self.variables) + self.statistics = statistics + self.name_to_index_stats = name_to_index_stats def _create_index(self, variables: list[str]) -> InputTensorIndex: return InputTensorIndex(includes=variables, excludes=[], name_to_index=self.name_to_index)._only @@ -63,7 +80,92 @@ class ReluBounding(BaseBounding): def forward(self, x: torch.Tensor) -> torch.Tensor: x[..., self.data_index] = torch.nn.functional.relu(x[..., self.data_index]) return x + + +class NormalizedReluBounding(BaseBounding): + """Bounding variable with a ReLU activation and customizable normalized thresholds.""" + def __init__( + self, + *, + variables: list[str], + name_to_index: dict, + min_val: list[float], + normalizer: list[str], + statistics: dict, + name_to_index_stats: dict, + ) -> None: + """Initializes the NormalizedReluBounding with the specified parameters. + + Parameters + ---------- + variables : list[str] + A list of strings representing the variables that will be bounded. + name_to_index : dict + A dictionary mapping the variable names to their corresponding indices. + statistics : dict + A dictionary containing the statistics of the variables (mean, std, min, max, etc.). + min_val : list[float] + The minimum values for the ReLU activation. It should be given in the same order as the variables. + normalizer : list[str] + A list of normalization types to apply, one per variable. Options: 'mean-std', 'min-max', 'max', 'std'. + name_to_index_stats : dict + A dictionary mapping the variable names to their corresponding indices in the statistics dictionary. + """ + super().__init__( + variables=variables, + name_to_index=name_to_index, + statistics=statistics, + name_to_index_stats=name_to_index_stats, + ) + self.min_val = min_val + self.normalizer = normalizer + + # Validate normalizer input + if not all(norm in {"mean-std", "min-max", "max", "std"} for norm in self.normalizer): + raise ValueError("Each normalizer must be one of: 'mean-std', 'min-max', 'max', 'std' in NormalizedReluBounding.") + if len(self.normalizer) != len(variables): + raise ValueError("The length of the normalizer list must match the number of variables in NormalizedReluBounding.") + if len(self.min_val) != len(variables): + raise ValueError("The length of the min_val list must match the number of variables in NormalizedReluBounding.") + + self.norm_min_val = torch.zeros(len(variables)) + for ii, variable in enumerate(variables): + stat_index = self.name_to_index_stats[variable] + if self.normalizer[ii] == "mean-std": + mean = self.statistics["mean"][stat_index] + std = self.statistics["stdev"][stat_index] + self.norm_min_val[ii] = (min_val[ii] - mean) / std + elif self.normalizer[ii] == "min-max": + min_stat = self.statistics["min"][stat_index] + max_stat = self.statistics["max"][stat_index] + self.norm_min_val[ii] = (min_val[ii] - min_stat) / (max_stat - min_stat) + elif self.normalizer[ii] == "max": + max_stat = self.statistics["max"][stat_index] + self.norm_min_val[ii] = min_val[ii] / max_stat + elif self.normalizer[ii] == "std": + std = self.statistics["stdev"][stat_index] + self.norm_min_val[ii] = min_val[ii] / std + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies the ReLU activation with the normalized minimum values to the input tensor. + + Parameters + ---------- + x : torch.Tensor + The input tensor to process. + + Returns + ------- + torch.Tensor + The processed tensor with bounding applied. + """ + self.norm_min_val = self.norm_min_val.to(x.device) + x[..., self.data_index] = ( + torch.nn.functional.relu(x[..., self.data_index] - self.norm_min_val) + + self.norm_min_val + ) + return x class HardtanhBounding(BaseBounding): """Initializes the bounding with specified minimum and maximum values for bounding. From 54f4ef83521ef3390bdb50299e368a906eade698 Mon Sep 17 00:00:00 2001 From: Lorenzo Zampieri Date: Mon, 16 Dec 2024 16:04:16 +0000 Subject: [PATCH 02/10] Modify instantiation of bounding function to have statistics --- src/anemoi/models/models/encoder_processor_decoder.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index c67c8c03..08c0c59c 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -100,7 +100,12 @@ def __init__( # Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite) self.boundings = nn.ModuleList( [ - instantiate(cfg, name_to_index=self.data_indices.internal_model.output.name_to_index) + instantiate( + cfg, + name_to_index=self.data_indices.internal_model.output.name_to_index, + statistics=self.statistics, + name_to_index_stats=self.data_indices.data.input.name_to_index, + ) for cfg in getattr(model_config.model, "bounding", []) ] ) From abb0afd7dd3a0a3fca66b7747688b706a8201159 Mon Sep 17 00:00:00 2001 From: Lorenzo Zampieri Date: Mon, 16 Dec 2024 16:29:58 +0000 Subject: [PATCH 03/10] Fix some issues and typos --- src/anemoi/models/interface/__init__.py | 1 + src/anemoi/models/layers/bounding.py | 21 ++++++++++++------- .../models/encoder_processor_decoder.py | 2 ++ 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/anemoi/models/interface/__init__.py b/src/anemoi/models/interface/__init__.py index 261dec29..8e9a568d 100644 --- a/src/anemoi/models/interface/__init__.py +++ b/src/anemoi/models/interface/__init__.py @@ -87,6 +87,7 @@ def _build_model(self) -> None: self.config.model.model, model_config=self.config, data_indices=self.data_indices, + statistics=self.statistics, graph_data=self.graph_data, _recursive_=False, # Disables recursive instantiation by Hydra ) diff --git a/src/anemoi/models/layers/bounding.py b/src/anemoi/models/layers/bounding.py index c0e8e459..738b4af6 100644 --- a/src/anemoi/models/layers/bounding.py +++ b/src/anemoi/models/layers/bounding.py @@ -123,11 +123,17 @@ def __init__( # Validate normalizer input if not all(norm in {"mean-std", "min-max", "max", "std"} for norm in self.normalizer): - raise ValueError("Each normalizer must be one of: 'mean-std', 'min-max', 'max', 'std' in NormalizedReluBounding.") + raise ValueError( + "Each normalizer must be one of: 'mean-std', 'min-max', 'max', 'std' in NormalizedReluBounding." + ) if len(self.normalizer) != len(variables): - raise ValueError("The length of the normalizer list must match the number of variables in NormalizedReluBounding.") + raise ValueError( + "The length of the normalizer list must match the number of variables in NormalizedReluBounding." + ) if len(self.min_val) != len(variables): - raise ValueError("The length of the min_val list must match the number of variables in NormalizedReluBounding.") + raise ValueError( + "The length of the min_val list must match the number of variables in NormalizedReluBounding." + ) self.norm_min_val = torch.zeros(len(variables)) for ii, variable in enumerate(variables): @@ -162,8 +168,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ self.norm_min_val = self.norm_min_val.to(x.device) x[..., self.data_index] = ( - torch.nn.functional.relu(x[..., self.data_index] - self.norm_min_val) - + self.norm_min_val + torch.nn.functional.relu(x[..., self.data_index] - self.norm_min_val) + self.norm_min_val ) return x @@ -182,7 +187,7 @@ class HardtanhBounding(BaseBounding): The maximum value for the HardTanh activation. """ - def __init__(self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float) -> None: + def __init__(self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float, statistics: Optional[dict] = None, name_to_index_stats: Optional[dict] = None,) -> None: super().__init__(variables=variables, name_to_index=name_to_index) self.min_val = min_val self.max_val = max_val @@ -213,8 +218,8 @@ class FractionBounding(HardtanhBounding): """ def __init__( - self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float, total_var: str - ) -> None: + self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float, total_var: str, statistics: Optional[dict] = None, name_to_index_stats: Optional[dict] = None, + ) -> None: super().__init__(variables=variables, name_to_index=name_to_index, min_val=min_val, max_val=max_val) self.total_variable = self._create_index(variables=[total_var]) diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index 08c0c59c..532fff6b 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -35,6 +35,7 @@ def __init__( *, model_config: DotDict, data_indices: dict, + statistics: dict, graph_data: HeteroData, ) -> None: """Initializes the graph neural network. @@ -57,6 +58,7 @@ def __init__( self._calculate_shapes_and_indices(data_indices) self._assert_matching_indices(data_indices) self.data_indices = data_indices + self.statistics = statistics self.multi_step = model_config.training.multistep_input self.num_channels = model_config.model.num_channels From c808dcfc9ba1d2151d75a891837e985c48db2382 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 Dec 2024 16:32:41 +0000 Subject: [PATCH 04/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anemoi/models/layers/bounding.py | 32 ++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/src/anemoi/models/layers/bounding.py b/src/anemoi/models/layers/bounding.py index 738b4af6..50a0d08e 100644 --- a/src/anemoi/models/layers/bounding.py +++ b/src/anemoi/models/layers/bounding.py @@ -80,7 +80,7 @@ class ReluBounding(BaseBounding): def forward(self, x: torch.Tensor) -> torch.Tensor: x[..., self.data_index] = torch.nn.functional.relu(x[..., self.data_index]) return x - + class NormalizedReluBounding(BaseBounding): """Bounding variable with a ReLU activation and customizable normalized thresholds.""" @@ -125,15 +125,15 @@ def __init__( if not all(norm in {"mean-std", "min-max", "max", "std"} for norm in self.normalizer): raise ValueError( "Each normalizer must be one of: 'mean-std', 'min-max', 'max', 'std' in NormalizedReluBounding." - ) + ) if len(self.normalizer) != len(variables): raise ValueError( "The length of the normalizer list must match the number of variables in NormalizedReluBounding." - ) + ) if len(self.min_val) != len(variables): raise ValueError( "The length of the min_val list must match the number of variables in NormalizedReluBounding." - ) + ) self.norm_min_val = torch.zeros(len(variables)) for ii, variable in enumerate(variables): @@ -172,6 +172,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) return x + class HardtanhBounding(BaseBounding): """Initializes the bounding with specified minimum and maximum values for bounding. @@ -187,7 +188,16 @@ class HardtanhBounding(BaseBounding): The maximum value for the HardTanh activation. """ - def __init__(self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float, statistics: Optional[dict] = None, name_to_index_stats: Optional[dict] = None,) -> None: + def __init__( + self, + *, + variables: list[str], + name_to_index: dict, + min_val: float, + max_val: float, + statistics: Optional[dict] = None, + name_to_index_stats: Optional[dict] = None, + ) -> None: super().__init__(variables=variables, name_to_index=name_to_index) self.min_val = min_val self.max_val = max_val @@ -218,8 +228,16 @@ class FractionBounding(HardtanhBounding): """ def __init__( - self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float, total_var: str, statistics: Optional[dict] = None, name_to_index_stats: Optional[dict] = None, - ) -> None: + self, + *, + variables: list[str], + name_to_index: dict, + min_val: float, + max_val: float, + total_var: str, + statistics: Optional[dict] = None, + name_to_index_stats: Optional[dict] = None, + ) -> None: super().__init__(variables=variables, name_to_index=name_to_index, min_val=min_val, max_val=max_val) self.total_variable = self._create_index(variables=[total_var]) From bc20548c05525da4ae083a957fd4fbf0b31d06c7 Mon Sep 17 00:00:00 2001 From: Lorenzo Zampieri Date: Tue, 17 Dec 2024 11:00:33 +0000 Subject: [PATCH 05/10] Add tests for normalized relu bounding --- tests/layers/test_bounding.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/layers/test_bounding.py b/tests/layers/test_bounding.py index eddbae2d..29fb3b8d 100644 --- a/tests/layers/test_bounding.py +++ b/tests/layers/test_bounding.py @@ -15,7 +15,8 @@ from anemoi.models.layers.bounding import FractionBounding from anemoi.models.layers.bounding import HardtanhBounding from anemoi.models.layers.bounding import ReluBounding - +from anemoi.models.layers.bounding import NormalizedReluBounding +import numpy as np @pytest.fixture def config(): @@ -31,6 +32,15 @@ def name_to_index(): def input_tensor(): return torch.tensor([[-1.0, 2.0, 3.0], [4.0, -5.0, 6.0], [0.5, 0.5, 0.5]]) +@pytest.fixture +def statistics(): + statistics = { + "mean": np.array([1.0, 2.0, 3.0]), + "stdev": np.array([0.5, 0.5, 0.5]), + "minimum": np.array([1.0, 1.0, 1.0]), + "maximum": np.array([11.0, 10.0, 10.0]), + } + return statistics def test_relu_bounding(config, name_to_index, input_tensor): bounding = ReluBounding(variables=config.variables, name_to_index=name_to_index) @@ -38,6 +48,13 @@ def test_relu_bounding(config, name_to_index, input_tensor): expected_output = torch.tensor([[0.0, 2.0, 3.0], [4.0, 0.0, 6.0], [0.5, 0.5, 0.5]]) assert torch.equal(output, expected_output) +def test_normalized_relu_bounding(config, name_to_index, input_tensor, statistics): + bounding = NormalizedReluBounding(variables=config.variables, name_to_index=name_to_index, min_val= [2.0, 2.0], normalizer= ["mean-std","min-max"], statistics= statistics) + output = bounding(input_tensor.clone()) + breakpoint() + expected_output = torch.tensor([[2.0, 2.0, 3.0], [4.0, 0.0, 6.0], [0.5, 0.5, 0.5]]) + assert torch.equal(output, expected_output) + def test_hardtanh_bounding(config, name_to_index, input_tensor): minimum, maximum = -1.0, 1.0 From be8f32f528b34176e46b7253391adb257378e37d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Dec 2024 11:04:02 +0000 Subject: [PATCH 06/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/layers/test_bounding.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/layers/test_bounding.py b/tests/layers/test_bounding.py index 29fb3b8d..f2022534 100644 --- a/tests/layers/test_bounding.py +++ b/tests/layers/test_bounding.py @@ -7,6 +7,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import numpy as np import pytest import torch from anemoi.utils.config import DotDict @@ -14,9 +15,9 @@ from anemoi.models.layers.bounding import FractionBounding from anemoi.models.layers.bounding import HardtanhBounding -from anemoi.models.layers.bounding import ReluBounding from anemoi.models.layers.bounding import NormalizedReluBounding -import numpy as np +from anemoi.models.layers.bounding import ReluBounding + @pytest.fixture def config(): @@ -32,6 +33,7 @@ def name_to_index(): def input_tensor(): return torch.tensor([[-1.0, 2.0, 3.0], [4.0, -5.0, 6.0], [0.5, 0.5, 0.5]]) + @pytest.fixture def statistics(): statistics = { @@ -42,14 +44,22 @@ def statistics(): } return statistics + def test_relu_bounding(config, name_to_index, input_tensor): bounding = ReluBounding(variables=config.variables, name_to_index=name_to_index) output = bounding(input_tensor.clone()) expected_output = torch.tensor([[0.0, 2.0, 3.0], [4.0, 0.0, 6.0], [0.5, 0.5, 0.5]]) assert torch.equal(output, expected_output) + def test_normalized_relu_bounding(config, name_to_index, input_tensor, statistics): - bounding = NormalizedReluBounding(variables=config.variables, name_to_index=name_to_index, min_val= [2.0, 2.0], normalizer= ["mean-std","min-max"], statistics= statistics) + bounding = NormalizedReluBounding( + variables=config.variables, + name_to_index=name_to_index, + min_val=[2.0, 2.0], + normalizer=["mean-std", "min-max"], + statistics=statistics, + ) output = bounding(input_tensor.clone()) breakpoint() expected_output = torch.tensor([[2.0, 2.0, 3.0], [4.0, 0.0, 6.0], [0.5, 0.5, 0.5]]) From 28cd77e79bb361a0e12db086343573537d8c7f99 Mon Sep 17 00:00:00 2001 From: Lorenzo Zampieri Date: Tue, 7 Jan 2025 16:17:08 +0000 Subject: [PATCH 07/10] Fix tests for norm_relu_bounding normalization --- models/tests/layers/test_bounding.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/models/tests/layers/test_bounding.py b/models/tests/layers/test_bounding.py index c3b46d59..0bce83c7 100644 --- a/models/tests/layers/test_bounding.py +++ b/models/tests/layers/test_bounding.py @@ -28,6 +28,9 @@ def config(): def name_to_index(): return {"var1": 0, "var2": 1, "total_var": 2} +@pytest.fixture +def name_to_index_stats(): + return {"var1": 0, "var2": 1, "total_var": 2} @pytest.fixture def input_tensor(): @@ -39,8 +42,8 @@ def statistics(): statistics = { "mean": np.array([1.0, 2.0, 3.0]), "stdev": np.array([0.5, 0.5, 0.5]), - "minimum": np.array([1.0, 1.0, 1.0]), - "maximum": np.array([11.0, 10.0, 10.0]), + "min": np.array([1.0, 1.0, 1.0]), + "max": np.array([11.0, 10.0, 10.0]), } return statistics @@ -52,18 +55,19 @@ def test_relu_bounding(config, name_to_index, input_tensor): assert torch.equal(output, expected_output) -def test_normalized_relu_bounding(config, name_to_index, input_tensor, statistics): +def test_normalized_relu_bounding(config, name_to_index, name_to_index_stats, input_tensor, statistics): bounding = NormalizedReluBounding( variables=config.variables, name_to_index=name_to_index, min_val=[2.0, 2.0], normalizer=["mean-std", "min-max"], statistics=statistics, + name_to_index_stats=name_to_index_stats, ) output = bounding(input_tensor.clone()) - breakpoint() - expected_output = torch.tensor([[2.0, 2.0, 3.0], [4.0, 0.0, 6.0], [0.5, 0.5, 0.5]]) - assert torch.equal(output, expected_output) + #breakpoint() + expected_output = torch.tensor([[2.0, 2.0, 3.0], [4.0, 0.1111, 6.0], [2.0, 0.5, 0.5]]) + assert torch.allclose(output, expected_output, atol=1e-4) def test_hardtanh_bounding(config, name_to_index, input_tensor): From a38e3dec5781eff94e901d1f1f91517b646e532f Mon Sep 17 00:00:00 2001 From: Lorenzo Zampieri Date: Fri, 10 Jan 2025 14:44:14 +0000 Subject: [PATCH 08/10] Commented example for NormalizedReluBounding added to the config --- training/src/anemoi/training/config/model/gnn.yaml | 11 +++++++++++ .../training/config/model/graphtransformer.yaml | 11 +++++++++++ .../anemoi/training/config/model/transformer.yaml | 13 +++++++++++++ 3 files changed, 35 insertions(+) diff --git a/training/src/anemoi/training/config/model/gnn.yaml b/training/src/anemoi/training/config/model/gnn.yaml index 92a17fd4..66e761ce 100644 --- a/training/src/anemoi/training/config/model/gnn.yaml +++ b/training/src/anemoi/training/config/model/gnn.yaml @@ -67,3 +67,14 @@ bounding: #These are applied in order # min_val: 0 # max_val: 1 # total_var: tp + + # [OPTIONAL] NormalizedReluBounding + # This is an extension of the Relu bounding in case the thrshold to be used + # is not 0. For example, in case of the sea surface temperature we don't use + # [0, infinity), buth rather [-2C, infinity). We do not want the water + # temperature to be below the freezing temperature. + + # - _target_: anemoi.models.layers.bounding.NormalizedReluBounding + # variables: [sst] + # min_val: [-2] + # normalizer: ['mean-std'] \ No newline at end of file diff --git a/training/src/anemoi/training/config/model/graphtransformer.yaml b/training/src/anemoi/training/config/model/graphtransformer.yaml index 9c48967b..44cfa029 100644 --- a/training/src/anemoi/training/config/model/graphtransformer.yaml +++ b/training/src/anemoi/training/config/model/graphtransformer.yaml @@ -72,3 +72,14 @@ bounding: #These are applied in order # min_val: 0 # max_val: 1 # total_var: tp + + # [OPTIONAL] NormalizedReluBounding + # This is an extension of the Relu bounding in case the thrshold to be used + # is not 0. For example, in case of the sea surface temperature we don't use + # [0, infinity), buth rather [-2C, infinity). We do not want the water + # temperature to be below the freezing temperature. + + # - _target_: anemoi.models.layers.bounding.NormalizedReluBounding + # variables: [sst] + # min_val: [-2] + # normalizer: ['mean-std'] diff --git a/training/src/anemoi/training/config/model/transformer.yaml b/training/src/anemoi/training/config/model/transformer.yaml index cd6a1e7b..643ed878 100644 --- a/training/src/anemoi/training/config/model/transformer.yaml +++ b/training/src/anemoi/training/config/model/transformer.yaml @@ -71,3 +71,16 @@ bounding: #These are applied in order # min_val: 0 # max_val: 1 # total_var: tp + + # [OPTIONAL] NormalizedReluBounding + # This is an extension of the Relu bounding in case the thrshold to be used + # is not 0. For example, in case of the sea surface temperature we don't use + # [0, infinity), buth rather [-2C, infinity). We do not want the water + # temperature to be below the freezing temperature. + + # - _target_: anemoi.models.layers.bounding.NormalizedReluBounding + # variables: [sst] + # min_val: [-2] + # normalizer: ['mean-std'] + + From 22345215c2401e96456107d4294bfd901fb5d2e5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Jan 2025 14:44:51 +0000 Subject: [PATCH 09/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- models/tests/layers/test_bounding.py | 4 +++- training/src/anemoi/training/config/model/gnn.yaml | 6 +++--- .../src/anemoi/training/config/model/graphtransformer.yaml | 4 ++-- training/src/anemoi/training/config/model/transformer.yaml | 6 ++---- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/models/tests/layers/test_bounding.py b/models/tests/layers/test_bounding.py index 0bce83c7..a7f188be 100644 --- a/models/tests/layers/test_bounding.py +++ b/models/tests/layers/test_bounding.py @@ -28,10 +28,12 @@ def config(): def name_to_index(): return {"var1": 0, "var2": 1, "total_var": 2} + @pytest.fixture def name_to_index_stats(): return {"var1": 0, "var2": 1, "total_var": 2} + @pytest.fixture def input_tensor(): return torch.tensor([[-1.0, 2.0, 3.0], [4.0, -5.0, 6.0], [0.5, 0.5, 0.5]]) @@ -65,7 +67,7 @@ def test_normalized_relu_bounding(config, name_to_index, name_to_index_stats, in name_to_index_stats=name_to_index_stats, ) output = bounding(input_tensor.clone()) - #breakpoint() + # breakpoint() expected_output = torch.tensor([[2.0, 2.0, 3.0], [4.0, 0.1111, 6.0], [2.0, 0.5, 0.5]]) assert torch.allclose(output, expected_output, atol=1e-4) diff --git a/training/src/anemoi/training/config/model/gnn.yaml b/training/src/anemoi/training/config/model/gnn.yaml index 66e761ce..e9bb3686 100644 --- a/training/src/anemoi/training/config/model/gnn.yaml +++ b/training/src/anemoi/training/config/model/gnn.yaml @@ -70,11 +70,11 @@ bounding: #These are applied in order # [OPTIONAL] NormalizedReluBounding # This is an extension of the Relu bounding in case the thrshold to be used - # is not 0. For example, in case of the sea surface temperature we don't use - # [0, infinity), buth rather [-2C, infinity). We do not want the water + # is not 0. For example, in case of the sea surface temperature we don't use + # [0, infinity), buth rather [-2C, infinity). We do not want the water # temperature to be below the freezing temperature. # - _target_: anemoi.models.layers.bounding.NormalizedReluBounding # variables: [sst] # min_val: [-2] - # normalizer: ['mean-std'] \ No newline at end of file + # normalizer: ['mean-std'] diff --git a/training/src/anemoi/training/config/model/graphtransformer.yaml b/training/src/anemoi/training/config/model/graphtransformer.yaml index 44cfa029..715b819d 100644 --- a/training/src/anemoi/training/config/model/graphtransformer.yaml +++ b/training/src/anemoi/training/config/model/graphtransformer.yaml @@ -75,8 +75,8 @@ bounding: #These are applied in order # [OPTIONAL] NormalizedReluBounding # This is an extension of the Relu bounding in case the thrshold to be used - # is not 0. For example, in case of the sea surface temperature we don't use - # [0, infinity), buth rather [-2C, infinity). We do not want the water + # is not 0. For example, in case of the sea surface temperature we don't use + # [0, infinity), buth rather [-2C, infinity). We do not want the water # temperature to be below the freezing temperature. # - _target_: anemoi.models.layers.bounding.NormalizedReluBounding diff --git a/training/src/anemoi/training/config/model/transformer.yaml b/training/src/anemoi/training/config/model/transformer.yaml index 643ed878..bcdfeaaf 100644 --- a/training/src/anemoi/training/config/model/transformer.yaml +++ b/training/src/anemoi/training/config/model/transformer.yaml @@ -74,13 +74,11 @@ bounding: #These are applied in order # [OPTIONAL] NormalizedReluBounding # This is an extension of the Relu bounding in case the thrshold to be used - # is not 0. For example, in case of the sea surface temperature we don't use - # [0, infinity), buth rather [-2C, infinity). We do not want the water + # is not 0. For example, in case of the sea surface temperature we don't use + # [0, infinity), buth rather [-2C, infinity). We do not want the water # temperature to be below the freezing temperature. # - _target_: anemoi.models.layers.bounding.NormalizedReluBounding # variables: [sst] # min_val: [-2] # normalizer: ['mean-std'] - - From a74be0a1538246179c35e569ee18397b74801bbc Mon Sep 17 00:00:00 2001 From: Lorenzo Zampieri Date: Mon, 20 Jan 2025 16:15:29 +0000 Subject: [PATCH 10/10] Add commit description in CHANGELOG.md --- models/CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/models/CHANGELOG.md b/models/CHANGELOG.md index ba61ebe3..9490c593 100644 --- a/models/CHANGELOG.md +++ b/models/CHANGELOG.md @@ -18,6 +18,7 @@ Keep it human-readable, your future self will thank you! - Reduced memory usage when using chunking in the mapper [#84](https://github.com/ecmwf/anemoi-models/pull/84) - Added `supporting_arrays` argument, which contains arrays to store in checkpoints. [#97](https://github.com/ecmwf/anemoi-models/pull/97) - Add remappers, e.g. link functions to apply during training to facilitate learning of variables with a difficult distribution [#88](https://github.com/ecmwf/anemoi-models/pull/88) +- Add Normalized Relu Bounding for minimum bounding thresholds different than 0 [#64](https://github.com/ecmwf/anemoi-core/pull/64) ## [0.4.0](https://github.com/ecmwf/anemoi-models/compare/0.3.0...0.4.0) - Improvements to Model Design