-
Notifications
You must be signed in to change notification settings - Fork 463
Interpolation model example #1149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jleinonen
wants to merge
24
commits into
NVIDIA:main
Choose a base branch
from
jleinonen:interp-model-example
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,798
−0
Open
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
b0dbf92
Temporal interpolation training recipe
jleinonen af84a7b
Add README
jleinonen d9841bc
Merge branch 'NVIDIA:main' into interp-model-example
jleinonen f56991e
Docs changes based on comments
jleinonen ae6eed1
Update docstrings and README
jleinonen 5430eb4
Add temporal interpolation animation
jleinonen 84289f5
Merge branch 'NVIDIA:main' into interp-model-example
jleinonen 2b2c81e
Add animation link
jleinonen 6f09aa1
Add shape check in loss
jleinonen 811a38a
Updates of configs + trainer
jleinonen cb23fe6
Update config comments
jleinonen c32a01b
Merge branch 'NVIDIA:main' into interp-model-example
jleinonen f642d78
Merge branch 'main' into interp-model-example
CharlelieLrt 8d6ae42
Update README.md
megnvidia e1c202b
Added wandb logging
CharlelieLrt dc2b215
Merge branch 'interp-model-example' of https://github.com/jleinonen/m…
CharlelieLrt a253d6e
Reformated sections in docstring for GeometricL2Loss
CharlelieLrt aad2683
Merge branch 'main' into interp-model-example
CharlelieLrt e25aaeb
Update README and configs
jleinonen 2ea2897
Merge branch 'interp-model-example' of
jleinonen 8475e1d
README changes + type hint fixes
jleinonen 547f10d
Update README.md
jleinonen 41560fc
Merge branch 'NVIDIA:main' into interp-model-example
jleinonen 5912555
Draft of validation script
jleinonen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# Earth-2 Temporal Interpolation Model | ||
|
||
The temporal interpolation model is used to increase the temporal resolution of AI-based | ||
forecast models. These typically have a native temporal resolution of six hours; the | ||
interpolation allows this to be improved to one hour. With appropriate training data, even | ||
higher temporal resolutions might be achievable. | ||
|
||
This PhysicsNeMo example shows how to train a ModAFNO-based temporal interpolation model | ||
with a custom dataset. This architecture uses an embedding network to determine | ||
parameters for a shift and scale operation that is used to modify the behavior of the AFNO | ||
network depending on a given conditioning variable. For temporal | ||
interpolation, the atmospheric states at both ends of the interpolation interval are | ||
passed as inputs along with some auxiliary data, such as orography, and the conditioning | ||
indicates which time step between the endpoints will be generated by the model. The | ||
interpolation is deterministic and trained with a latitude-weighted L2 loss. However, it | ||
can still be used to produce probabilistic forecasts, if used to interpolate results of | ||
probabilistic forecast models. More formally, the ModAFNO $f_{\theta}$ is a conditional | ||
expected-value model that approximates: | ||
|
||
$$ | ||
f_{\theta} (x_{t}, x_{t+T}, \Delta t) \approx | ||
\mathbb{E} \left[ x_{t + \Delta t} | x_{t}, x_{t+T}, \Delta t \right] | ||
$$ | ||
|
||
$0 \leq \Delta t \leq T$. In the pre-trained model, $T = 6$ hours and | ||
$\Delta t \in \{0, 1, 2, 3, 4, 5, 6\}$ hours. | ||
|
||
For access to the pre-trained model, refer to the [wrapper in | ||
Earth2Studio](https://nvidia.github.io/earth2studio/modules/generated/models/px/earth2studio.models.px.InterpModAFNO.html#earth2studio.models.px.InterpModAFNO). | ||
A technical description of the model can be found in the paper ["Modulated Adaptive | ||
Fourier Neural Operators for Temporal Interpolation of Weather | ||
Forecasts"](https://arxiv.org/abs/2410.18904). | ||
|
||
 | ||
|
||
## Requirements | ||
|
||
### Environment | ||
|
||
You must have PhysicsNeMo installed on a GPU system. Training useful models, in | ||
practice, requires a multi-GPU system; for the original model, 64 H100 GPUs were used. | ||
Using the [PhysicsNeMo | ||
container](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/physicsnemo/containers/physicsnemo) | ||
is recommended. | ||
|
||
Install the additional packages (MLFlow) needed by this example: | ||
|
||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
### Data | ||
CharlelieLrt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
To train a temporal interpolation model, ensure you have the following: | ||
|
||
* A dataset of yearly HDF5 files at one-hour resolution. For more details, refer to the | ||
section ["Data Format and Structure" in the diagnostic model | ||
example](https://github.com/NVIDIA/physicsnemo/blob/5a64525c40eada2248cd3eacee0a6ac4735ae380/examples/weather/diagnostic/README.md#data-format-and-structure). | ||
These datasets can be very large. The dataset used to train the original model, with | ||
73 variables from 1980 to 2017, is approximately 100 TB in size. The data used to | ||
train the original model are on the ERA5 0.25 degree grid with shape `(721, 1440)` but | ||
other resolutions can work too. The ERA5 data is freely accessible; a recommended | ||
method to download it is the [ERA5 interface in | ||
Earth2Studio](https://nvidia.github.io/earth2studio/modules/generated/data/earth2studio.data.CDS.html). | ||
The data downloaded from this interface must then be inserted into the HDF5 file. | ||
* Statistics files containing the mean and standard deviation of each channel in the | ||
data files. They must be in the `stats/global_means.npy` and | ||
`stats/global_stds.npy` files in your data directory. They must be `.npy` files | ||
containing a 1D array with length equal to the number of variables in the dataset, | ||
with each value giving the mean (for `global_means.npy`) or standard deviation (for | ||
`global_stds.npy`) of the corresponding variable. | ||
* A JSON file with metadata about the contents of the HDF5 files. Refer to the [data | ||
sample](https://github.com/NVIDIA/physicsnemo/blob/main/examples/weather/temporal_interpolation/data/data.json) | ||
for an example describing the dataset used to train the original model. | ||
* Optional: NetCDF4 files containing the orography and land-sea mask for the grid | ||
contained in the data. These should contain a variable of the same shape as the data. | ||
|
||
## Configuration | ||
|
||
The model training is controlled by YAML configuration files that are managed by | ||
[Hydra](https://hydra.cc/), which is found in the `config` directory. The full | ||
configuration for training of the original model is | ||
[`train_interp.yaml`](https://github.com/NVIDIA/physicsnemo/blob/main/examples/weather/temporal_interpolation/config/train_interp.yaml). | ||
[`train_interp_lite.yaml`](https://github.com/NVIDIA/physicsnemo/blob/main/examples/weather/temporal_interpolation/config/train_interp_lite.yaml) | ||
runs a short test with a lightweight model, which is not expected to produce useful | ||
checkpoints but can be used to verify that training runs without errors. | ||
|
||
See the comments in the configuration files for an explanation of each configuration | ||
parameter. To replicate the model from the paper, you only need to change the file and | ||
directory paths to correspond to those on your system. If you train it with a custom | ||
dataset, you might also need to change the `model.in_channels` and `model.out_channels` | ||
parameters. | ||
|
||
## Starting Training | ||
|
||
Test the training by running the `train.py` script using the "lite" configuration file | ||
on a system with a GPU: | ||
|
||
```bash | ||
python train.py --config-name=train_interp_lite.yaml | ||
``` | ||
|
||
For a multi-GPU or multi-node training job, launch the training with the | ||
`train_interp.yaml` configuration file using `torchrun` or MPI. For example, to train on | ||
eight nodes with eight GPUs each, for a total of 64 GPUs, start a distributed compute | ||
job (for example, using SLURM or Run:ai) and use: | ||
|
||
```bash | ||
torchrun --nnodes=8 --nproc-per-node=8 train.py --config-name=train_interp.yaml | ||
``` | ||
|
||
Or the equivalent `mpirun` command. The code will automatically use all GPUs | ||
available to the job. Remember to set `training.batch_size` in the configuration file to | ||
the batch size *per process*. | ||
|
||
Configuration parameters can be overridden from the command line using the Hydra syntax. | ||
For instance, to set the optimizer learning rate to 0.0001 for the current run, you | ||
can use: | ||
|
||
```bash | ||
torchrun --nnodes=8 --nproc-per-node=8 train.py --config-name=train_interp.yaml ++training.optimizer_params.lr=0.0001 | ||
``` |
64 changes: 64 additions & 0 deletions
64
examples/weather/temporal_interpolation/config/train_interp.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. | ||
# SPDX-FileCopyrightText: All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
model: | ||
model_type: "modafno" # should always be "modafno" | ||
model_name: "modafno-cplxscale-smallpatch" # name for the model | ||
inp_shape: [720, 1440] # should be [720, 1440], must be divisible by patch_size | ||
in_channels: null # number of input channels to the model, null determines it from datapipe | ||
out_channels: null # number of output channels from the model, null determines it from datapipe | ||
patch_size: [2,2] # size of AFNO patches | ||
embed_dim: 512 # embedding dimension | ||
mlp_ratio: 2.0 # multiplier for MLP hidden layer size (may be a non-integer value, e.g. 2.5) | ||
num_blocks: 12 # number of AFNO blocks | ||
|
||
scale_shift_mode: complex # type of numbers used for the ModAFNO modulation, "real" or "complex" | ||
embed_model: | ||
dim: 64 # width of time embedding net | ||
depth: 1 # depth of time embedding net | ||
method: sinusoidal # embedding type used in time embedding net, "sinusoidal" or "learned" | ||
|
||
datapipe: | ||
data_dir: "/data/era5-73varQ-hourly" # directory where data files are located | ||
metadata_path: "/data/era5-73varQ-hourly/metadata/data.json" # directory to metadata json file | ||
geopotential_filename: "/data/era5-wind_gust/invariants/orography.nc" # location of orography file | ||
lsm_filename: "/data/era5-wind_gust/invariants/land_sea_mask.nc" # location of lsm file | ||
use_latlon: True # when True, return latitude and longitude from datapipe | ||
num_samples_per_year_train: null # number of training samples per year, null uses all available | ||
num_samples_per_year_valid: 64 # number of validation samples per year | ||
batch_size_train: 1 # batch size per GPU | ||
|
||
training: | ||
max_epoch: 120 # number of data "epochs" (each epoch we save a checkpoint, run validation, update LR) | ||
samples_per_epoch: 50000 # number of samples per "epoch" | ||
load_epoch: "latest" # int, null or "latest"; "latest" loads the most recent checkpoint in checkpoint_dir | ||
checkpoint_dir: "/checkpoints/fcinterp/" # location where checkpoints are saved | ||
CharlelieLrt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
optimizer_params: | ||
lr: 5e-4 # learning rate | ||
betas: [0.9, 0.95] # beta parameters for Adam | ||
|
||
logging: | ||
mlflow: | ||
use_mlflow: True # when True, produce logs with mlflow | ||
experiment_name: "Forecast interpolation model" # experiment name, can be set freely | ||
user_name: "PhysicsNeMo User" # user name, can be set freely | ||
wandb: | ||
use_wandb: False # when True, produce logs with wandb | ||
mode: "offline" # "online", "offline", or "disabled" | ||
project: "Temporal-Interpolation-Training" # project name for wandb | ||
entity: null # entity (username or team) for Weights & Biases | ||
results_dir: "./wandb/" # directory to save wandb logs |
68 changes: 68 additions & 0 deletions
68
examples/weather/temporal_interpolation/config/train_interp_lite.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. | ||
# SPDX-FileCopyrightText: All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Config file for testing training. Does a very short run with a small model. | ||
# Can be used to test that training runs without errors, not expected to | ||
# produce useful checkpoints. | ||
|
||
model: | ||
model_type: "modafno" # should always be "modafno" | ||
model_name: "modafno-test" # name for the model | ||
inp_shape: [720, 1440] # should be [720, 1440], must be divisible by patch_size | ||
in_channels: null # number of input channels to the model, null determines it from datapipe | ||
out_channels: null # number of output channels from the model, null determines it from datapipe | ||
patch_size: [8,8] # size of AFNO patches | ||
embed_dim: 64 # embedding dimension | ||
mlp_ratio: 2.0 # multiplier for MLP hidden layer size (may be a non-integer value, e.g. 2.5) | ||
num_blocks: 2 # number of AFNO blocks | ||
|
||
scale_shift_mode: complex # type of numbers used for the ModAFNO modulation, "real" or "complex" | ||
embed_model: | ||
dim: 64 # width of time embedding net | ||
depth: 1 # depth of time embedding net | ||
method: sinusoidal # embedding type used in time embedding net, "sinusoidal" or "learned" | ||
|
||
datapipe: | ||
data_dir: "/data/era5-73varQ-hourly" # directory where data files are located | ||
metadata_path: "/data/era5-73varQ-hourly/metadata/data.json" # directory to metadata json file | ||
geopotential_filename: "/data/era5-wind_gust/invariants/orography.nc" # location of orography file | ||
lsm_filename: "/data/era5-wind_gust/invariants/land_sea_mask.nc" # location of lsm file | ||
use_latlon: True # when True, return latitude and longitude from datapipe | ||
num_samples_per_year_train: null # number of training samples per year, null uses all available | ||
num_samples_per_year_valid: 64 # number of validation samples per year | ||
batch_size_train: 1 # batch size per GPU | ||
|
||
training: | ||
max_epoch: 4 # number of data "epochs" (each epoch we save a checkpoint, run validation, update LR) | ||
samples_per_epoch: 50 # number of samples per "epoch" | ||
load_epoch: "latest" # int, null or "latest"; "latest" loads the most recent checkpoint in checkpoint_dir | ||
checkpoint_dir: "/checkpoints/fcinterp/" # location where checkpoints are saved | ||
|
||
optimizer_params: | ||
lr: 5e-4 # learning rate | ||
betas: [0.9, 0.95] # beta parameters for Adam | ||
|
||
logging: | ||
mlflow: | ||
use_mlflow: True # when True, produce logs with mlflow | ||
experiment_name: "Forecast interpolation model" # experiment name, can be set freely | ||
user_name: "PhysicsNeMo User" # user name, can be set freely | ||
wandb: | ||
use_wandb: False # when True, produce logs with wandb | ||
mode: "offline" # "online", "offline", or "disabled" | ||
project: "Temporal-Interpolation-Training" # project name for wandb | ||
entity: null # entity (username or team) for Weights & Biases | ||
results_dir: "./wandb/" # directory to save wandb logs |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
{ | ||
jleinonen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"dataset_name": "73ch-hourly", | ||
"attrs": { | ||
"description": "ERA5 data at 1 hourly frequency with snapshots at every hour 0000, 0100, 0200, 0300, ..., 2300 UTC. First snapshot in each file is Jan 01 0000 UTC. " | ||
}, | ||
"h5_path": "fields", | ||
"dims": [ | ||
"time", | ||
"channel", | ||
"lat", | ||
"lon" | ||
], | ||
"coords": { | ||
"channel": [ | ||
"u10m", | ||
"v10m", | ||
"u100m", | ||
"v100m", | ||
"t2m", | ||
"sp", | ||
"msl", | ||
"tcwv", | ||
"u50", | ||
"u100", | ||
"u150", | ||
"u200", | ||
"u250", | ||
"u300", | ||
"u400", | ||
"u500", | ||
"u600", | ||
"u700", | ||
"u850", | ||
"u925", | ||
"u1000", | ||
"v50", | ||
"v100", | ||
"v150", | ||
"v200", | ||
"v250", | ||
"v300", | ||
"v400", | ||
"v500", | ||
"v600", | ||
"v700", | ||
"v850", | ||
"v925", | ||
"v1000", | ||
"z50", | ||
"z100", | ||
"z150", | ||
"z200", | ||
"z250", | ||
"z300", | ||
"z400", | ||
"z500", | ||
"z600", | ||
"z700", | ||
"z850", | ||
"z925", | ||
"z1000", | ||
"t50", | ||
"t100", | ||
"t150", | ||
"t200", | ||
"t250", | ||
"t300", | ||
"t400", | ||
"t500", | ||
"t600", | ||
"t700", | ||
"t850", | ||
"t925", | ||
"t1000", | ||
"q50", | ||
"q100", | ||
"q150", | ||
"q200", | ||
"q250", | ||
"q300", | ||
"q400", | ||
"q500", | ||
"q600", | ||
"q700", | ||
"q850", | ||
"q925", | ||
"q1000" | ||
] | ||
} | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.