Skip to content

Conversation

jleinonen
Copy link
Collaborator

PhysicsNeMo Pull Request

Adds an example of training a temporal interpolation model withphysicsnemo.models.afno.ModAFNO to examples/weather/temporal_interpolation/. See the README.md located there for details.

Description

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

Copy link
Collaborator

@CharlelieLrt CharlelieLrt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First pass on the PR.
Overall looks fine, except a few things that are a little unclear here and there.
My biggest concern is the Trainer. I know this class is used in many of our examples, but here I don't really see the point. I would support having a dedicated abstraction for a trainer only if you expect users to import it into their own codebase and use it in their own external application.

install it by running

```bash
pip install mlflow
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a requirements.txt at the root dir of this example and put mlflow with minimum version constraint that you used.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

requirements.txt is now included, along with a note in the README to install dependencies using it.

Comment on lines 8 to 13
This PhysicsNeMo example shows how to train a ModAFNO-based temporal interpolation model
with a custom dataset. For access to the pre-trained model, see the [wrapper in
Earth2Studio](https://nvidia.github.io/earth2studio/modules/generated/models/px/earth2studio.models.px.InterpModAFNO.html#earth2studio.models.px.InterpModAFNO).
A technical description of the model can be found in the paper ["Modulated Adaptive
Fourier Neural Operators for Temporal Interpolation of Weather
Forecasts"](https://arxiv.org/abs/2410.18904).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you be more specific a give a few more lines of technical background there?
As a potential user I would like to have a short paragraph that gives me the big technical picture of what the model can do, without having to go to the paper.

For example, one question I have is whether this is a deterministic interpolation model or a probabilistic one?
If it's deterministic, isn't it a shortcoming?
If I want to apply it to a higher temporal super-resolution ratio, a fully deterministic model is most likely not going to work, correct?
Another point that it might be worth clarifying here:
How many snapshots after/before the targeted interpolation time is the model using?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a longer description in ae6eed1.

A technical description of the model can be found in the paper ["Modulated Adaptive
Fourier Neural Operators for Temporal Interpolation of Weather
Forecasts"](https://arxiv.org/abs/2410.18904).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a (preferably nice) picture with the results provided by the trained model

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Picture added in 5430eb4.

containing a 1D array with length equal to the number of variables in the dataset,
with each value giving the mean (for `global_means.npy`) or standard deviation (for
`global_stds.npy`) of the corresponding variable.
* A JSON file with metadata about the contents of the HDF5 files. See [here](data.json)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This link: [here](data.json) is going to break in the online docs. Use the full github url instead. This applies to all the links below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do I get the correct URL it'll have on Github after it's merged?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this file it should be something like https://github.com/NVIDIA/physicsnemo/blob/main/examples/weather/temporal_interpolation/data.json

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Links changed to full URLs in ae6eed1.

dataset, you may also need to change the `model.in_channels` and `model.out_channels`
parameters.

## Starting training
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also need a script to validate the trained model. Also a corresponding section in this README.md

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you propose to include in the validation script? If it was my project I would be quite happy having just simple online validation here and prefer to do more extensive validation using Earth2Studio.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I understand that E2S should be the main target for inference, but I also think that PhysicsNeMo examples should be "standalone". AFAIK all our examples have separate validation scripts.

Ideally, it would be good to take 2 timesteps T1 and T2, get all input snapshots between T1 and T2, and super-resolve the trajectory between the two. If that is too complicated, it could be something very simple, like basically repeating what is done in the online validation in train.py, but decoupled from the training itself. Being able to do validation. without relaunching the training script is still nice I think.

embed_model:
dim: 64 # width of time embedding net
depth: 1 # depth of time embedding net
method: sinusoidal # embedding type used in time embedding net
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other options available? If so list them

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theembed_model config maps directly to

def __init__(
self,
max_time: float = 1.0,
dim: int = 64,
depth: int = 1,
activation_fn: Type[nn.Module] = nn.GELU,
method: str = "sinusoidal",
):
and I think these are the only options it makes sense to expose to the config. I added a mention of the possible values of method in f56991e, though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I just meant for the method: sinusoidal, are there any other options? But it seems it's the only option accepted. So maybe add a comment "Only supports 'sinusoidal' for now".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It now says

  method: sinusoidal # embedding type used in time embedding net, "sinusoidal" or "learned"

Comment on lines 45 to 46
max_epoch: 120 # number of data "epochs"
samples_per_epoch: 50000 # number of samples per "epoch"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand these 2 fields. Is it not looping over the entire dataloader at every epoch?

Copy link
Collaborator Author

@jleinonen jleinonen Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, "epochs" are basically fictional, they just indicate how often we save checkpoints. Could reword this if it feels confusing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's just the frequency to save checkpoints, maybe something like that would be more clear?

# I/O configuration
io:
  checkpoint_freq: 10
  log_freq: 10

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I have misstated this a bit in the previous comment. Actually after each "epoch" we do these chores:

  1. Save checkpoint
  2. Run validation
  3. Update learning rate

We could add options to time all these separately, but it would add more complexity and the refactoring would potentially introduce bugs, so I'd prefer to avoid it. I did add a comment in the config file to specify what we mean by "epoch", though.

training:
max_epoch: 120 # number of data "epochs"
samples_per_epoch: 50000 # number of samples per "epoch"
checkpoint_dir: "/checkpoints/fcinterp/" # location where checkpoints are saved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a user, how would I restart a training? Is there anything I need to change in this config?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the default behavior to restart from the latest checkpoint, but it's probably a good idea to add the option here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this implicit behavior of restarting by default is confusing and it led me to making mistakes before. Usually I prefer to put an explicit flag:

# I/O configuration
io:
  # Whether to load a checkpoint
  load_checkpoint: false

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added load_epoch: "latest" in the config file as the explicit default, with a comment explaining the options ("latest" will start from scratch if there are no checkpoint files in the directory, so I think it's a good default because we can use the same config options for the first training run and subsequent ones).

geopotential_filename: "/data/era5-wind_gust/invariants/orography.nc" # location of orography file
lsm_filename: "/data/era5-wind_gust/invariants/land_sea_mask.nc" # location of lsm file
use_latlon: True # when True, return latitude and longitude from datapipe
num_samples_per_year_train: 8748 # number of training samples per year (8748 == 365 * 24 - 12)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the - 12?

Copy link
Collaborator Author

@jleinonen jleinonen Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The datapipe doesn't support loading data from two annual files. This is to keep the (t, t+6h) data tuple from going out of bounds.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok got it. I'm not sure it's possible, but could we do the same as for this comment? That is, just give the raw number of samples in the config and somewhere in the datapipe do the conversion to make sure it doesn't go out of bounds?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added an option of setting it to null in which case it is determined in the code, and made it the default.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants