Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
b0dbf92
Temporal interpolation training recipe
jleinonen Oct 3, 2025
af84a7b
Add README
jleinonen Oct 8, 2025
d9841bc
Merge branch 'NVIDIA:main' into interp-model-example
jleinonen Oct 8, 2025
f56991e
Docs changes based on comments
jleinonen Oct 13, 2025
ae6eed1
Update docstrings and README
jleinonen Oct 14, 2025
5430eb4
Add temporal interpolation animation
jleinonen Oct 14, 2025
84289f5
Merge branch 'NVIDIA:main' into interp-model-example
jleinonen Oct 14, 2025
2b2c81e
Add animation link
jleinonen Oct 14, 2025
6f09aa1
Add shape check in loss
jleinonen Oct 14, 2025
811a38a
Updates of configs + trainer
jleinonen Oct 15, 2025
cb23fe6
Update config comments
jleinonen Oct 15, 2025
c32a01b
Merge branch 'NVIDIA:main' into interp-model-example
jleinonen Oct 20, 2025
f642d78
Merge branch 'main' into interp-model-example
CharlelieLrt Oct 21, 2025
8d6ae42
Update README.md
megnvidia Oct 21, 2025
e1c202b
Added wandb logging
CharlelieLrt Oct 22, 2025
dc2b215
Merge branch 'interp-model-example' of https://github.com/jleinonen/m…
CharlelieLrt Oct 22, 2025
a253d6e
Reformated sections in docstring for GeometricL2Loss
CharlelieLrt Oct 22, 2025
aad2683
Merge branch 'main' into interp-model-example
CharlelieLrt Oct 22, 2025
e25aaeb
Update README and configs
jleinonen Oct 22, 2025
2ea2897
Merge branch 'interp-model-example' of
jleinonen Oct 22, 2025
8475e1d
README changes + type hint fixes
jleinonen Oct 22, 2025
547f10d
Update README.md
jleinonen Oct 22, 2025
41560fc
Merge branch 'NVIDIA:main' into interp-model-example
jleinonen Oct 23, 2025
5912555
Draft of validation script
jleinonen Oct 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/img/temporal_interpolation.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
122 changes: 122 additions & 0 deletions examples/weather/temporal_interpolation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Earth-2 Temporal Interpolation Model

The temporal interpolation model is used to increase the temporal resolution of AI-based
forecast models. These typically have a native temporal resolution of six hours; the
interpolation allows this to be improved to one hour. With appropriate training data, even
higher temporal resolutions might be achievable.

This PhysicsNeMo example shows how to train a ModAFNO-based temporal interpolation model
with a custom dataset. This architecture uses an embedding network to determine
parameters for a shift and scale operation that is used to modify the behavior of the AFNO
network depending on a given conditioning variable. For temporal
interpolation, the atmospheric states at both ends of the interpolation interval are
passed as inputs along with some auxiliary data, such as orography, and the conditioning
indicates which time step between the endpoints will be generated by the model. The
interpolation is deterministic and trained with a latitude-weighted L2 loss. However, it
can still be used to produce probabilistic forecasts, if used to interpolate results of
probabilistic forecast models. More formally, the ModAFNO $f_{\theta}$ is a conditional
expected-value model that approximates:

$$
f_{\theta} (x_{t}, x_{t+T}, \Delta t) \approx
\mathbb{E} \left[ x_{t + \Delta t} | x_{t}, x_{t+T}, \Delta t \right]
$$

$0 \leq \Delta t \leq T$. In the pre-trained model, $T = 6$ hours and
$\Delta t \in \{0, 1, 2, 3, 4, 5, 6\}$ hours.

