Skip to content

Commit beffa06

Browse files
authored
Feature/improve loss functions (#70)
* Dynamic loss function initialisation - Seperate config for loss and metrics - Added: MAE, RMSE, LogCosh, Huber, CombinedLoss - Refactored to provide abstract BaseWeightedLoss - Added ScaleTensor to enable arbitary scalings
1 parent fe90a9d commit beffa06

File tree

16 files changed

+1622
-88
lines changed

16 files changed

+1622
-88
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ Keep it human-readable, your future self will thank you!
1212

1313
## [0.2.2 - Maintenance: pin python <3.13](https://github.com/ecmwf/anemoi-training/compare/0.2.1...0.2.2) - 2024-10-28
1414

15+
### Added
16+
- Included more loss functions and allowed configuration [#70](https://github.com/ecmwf/anemoi-training/pull/70)
17+
1518
### Changed
1619

1720
- Lock python version <3.13 [#107](https://github.com/ecmwf/anemoi-training/pull/107)

docs/modules/losses.rst

Lines changed: 136 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,153 @@
33
########
44

55
This module is used to define the loss function used to train the model.
6+
7+
Anemoi-training exposes a couple of loss functions by default to be
8+
used, all of which are subclassed from ``BaseWeightedLoss``. This class
9+
enables scalar multiplication, and graph node weighting.
10+
11+
.. automodule:: anemoi.training.losses.weightedloss
12+
:members:
13+
:no-undoc-members:
14+
:show-inheritance:
15+
16+
************************
17+
Default Loss Functions
18+
************************
19+
620
By default anemoi-training trains the model using a latitude-weighted
721
mean-squared-error, which is defined in the ``WeightedMSELoss`` class in
8-
``aifs/losses/mse.py``.
22+
``anemoi/training/losses/mse.py``. The loss function can be configured
23+
in the config file at ``config.training.training_loss``, and
24+
``config.training.validation_metrics``.
25+
26+
The following loss functions are available by default:
27+
28+
- ``WeightedMSELoss``: Latitude-weighted mean-squared-error.
29+
- ``WeightedMAELoss``: Latitude-weighted mean-absolute-error.
30+
- ``WeightedHuberLoss``: Latitude-weighted Huber loss.
31+
- ``WeightedLogCoshLoss``: Latitude-weighted log-cosh loss.
32+
- ``WeightedRMSELoss``: Latitude-weighted root-mean-squared-error.
33+
- ``CombinedLoss``: Combined component weighted loss.
34+
35+
These are available in the ``anemoi.training.losses`` module, at
36+
``anemoi.training.losses.{short_name}.{class_name}``.
37+
38+
So for example, to use the ``WeightedMSELoss`` class, you would
39+
reference it in the config as follows:
40+
41+
.. code:: yaml
42+
43+
# loss function for the model
44+
training_loss:
45+
# loss class to initialise
46+
_target_: anemoi.training.losses.mse.WeightedMSELoss
47+
# loss function kwargs here
48+
49+
*********
50+
Scalars
51+
*********
52+
53+
In addition to node scaling, the loss function can also be scaled by a
54+
scalar. These are provided by the ``Forecaster`` class, and a user can
55+
define whether to include them in the loss function by setting
56+
``scalars`` in the loss config dictionary.
57+
58+
.. code:: yaml
59+
60+
# loss function for the model
61+
training_loss:
62+
# loss class to initialise
63+
_target_: anemoi.training.losses.mse.WeightedMSELoss
64+
scalars: ['scalar1', 'scalar2']
65+
66+
Currently, the following scalars are available for use:
67+
68+
- ``variable``: Scale by the feature/variable weights as defined in the
69+
config ``config.training.loss_scaling``.
970

10-
The user can define their own loss function using the same structure as
11-
the ``WeightedMSELoss`` class.
71+
********************
72+
Validation Metrics
73+
********************
1274

13-
.. automodule:: anemoi.training.losses.mse
75+
Validation metrics as defined in the config file at
76+
``config.training.validation_metrics`` follow the same initialise
77+
behaviour as the loss function, but can be a list. In this case all
78+
losses are calculated and logged as a dictionary with the corresponding
79+
name
80+
81+
***********************
82+
Custom Loss Functions
83+
***********************
84+
85+
Additionally, you can define your own loss function by subclassing
86+
``BaseWeightedLoss`` and implementing the ``forward`` method, or by
87+
subclassing ``FunctionalWeightedLoss`` and implementing the
88+
``calculate_difference`` function. The latter abstracts the scaling, and
89+
node weighting, and allows you to just specify the difference
90+
calculation.
91+
92+
.. code:: python
93+
94+
from anemoi.training.losses.weightedloss import FunctionalWeightedLoss
95+
96+
class MyLossFunction(FunctionalWeightedLoss):
97+
def calculate_difference(self, pred, target):
98+
return (pred - target) ** 2
99+
100+
Then in the config, set ``_target_`` to the class name, and any
101+
additional kwargs to the loss function.
102+
103+
*****************
104+
Combined Losses
105+
*****************
106+
107+
Building on the simple single loss functions, a user can define a
108+
combined loss, one that weights and combines multiple loss functions.
109+
110+
This can be done by referencing the ``CombinedLoss`` class in the config
111+
file, and setting the ``losses`` key to a list of loss functions to
112+
combine. Each of those losses is then initalised just like the other
113+
losses above.
114+
115+
.. code:: yaml
116+
117+
training_loss:
118+
__target__: anemoi.training.losses.combined.CombinedLoss
119+
losses:
120+
- __target__: anemoi.training.losses.mse.WeightedMSELoss
121+
- __target__: anemoi.training.losses.mae.WeightedMAELoss
122+
scalars: ['variable']
123+
loss_weights: [1.0,0.5]
124+
125+
All kwargs passed to ``CombinedLoss`` are passed to each of the loss
126+
functions, and the loss weights are used to scale the individual losses
127+
before combining them.
128+
129+
.. automodule:: anemoi.training.losses.combined
14130
:members:
15131
:no-undoc-members:
16132
:show-inheritance:
17133

134+
*******************
135+
Utility Functions
136+
*******************
137+
18138
There is also generic functions that are useful for losses in
19-
``aifs/losses/utils.py``.
139+
``anemoi/training/losses/utils.py``.
20140

21141
``grad_scaler`` is used to automatically scale the loss gradients in the
22142
loss function using the formula in https://arxiv.org/pdf/2306.06079.pdf,
23143
section 4.3.2. This can be switched on in the config by setting the
24144
option ``config.training.loss_gradient_scaling=True``.
145+
146+
``ScaleTensor`` is a class that can record and apply arbitrary scaling
147+
factors to tensors. It supports relative indexing, combining multiple
148+
scalars over the same dimensions, and is only constructed at
149+
broadcasting time, so the shape can be resolved to match the tensor
150+
exactly.
151+
152+
.. automodule:: anemoi.training.losses.utils
153+
:members:
154+
:no-undoc-members:
155+
:show-inheritance:

src/anemoi/training/config/training/default.yaml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,35 @@ swa:
3333
# use ZeroRedundancyOptimizer ; saves memory for larger models
3434
zero_optimizer: False
3535

36+
# loss functions
37+
3638
# dynamic rescaling of the loss gradient
3739
# see https://arxiv.org/pdf/2306.06079.pdf, section 4.3.2
3840
# don't enable this by default until it's been tested and proven beneficial
41+
42+
# loss function for the model
43+
training_loss:
44+
# loss class to initialise
45+
_target_: anemoi.training.losses.mse.WeightedMSELoss
46+
# Scalars to include in loss calculation
47+
# Available scalars include, 'variable'
48+
scalars: ['variable']
49+
ignore_nans: False
50+
3951
loss_gradient_scaling: False
4052

53+
# Validation metrics calculation,
54+
# This may be a list, in which case all metrics will be calculated
55+
# and logged according to their name
56+
validation_metrics:
57+
# loss class to initialise
58+
- _target_: anemoi.training.losses.mse.WeightedMSELoss
59+
# Scalars to include in loss calculation
60+
# Available scalars include, 'variable'
61+
scalars: []
62+
# other kwargs
63+
ignore_nans: True
64+
4165
# length of the "rollout" window (see Keisler's paper)
4266
rollout:
4367
start: 1

src/anemoi/training/diagnostics/callbacks/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from anemoi.training.diagnostics.plots import plot_loss
4545
from anemoi.training.diagnostics.plots import plot_power_spectrum
4646
from anemoi.training.diagnostics.plots import plot_predicted_multilevel_flat_sample
47+
from anemoi.training.losses.weightedloss import BaseWeightedLoss
4748

4849
if TYPE_CHECKING:
4950
import pytorch_lightning as pl
@@ -605,6 +606,11 @@ def _plot(
605606
# reorder parameter_names by position
606607
self.parameter_names = [parameter_names[i] for i in np.argsort(parameter_positions)]
607608

609+
if not isinstance(pl_module.loss, BaseWeightedLoss):
610+
logging.warning(
611+
"Loss function must be a subclass of BaseWeightedLoss, or provide `squash`.", RuntimeWarning
612+
)
613+
608614
batch = pl_module.model.pre_processors(batch, in_place=False)
609615
for rollout_step in range(pl_module.rollout):
610616
y_hat = outputs[1][rollout_step]
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# (C) Copyright 2024 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
from __future__ import annotations
11+
12+
import functools
13+
from typing import Any
14+
from typing import Callable
15+
16+
import torch
17+
18+
from anemoi.training.train.forecaster import GraphForecaster
19+
20+
21+
class CombinedLoss(torch.nn.Module):
22+
"""Combined Loss function."""
23+
24+
def __init__(
25+
self,
26+
*extra_losses: dict[str, Any] | Callable,
27+
losses: tuple[dict[str, Any] | Callable] | None = None,
28+
loss_weights: tuple[int, ...],
29+
**kwargs,
30+
):
31+
"""Combined loss function.
32+
33+
Allows multiple losses to be combined into a single loss function,
34+
and the components weighted.
35+
36+
If a sub loss function requires additional weightings or code created tensors,
37+
that must be `included_` for this function, and then controlled by the underlying
38+
loss function configuration.
39+
40+
Parameters
41+
----------
42+
losses: tuple[dict[str, Any]| Callable]
43+
Tuple of losses to initialise with `GraphForecaster.get_loss_function`.
44+
Allows for kwargs to be passed, and weighings controlled.
45+
*extra_losses: dict[str, Any] | Callable
46+
Additional arg form of losses to include in the combined loss.
47+
loss_weights : tuple[int, ...]
48+
Weights of each loss function in the combined loss.
49+
kwargs: Any
50+
Additional arguments to pass to the loss functions
51+
52+
Examples
53+
--------
54+
>>> CombinedLoss(
55+
{"__target__": "anemoi.training.losses.mse.WeightedMSELoss"},
56+
loss_weights=(1.0,),
57+
node_weights=node_weights
58+
)
59+
--------
60+
>>> CombinedLoss(
61+
losses = [anemoi.training.losses.mse.WeightedMSELoss],
62+
loss_weights=(1.0,),
63+
node_weights=node_weights
64+
)
65+
Or from the config,
66+
67+
```
68+
training_loss:
69+
__target__: anemoi.training.losses.combined.CombinedLoss
70+
losses:
71+
- __target__: anemoi.training.losses.mse.WeightedMSELoss
72+
- __target__: anemoi.training.losses.mae.WeightedMAELoss
73+
scalars: ['variable']
74+
loss_weights: [1.0,0.5]
75+
```
76+
"""
77+
super().__init__()
78+
79+
losses = (*(losses or []), *extra_losses)
80+
81+
assert len(losses) == len(loss_weights), "Number of losses and weights must match"
82+
assert len(losses) > 0, "At least one loss must be provided"
83+
84+
self.losses = [
85+
GraphForecaster.get_loss_function(loss, **kwargs) if isinstance(loss, dict) else loss(**kwargs)
86+
for loss in losses
87+
]
88+
self.loss_weights = loss_weights
89+
90+
def forward(
91+
self,
92+
pred: torch.Tensor,
93+
target: torch.Tensor,
94+
**kwargs,
95+
) -> torch.Tensor:
96+
"""Calculates the combined loss.
97+
98+
Parameters
99+
----------
100+
pred : torch.Tensor
101+
Prediction tensor, shape (bs, ensemble, lat*lon, n_outputs)
102+
target : torch.Tensor
103+
Target tensor, shape (bs, ensemble, lat*lon, n_outputs)
104+
kwargs: Any
105+
Additional arguments to pass to the loss functions
106+
Will be passed to all loss functions
107+
108+
Returns
109+
-------
110+
torch.Tensor
111+
Combined loss
112+
"""
113+
loss = None
114+
for i, loss_fn in enumerate(self.losses):
115+
if loss is not None:
116+
loss += self.loss_weights[i] * loss_fn(pred, target, **kwargs)
117+
else:
118+
loss = self.loss_weights[i] * loss_fn(pred, target, **kwargs)
119+
return loss
120+
121+
@property
122+
def name(self) -> str:
123+
return "combined_" + "_".join(getattr(loss, "name", loss.__class__.__name__.lower()) for loss in self.losses)
124+
125+
def __getattr__(self, name: str) -> Callable:
126+
"""Allow access to underlying attributes of the loss functions."""
127+
if not all(hasattr(loss, name) for loss in self.losses):
128+
error_msg = f"Attribute {name} not found in all loss functions"
129+
raise AttributeError(error_msg)
130+
131+
@functools.wraps(getattr(self.losses[0], name))
132+
def hidden_func(*args, **kwargs) -> list[Any]:
133+
return [getattr(loss, name)(*args, **kwargs) for loss in self.losses]
134+
135+
return hidden_func

0 commit comments

Comments
 (0)