Skip to content

Commit 27c858d

Browse files
HCookieJPXKQXsahahner
authored
Refactor Callbacks (#60)
* Refactor Callbacks - Split into seperate files - Use list in config to add callbacks - Split out plotting callbacks config * Refactor rollout (#87) - New rollout central function --------- Co-authored-by: Mario Santa Cruz <[email protected]> Co-authored-by: Sara Hahner <[email protected]>
1 parent beffa06 commit 27c858d

File tree

24 files changed

+2129
-1194
lines changed

24 files changed

+2129
-1194
lines changed

CHANGELOG.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ Keep it human-readable, your future self will thank you!
1212

1313
## [0.2.2 - Maintenance: pin python <3.13](https://github.com/ecmwf/anemoi-training/compare/0.2.1...0.2.2) - 2024-10-28
1414

15+
### Fixed
16+
- Refactored callbacks. [#60](https://github.com/ecmwf/anemoi-training/pulls/60)
17+
- Refactored rollout [#87](https://github.com/ecmwf/anemoi-training/pulls/87)
18+
- Enable longer validation rollout than training
1519
### Added
1620
- Included more loss functions and allowed configuration [#70](https://github.com/ecmwf/anemoi-training/pull/70)
1721

@@ -29,7 +33,9 @@ Keep it human-readable, your future self will thank you!
2933
- Feature: New `Boolean1DMask` class. Enables rollout training for limited area models. [#79](https://github.com/ecmwf/anemoi-training/pulls/79)
3034

3135
### Fixed
32-
36+
- Mlflow-sync to handle creation of new experiments in the remote server [#83] (https://github.com/ecmwf/anemoi-training/pull/83)
37+
- Fix for multi-gpu when using mlflow due to refactoring of _get_mlflow_run_params function [#99] (https://github.com/ecmwf/anemoi-training/pull/99)
38+
- ci: fix pyshtools install error (#100) https://github.com/ecmwf/anemoi-training/pull/100
3339
- Mlflow-sync to handle creation of new experiments in the remote server [#83](https://github.com/ecmwf/anemoi-training/pull/83)
3440
- Fix for multi-gpu when using mlflow due to refactoring of _get_mlflow_run_params function [#99](https://github.com/ecmwf/anemoi-training/pull/99)
3541
- ci: fix pyshtools install error [#100](https://github.com/ecmwf/anemoi-training/pull/100)

docs/modules/diagnostics.rst

Lines changed: 84 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,94 @@ functionality to use both Weights & Biases and Tensorboard.
2121

2222
The callbacks can also be used to evaluate forecasts over longer
2323
rollouts beyond the forecast time that the model is trained on. The
24-
number of rollout steps (or forecast iteration steps) is set using
25-
``config.eval.rollout = *num_of_rollout_steps*``.
26-
27-
Note the user has the option to evaluate the callbacks asynchronously
28-
(using the following config option
29-
``config.diagnostics.plot.asynchronous``, which means that the model
30-
training doesn't stop whilst the callbacks are being evaluated).
31-
However, note that callbacks can still be slow, and therefore the
32-
plotting callbacks can be switched off by setting
33-
``config.diagnostics.plot.enabled`` to ``False`` or all the callbacks
34-
can be completely switched off by setting
35-
``config.diagnostics.eval.enabled`` to ``False``.
24+
number of rollout steps for verification (or forecast iteration steps)
25+
is set using ``config.dataloader.validation_rollout =
26+
*num_of_rollout_steps*``.
27+
28+
Callbacks are configured in the config file under the
29+
``config.diagnostics`` key.
30+
31+
For regular callbacks, they can be provided as a list of dictionaries
32+
underneath the ``config.diagnostics.callbacks`` key. Each dictionary
33+
must have a ``_target`` key which is used by hydra to instantiate the
34+
callback, any other kwarg is passed to the callback's constructor.
35+
36+
.. code:: yaml
37+
38+
callbacks:
39+
- _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval
40+
rollout: ${dataloader.validation_rollout}
41+
frequency: 20
42+
43+
Plotting callbacks are configured in a similar way, but they are
44+
specified underneath the ``config.diagnostics.plot.callbacks`` key.
45+
46+
This is done to ensure seperation and ease of configuration between
47+
experiments.
48+
49+
``config.diagnostics.plot`` is a broader config file specifying the
50+
parameters to plot, as well as the plotting frequency, and
51+
asynchronosity.
52+
53+
Setting ``config.diagnostics.plot.asynchronous``, means that the model
54+
training doesn't stop whilst the callbacks are being evaluated)
55+
56+
.. code:: yaml
57+
58+
plot:
59+
asynchronous: True # Whether to plot asynchronously
60+
frequency: # Frequency of the plotting
61+
batch: 750
62+
epoch: 5
63+
64+
# Parameters to plot
65+
parameters:
66+
- z_500
67+
- t_850
68+
- u_850
69+
70+
# Sample index
71+
sample_idx: 0
72+
73+
# Precipitation and related fields
74+
precip_and_related_fields: [tp, cp]
75+
76+
callbacks:
77+
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss
78+
# group parameters by categories when visualizing contributions to the loss
79+
# one-parameter groups are possible to highlight individual parameters
80+
parameter_groups:
81+
moisture: [tp, cp, tcw]
82+
sfc_wind: [10u, 10v]
83+
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample
84+
sample_idx: ${diagnostics.plot.sample_idx}
85+
per_sample : 6
86+
parameters: ${diagnostics.plot.parameters}
3687
3788
Below is the documentation for the default callbacks provided, but it is
3889
also possible for users to add callbacks using the same structure:
3990

40-
.. automodule:: anemoi.training.diagnostics.callbacks
91+
.. automodule:: anemoi.training.diagnostics.callbacks.checkpoint
92+
:members:
93+
:no-undoc-members:
94+
:show-inheritance:
95+
96+
.. automodule:: anemoi.training.diagnostics.callbacks.evaluation
97+
:members:
98+
:no-undoc-members:
99+
:show-inheritance:
100+
101+
.. automodule:: anemoi.training.diagnostics.callbacks.optimiser
102+
:members:
103+
:no-undoc-members:
104+
:show-inheritance:
105+
106+
.. automodule:: anemoi.training.diagnostics.callbacks.plot
107+
:members:
108+
:no-undoc-members:
109+
:show-inheritance:
110+
111+
.. automodule:: anemoi.training.diagnostics.callbacks.provenance
41112
:members:
42113
:no-undoc-members:
43114
:show-inheritance:

docs/user-guide/configuring.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ settings at the top as follows:
2121
defaults:
2222
- data: zarr
2323
- dataloader: native_grid
24-
- diagnostics: eval_rollout
24+
- diagnostics: evaluation
2525
- hardware: example
2626
- graph: multi_scale
2727
- model: gnn
@@ -100,7 +100,7 @@ match the dataset you provide.
100100
defaults:
101101
- data: zarr
102102
- dataloader: native_grid
103-
- diagnostics: eval_rollout
103+
- diagnostics: evaluation
104104
- hardware: example
105105
- graph: multi_scale
106106
- model: transformer # Change from default group

docs/user-guide/tracking.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ the same experiment.
3333
Within the MLflow experiments tab, it is possible to define different
3434
namespaces. To create a new namespace, the user just needs to pass an
3535
'experiment_name'
36-
(``config.diagnostics.eval_rollout.log.mlflow.experiment_name``) to the
36+
(``config.diagnostics.evaluation.log.mlflow.experiment_name``) to the
3737
mlflow logger.
3838

3939
**Parent-Child Runs**

src/anemoi/training/config/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
defaults:
22
- data: zarr
33
- dataloader: native_grid
4-
- diagnostics: eval_rollout
4+
- diagnostics: evaluation
55
- hardware: example
66
- graph: multi_scale
77
- model: gnn

src/anemoi/training/config/dataloader/native_grid.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ training:
4545
frequency: ${data.frequency}
4646
drop: []
4747

48+
validation_rollout: 1 # number of rollouts to use for validation, must be equal or greater than rollout expected by callbacks
49+
4850
validation:
4951
dataset: ${dataloader.dataset}
5052
start: 2021

src/anemoi/training/config/debug.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
defaults:
22
- data: zarr
33
- dataloader: native_grid
4-
- diagnostics: eval_rollout
4+
- diagnostics: evaluation
55
- hardware: example
66
- graph: multi_scale
77
- model: gnn
@@ -18,7 +18,7 @@ defaults:
1818

1919
diagnostics:
2020
plot:
21-
enabled: False
21+
callbacks: []
2222
hardware:
2323
files:
2424
graph: ???
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Add callbacks here
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Add callbacks here
2+
- _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval
3+
rollout: ${dataloader.validation_rollout}
4+
frequency: 20

src/anemoi/training/config/diagnostics/eval_rollout.yaml renamed to src/anemoi/training/config/diagnostics/evaluation.yaml

Lines changed: 5 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,8 @@
11
---
2-
eval:
3-
enabled: False
4-
# use this to evaluate the model over longer rollouts, every so many validation batches
5-
rollout: 12
6-
frequency: 20
7-
plot:
8-
enabled: True
9-
asynchronous: True
10-
frequency: 750
11-
sample_idx: 0
12-
per_sample: 6
13-
parameters:
14-
- z_500
15-
- t_850
16-
- u_850
17-
- v_850
18-
- 2t
19-
- 10u
20-
- 10v
21-
- sp
22-
- tp
23-
- cp
24-
#Defining the accumulation levels for precipitation related fields and the colormap
25-
accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm
26-
cmap_accumulation: ["#ffffff", "#04e9e7", "#019ff4", "#0300f4", "#02fd02", "#01c501", "#008e00", "#fdf802", "#e5bc00", "#fd9500", "#fd0000", "#d40000", "#bc0000", "#f800fd"]
27-
precip_and_related_fields: [tp, cp]
28-
# Histogram and Spectrum plots
29-
parameters_histogram:
30-
- z_500
31-
- tp
32-
- 2t
33-
- 10u
34-
- 10v
35-
parameters_spectrum:
36-
- z_500
37-
- tp
38-
- 2t
39-
- 10u
40-
- 10v
41-
# group parameters by categories when visualizing contributions to the loss
42-
# one-parameter groups are possible to highlight individual parameters
43-
parameter_groups:
44-
moisture: [tp, cp, tcw]
45-
sfc_wind: [10u, 10v]
46-
learned_features: False
47-
longrollout:
48-
enabled: False
49-
rollout: [60]
50-
frequency: 20 # every X epochs
2+
defaults:
3+
- plot: detailed
4+
- callbacks: pretraining
5+
516

527
debug:
538
# this will detect and trace back NaNs / Infs etc. but will slow down training
@@ -57,6 +12,7 @@ debug:
5712
# remember to also activate the tensorboard logger (below)
5813
profiler: False
5914

15+
enable_checkpointing: True
6016
checkpoint:
6117
every_n_minutes:
6218
save_frequency: 30 # Approximate, as this is checked at the end of training steps

0 commit comments

Comments
 (0)