For access to the pre-trained model, refer to the [wrapper in
Earth2Studio](https://nvidia.github.io/earth2studio/modules/generated/models/px/earth2studio.models.px.InterpModAFNO.html#earth2studio.models.px.InterpModAFNO).
A technical description of the model can be found in the paper ["Modulated Adaptive
Fourier Neural Operators for Temporal Interpolation of Weather
Forecasts"](https://arxiv.org/abs/2410.18904).

![Example of temporal interpolation of wind speed](../../../docs/img/temporal_interpolation.gif)

## Requirements

### Environment

You must have PhysicsNeMo installed on a GPU system. Training useful models, in
practice, requires a multi-GPU system; for the original model, 64 H100 GPUs were used.
Using the [PhysicsNeMo
container](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/physicsnemo/containers/physicsnemo)
is recommended.

Install the additional packages (MLFlow) needed by this example:

```bash
pip install -r requirements.txt
```

### Data

To train a temporal interpolation model, ensure you have the following:

* A dataset of yearly HDF5 files at one-hour resolution. For more details, refer to the
section ["Data Format and Structure" in the diagnostic model
example](https://github.com/NVIDIA/physicsnemo/blob/5a64525c40eada2248cd3eacee0a6ac4735ae380/examples/weather/diagnostic/README.md#data-format-and-structure).
These datasets can be very large. The dataset used to train the original model, with
73 variables from 1980 to 2017, is approximately 100 TB in size. The data used to
train the original model are on the ERA5 0.25 degree grid with shape `(721, 1440)` but
other resolutions can work too. The ERA5 data is freely accessible; a recommended
method to download it is the [ERA5 interface in
Earth2Studio](https://nvidia.github.io/earth2studio/modules/generated/data/earth2studio.data.CDS.html).
The data downloaded from this interface must then be inserted into the HDF5 file.
* Statistics files containing the mean and standard deviation of each channel in the
data files. They must be in the `stats/global_means.npy` and
`stats/global_stds.npy` files in your data directory. They must be `.npy` files
containing a 1D array with length equal to the number of variables in the dataset,
with each value giving the mean (for `global_means.npy`) or standard deviation (for
`global_stds.npy`) of the corresponding variable.
* A JSON file with metadata about the contents of the HDF5 files. Refer to the [data
sample](https://github.com/NVIDIA/physicsnemo/blob/main/examples/weather/temporal_interpolation/data/data.json)
for an example describing the dataset used to train the original model.
* Optional: NetCDF4 files containing the orography and land-sea mask for the grid
contained in the data. These should contain a variable of the same shape as the data.

## Configuration

The model training is controlled by YAML configuration files that are managed by
[Hydra](https://hydra.cc/), which is found in the `config` directory. The full
configuration for training of the original model is
[`train_interp.yaml`](https://github.com/NVIDIA/physicsnemo/blob/main/examples/weather/temporal_interpolation/config/train_interp.yaml).
[`train_interp_lite.yaml`](https://github.com/NVIDIA/physicsnemo/blob/main/examples/weather/temporal_interpolation/config/train_interp_lite.yaml)
runs a short test with a lightweight model, which is not expected to produce useful
checkpoints but can be used to verify that training runs without errors.

See the comments in the configuration files for an explanation of each configuration
parameter. To replicate the model from the paper, you only need to change the file and
directory paths to correspond to those on your system. If you train it with a custom
dataset, you might also need to change the `model.in_channels` and `model.out_channels`
parameters.

## Starting Training

Test the training by running the `train.py` script using the "lite" configuration file
on a system with a GPU:

```bash
python train.py --config-name=train_interp_lite.yaml
```

For a multi-GPU or multi-node training job, launch the training with the
`train_interp.yaml` configuration file using `torchrun` or MPI. For example, to train on
eight nodes with eight GPUs each, for a total of 64 GPUs, start a distributed compute
job (for example, using SLURM or Run:ai) and use:

```bash
torchrun --nnodes=8 --nproc-per-node=8 train.py --config-name=train_interp.yaml
```

Or the equivalent `mpirun` command. The code will automatically use all GPUs
available to the job. Remember to set `training.batch_size` in the configuration file to
the batch size *per process*.

Configuration parameters can be overridden from the command line using the Hydra syntax.
For instance, to set the optimizer learning rate to 0.0001 for the current run, you
can use:

```bash
torchrun --nnodes=8 --nproc-per-node=8 train.py --config-name=train_interp.yaml ++training.optimizer_params.lr=0.0001
```
64 changes: 64 additions & 0 deletions examples/weather/temporal_interpolation/config/train_interp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

model:
model_type: "modafno" # should always be "modafno"
model_name: "modafno-cplxscale-smallpatch" # name for the model
inp_shape: [720, 1440] # should be [720, 1440], must be divisible by patch_size
in_channels: null # number of input channels to the model, null determines it from datapipe
out_channels: null # number of output channels from the model, null determines it from datapipe
patch_size: [2,2] # size of AFNO patches
embed_dim: 512 # embedding dimension
mlp_ratio: 2.0 # multiplier for MLP hidden layer size (may be a non-integer value, e.g. 2.5)
num_blocks: 12 # number of AFNO blocks

scale_shift_mode: complex # type of numbers used for the ModAFNO modulation, "real" or "complex"
embed_model:
dim: 64 # width of time embedding net
depth: 1 # depth of time embedding net
method: sinusoidal # embedding type used in time embedding net, "sinusoidal" or "learned"

datapipe:
data_dir: "/data/era5-73varQ-hourly" # directory where data files are located
metadata_path: "/data/era5-73varQ-hourly/metadata/data.json" # directory to metadata json file
geopotential_filename: "/data/era5-wind_gust/invariants/orography.nc" # location of orography file
lsm_filename: "/data/era5-wind_gust/invariants/land_sea_mask.nc" # location of lsm file
use_latlon: True # when True, return latitude and longitude from datapipe
num_samples_per_year_train: null # number of training samples per year, null uses all available
num_samples_per_year_valid: 64 # number of validation samples per year
batch_size_train: 1 # batch size per GPU

training:
max_epoch: 120 # number of data "epochs" (each epoch we save a checkpoint, run validation, update LR)
samples_per_epoch: 50000 # number of samples per "epoch"
load_epoch: "latest" # int, null or "latest"; "latest" loads the most recent checkpoint in checkpoint_dir
checkpoint_dir: "/checkpoints/fcinterp/" # location where checkpoints are saved

optimizer_params:
lr: 5e-4 # learning rate
betas: [0.9, 0.95] # beta parameters for Adam

logging:
mlflow:
use_mlflow: True # when True, produce logs with mlflow
experiment_name: "Forecast interpolation model" # experiment name, can be set freely
user_name: "PhysicsNeMo User" # user name, can be set freely
wandb:
use_wandb: False # when True, produce logs with wandb
mode: "offline" # "online", "offline", or "disabled"
project: "Temporal-Interpolation-Training" # project name for wandb
entity: null # entity (username or team) for Weights & Biases
results_dir: "./wandb/" # directory to save wandb logs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Config file for testing training. Does a very short run with a small model.
# Can be used to test that training runs without errors, not expected to
# produce useful checkpoints.

model:
model_type: "modafno" # should always be "modafno"
model_name: "modafno-test" # name for the model
inp_shape: [720, 1440] # should be [720, 1440], must be divisible by patch_size
in_channels: null # number of input channels to the model, null determines it from datapipe
out_channels: null # number of output channels from the model, null determines it from datapipe
patch_size: [8,8] # size of AFNO patches
embed_dim: 64 # embedding dimension
mlp_ratio: 2.0 # multiplier for MLP hidden layer size (may be a non-integer value, e.g. 2.5)
num_blocks: 2 # number of AFNO blocks

scale_shift_mode: complex # type of numbers used for the ModAFNO modulation, "real" or "complex"
embed_model:
dim: 64 # width of time embedding net
depth: 1 # depth of time embedding net
method: sinusoidal # embedding type used in time embedding net, "sinusoidal" or "learned"

datapipe:
data_dir: "/data/era5-73varQ-hourly" # directory where data files are located
metadata_path: "/data/era5-73varQ-hourly/metadata/data.json" # directory to metadata json file
geopotential_filename: "/data/era5-wind_gust/invariants/orography.nc" # location of orography file
lsm_filename: "/data/era5-wind_gust/invariants/land_sea_mask.nc" # location of lsm file
use_latlon: True # when True, return latitude and longitude from datapipe
num_samples_per_year_train: null # number of training samples per year, null uses all available
num_samples_per_year_valid: 64 # number of validation samples per year
batch_size_train: 1 # batch size per GPU

training:
max_epoch: 4 # number of data "epochs" (each epoch we save a checkpoint, run validation, update LR)
samples_per_epoch: 50 # number of samples per "epoch"
load_epoch: "latest" # int, null or "latest"; "latest" loads the most recent checkpoint in checkpoint_dir
checkpoint_dir: "/checkpoints/fcinterp/" # location where checkpoints are saved

optimizer_params:
lr: 5e-4 # learning rate
betas: [0.9, 0.95] # beta parameters for Adam

logging:
mlflow:
use_mlflow: True # when True, produce logs with mlflow
experiment_name: "Forecast interpolation model" # experiment name, can be set freely
user_name: "PhysicsNeMo User" # user name, can be set freely
wandb:
use_wandb: False # when True, produce logs with wandb
mode: "offline" # "online", "offline", or "disabled"
project: "Temporal-Interpolation-Training" # project name for wandb
entity: null # entity (username or team) for Weights & Biases
results_dir: "./wandb/" # directory to save wandb logs
90 changes: 90 additions & 0 deletions examples/weather/temporal_interpolation/data/data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
{
"dataset_name": "73ch-hourly",
"attrs": {
"description": "ERA5 data at 1 hourly frequency with snapshots at every hour 0000, 0100, 0200, 0300, ..., 2300 UTC. First snapshot in each file is Jan 01 0000 UTC. "
},
"h5_path": "fields",
"dims": [
"time",
"channel",
"lat",
"lon"
],
"coords": {
"channel": [
"u10m",
"v10m",
"u100m",
"v100m",
"t2m",
"sp",
"msl",
"tcwv",
"u50",
"u100",
"u150",
"u200",
"u250",
"u300",
"u400",
"u500",
"u600",
"u700",
"u850",
"u925",
"u1000",
"v50",
"v100",
"v150",
"v200",
"v250",
"v300",
"v400",
"v500",
"v600",
"v700",
"v850",
"v925",
"v1000",
"z50",
"z100",
"z150",
"z200",
"z250",
"z300",
"z400",
"z500",
"z600",
"z700",
"z850",
"z925",
"z1000",
"t50",
"t100",
"t150",
"t200",
"t250",
"t300",
"t400",
"t500",
"t600",
"t700",
"t850",
"t925",
"t1000",
"q50",
"q100",
"q150",
"q200",
"q250",
"q300",
"q400",
"q500",
"q600",
"q700",
"q850",
"q925",
"q1000"
]
}
}
Loading