|
3 | 3 | ########
|
4 | 4 |
|
5 | 5 | This module is used to define the loss function used to train the model.
|
| 6 | + |
| 7 | +Anemoi-training exposes a couple of loss functions by default to be |
| 8 | +used, all of which are subclassed from ``BaseWeightedLoss``. This class |
| 9 | +enables scalar multiplication, and graph node weighting. |
| 10 | + |
| 11 | +.. automodule:: anemoi.training.losses.weightedloss |
| 12 | + :members: |
| 13 | + :no-undoc-members: |
| 14 | + :show-inheritance: |
| 15 | + |
| 16 | +************************ |
| 17 | + Default Loss Functions |
| 18 | +************************ |
| 19 | + |
6 | 20 | By default anemoi-training trains the model using a latitude-weighted
|
7 | 21 | mean-squared-error, which is defined in the ``WeightedMSELoss`` class in
|
8 |
| -``aifs/losses/mse.py``. |
| 22 | +``anemoi/training/losses/mse.py``. The loss function can be configured |
| 23 | +in the config file at ``config.training.training_loss``, and |
| 24 | +``config.training.validation_metrics``. |
| 25 | + |
| 26 | +The following loss functions are available by default: |
| 27 | + |
| 28 | +- ``WeightedMSELoss``: Latitude-weighted mean-squared-error. |
| 29 | +- ``WeightedMAELoss``: Latitude-weighted mean-absolute-error. |
| 30 | +- ``WeightedHuberLoss``: Latitude-weighted Huber loss. |
| 31 | +- ``WeightedLogCoshLoss``: Latitude-weighted log-cosh loss. |
| 32 | +- ``WeightedRMSELoss``: Latitude-weighted root-mean-squared-error. |
| 33 | +- ``CombinedLoss``: Combined component weighted loss. |
| 34 | + |
| 35 | +These are available in the ``anemoi.training.losses`` module, at |
| 36 | +``anemoi.training.losses.{short_name}.{class_name}``. |
| 37 | + |
| 38 | +So for example, to use the ``WeightedMSELoss`` class, you would |
| 39 | +reference it in the config as follows: |
| 40 | + |
| 41 | +.. code:: yaml |
| 42 | +
|
| 43 | + # loss function for the model |
| 44 | + training_loss: |
| 45 | + # loss class to initialise |
| 46 | + _target_: anemoi.training.losses.mse.WeightedMSELoss |
| 47 | + # loss function kwargs here |
| 48 | +
|
| 49 | +********* |
| 50 | + Scalars |
| 51 | +********* |
| 52 | + |
| 53 | +In addition to node scaling, the loss function can also be scaled by a |
| 54 | +scalar. These are provided by the ``Forecaster`` class, and a user can |
| 55 | +define whether to include them in the loss function by setting |
| 56 | +``scalars`` in the loss config dictionary. |
| 57 | + |
| 58 | +.. code:: yaml |
| 59 | +
|
| 60 | + # loss function for the model |
| 61 | + training_loss: |
| 62 | + # loss class to initialise |
| 63 | + _target_: anemoi.training.losses.mse.WeightedMSELoss |
| 64 | + scalars: ['scalar1', 'scalar2'] |
| 65 | +
|
| 66 | +Currently, the following scalars are available for use: |
| 67 | + |
| 68 | +- ``variable``: Scale by the feature/variable weights as defined in the |
| 69 | + config ``config.training.loss_scaling``. |
9 | 70 |
|
10 |
| -The user can define their own loss function using the same structure as |
11 |
| -the ``WeightedMSELoss`` class. |
| 71 | +******************** |
| 72 | + Validation Metrics |
| 73 | +******************** |
12 | 74 |
|
13 |
| -.. automodule:: anemoi.training.losses.mse |
| 75 | +Validation metrics as defined in the config file at |
| 76 | +``config.training.validation_metrics`` follow the same initialise |
| 77 | +behaviour as the loss function, but can be a list. In this case all |
| 78 | +losses are calculated and logged as a dictionary with the corresponding |
| 79 | +name |
| 80 | + |
| 81 | +*********************** |
| 82 | + Custom Loss Functions |
| 83 | +*********************** |
| 84 | + |
| 85 | +Additionally, you can define your own loss function by subclassing |
| 86 | +``BaseWeightedLoss`` and implementing the ``forward`` method, or by |
| 87 | +subclassing ``FunctionalWeightedLoss`` and implementing the |
| 88 | +``calculate_difference`` function. The latter abstracts the scaling, and |
| 89 | +node weighting, and allows you to just specify the difference |
| 90 | +calculation. |
| 91 | + |
| 92 | +.. code:: python |
| 93 | +
|
| 94 | + from anemoi.training.losses.weightedloss import FunctionalWeightedLoss |
| 95 | +
|
| 96 | + class MyLossFunction(FunctionalWeightedLoss): |
| 97 | + def calculate_difference(self, pred, target): |
| 98 | + return (pred - target) ** 2 |
| 99 | +
|
| 100 | +Then in the config, set ``_target_`` to the class name, and any |
| 101 | +additional kwargs to the loss function. |
| 102 | + |
| 103 | +***************** |
| 104 | + Combined Losses |
| 105 | +***************** |
| 106 | + |
| 107 | +Building on the simple single loss functions, a user can define a |
| 108 | +combined loss, one that weights and combines multiple loss functions. |
| 109 | + |
| 110 | +This can be done by referencing the ``CombinedLoss`` class in the config |
| 111 | +file, and setting the ``losses`` key to a list of loss functions to |
| 112 | +combine. Each of those losses is then initalised just like the other |
| 113 | +losses above. |
| 114 | + |
| 115 | +.. code:: yaml |
| 116 | +
|
| 117 | + training_loss: |
| 118 | + __target__: anemoi.training.losses.combined.CombinedLoss |
| 119 | + losses: |
| 120 | + - __target__: anemoi.training.losses.mse.WeightedMSELoss |
| 121 | + - __target__: anemoi.training.losses.mae.WeightedMAELoss |
| 122 | + scalars: ['variable'] |
| 123 | + loss_weights: [1.0,0.5] |
| 124 | +
|
| 125 | +All kwargs passed to ``CombinedLoss`` are passed to each of the loss |
| 126 | +functions, and the loss weights are used to scale the individual losses |
| 127 | +before combining them. |
| 128 | + |
| 129 | +.. automodule:: anemoi.training.losses.combined |
14 | 130 | :members:
|
15 | 131 | :no-undoc-members:
|
16 | 132 | :show-inheritance:
|
17 | 133 |
|
| 134 | +******************* |
| 135 | + Utility Functions |
| 136 | +******************* |
| 137 | + |
18 | 138 | There is also generic functions that are useful for losses in
|
19 |
| -``aifs/losses/utils.py``. |
| 139 | +``anemoi/training/losses/utils.py``. |
20 | 140 |
|
21 | 141 | ``grad_scaler`` is used to automatically scale the loss gradients in the
|
22 | 142 | loss function using the formula in https://arxiv.org/pdf/2306.06079.pdf,
|
23 | 143 | section 4.3.2. This can be switched on in the config by setting the
|
24 | 144 | option ``config.training.loss_gradient_scaling=True``.
|
| 145 | + |
| 146 | +``ScaleTensor`` is a class that can record and apply arbitrary scaling |
| 147 | +factors to tensors. It supports relative indexing, combining multiple |
| 148 | +scalars over the same dimensions, and is only constructed at |
| 149 | +broadcasting time, so the shape can be resolved to match the tensor |
| 150 | +exactly. |
| 151 | + |
| 152 | +.. automodule:: anemoi.training.losses.utils |
| 153 | + :members: |
| 154 | + :no-undoc-members: |
| 155 | + :show-inheritance: |
0 commit comments