diff --git a/CHANGELOG.md b/CHANGELOG.md index 6044555ece..e4f493de58 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 to compose and define an active learning workflow is provided in `examples/active_learning`. The `moons` example provides a minimal (pedagogical) composition that is meant to illustrate how to define the necessary parts of the workflow. +- Added a new example for temporal interpolation of weather forecasts using ModAFNO. + Accessible in `examples/weather/temporal_interpolation`. ### Changed diff --git a/docs/examples_weather.rst b/docs/examples_weather.rst index 7033b1e10d..1bd2c81ddf 100644 --- a/docs/examples_weather.rst +++ b/docs/examples_weather.rst @@ -16,4 +16,5 @@ Weather and climate modeling examples using PhysicsNeMo. examples/weather/diagnostic/README.rst examples/weather/unified_recipe/README.rst examples/weather/corrdiff/README.rst - examples/weather/stormcast/README.rst \ No newline at end of file + examples/weather/stormcast/README.rst + examples/weather/temporal_interpolation/README.rst \ No newline at end of file diff --git a/docs/img/temporal_interpolation.gif b/docs/img/temporal_interpolation.gif new file mode 100644 index 0000000000..7df5db2176 Binary files /dev/null and b/docs/img/temporal_interpolation.gif differ diff --git a/examples/README.md b/examples/README.md index 1cc794c380..a3e482afb1 100644 --- a/examples/README.md +++ b/examples/README.md @@ -70,6 +70,7 @@ The several examples inside PhysicsNeMo can be classified based on their domains |[Medium-range global weather forecast using Mixture of Experts](./weather/mixture_of_experts/)|MoE Model| |[Generative Data Assimilation of Sparse Weather Observations](./weather/regen/)|Denoising Diffusion Model| |[Flood Forecasting](./weather/flood_modeling/)|GNN + KAN| +|[Temporal Interpolation of Weather Forecasts](./weather/temporal_interpolation/)|ModAFNO| ### Structural Mechanics diff --git a/examples/weather/temporal_interpolation/README.md b/examples/weather/temporal_interpolation/README.md new file mode 100644 index 0000000000..56080b1cd4 --- /dev/null +++ b/examples/weather/temporal_interpolation/README.md @@ -0,0 +1,138 @@ +# Earth-2 Temporal Interpolation Model + +The temporal interpolation model is used to increase the temporal resolution of AI-based +forecast models. These typically have a native temporal resolution of six hours; the +interpolation allows this to be improved to one hour. With appropriate training data, even +higher temporal resolutions might be achievable. + +This PhysicsNeMo example shows how to train a ModAFNO-based temporal interpolation model +with a custom dataset. This architecture uses an embedding network to determine +parameters for a shift and scale operation that is used to modify the behavior of the AFNO +network depending on a given conditioning variable. For temporal +interpolation, the atmospheric states at both ends of the interpolation interval are +passed as inputs along with some auxiliary data, such as orography, and the conditioning +indicates which time step between the endpoints will be generated by the model. The +interpolation is deterministic and trained with a latitude-weighted L2 loss. However, it +can still be used to produce probabilistic forecasts, if used to interpolate results of +probabilistic forecast models. More formally, the ModAFNO $f_{\theta}$ is a conditional +expected-value model that approximates: + +$$ +f_{\theta} (x_{t}, x_{t+T}, \Delta t) \approx +\mathbb{E} \left[ x_{t + \Delta t} | x_{t}, x_{t+T}, \Delta t \right] +$$ + +$0 \leq \Delta t \leq T$. In the pre-trained model, $T = 6$ hours and +$\Delta t \in \{0, 1, 2, 3, 4, 5, 6\}$ hours. + +For access to the pre-trained model, refer to the [wrapper in +Earth2Studio](https://nvidia.github.io/earth2studio/modules/generated/models/px/earth2studio.models.px.InterpModAFNO.html#earth2studio.models.px.InterpModAFNO). +A technical description of the model can be found in the paper ["Modulated Adaptive +Fourier Neural Operators for Temporal Interpolation of Weather +Forecasts"](https://arxiv.org/abs/2410.18904). + +![Example of temporal interpolation of wind speed](../../../docs/img/temporal_interpolation.gif) + +## Requirements + +### Environment + +You must have PhysicsNeMo installed on a GPU system. Training useful models, in +practice, requires a multi-GPU system; for the original model, 64 H100 GPUs were used. +Using the [PhysicsNeMo +container](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/physicsnemo/containers/physicsnemo) +is recommended. + +Install the additional packages (MLFlow) needed by this example: + +```bash +pip install -r requirements.txt +``` + +### Data + +To train a temporal interpolation model, ensure you have the following: + +* A dataset of yearly HDF5 files at one-hour resolution. For more details, refer to the + section ["Data Format and Structure" in the diagnostic model + example](https://github.com/NVIDIA/physicsnemo/blob/5a64525c40eada2248cd3eacee0a6ac4735ae380/examples/weather/diagnostic/README.md#data-format-and-structure). + These datasets can be very large. The dataset used to train the original model, with + 73 variables from 1980 to 2017, is approximately 100 TB in size. The data used to + train the original model are on the ERA5 0.25 degree grid with shape `(721, 1440)` but + other resolutions can work too. The ERA5 data is freely accessible; a recommended + method to download it is the [ERA5 interface in + Earth2Studio](https://nvidia.github.io/earth2studio/modules/generated/data/earth2studio.data.CDS.html). + The data downloaded from this interface must then be inserted into the HDF5 file. +* Statistics files containing the mean and standard deviation of each channel in the + data files. They must be in the `stats/global_means.npy` and + `stats/global_stds.npy` files in your data directory. They must be `.npy` files + containing a 1D array with length equal to the number of variables in the dataset, + with each value giving the mean (for `global_means.npy`) or standard deviation (for + `global_stds.npy`) of the corresponding variable. +* A JSON file with metadata about the contents of the HDF5 files. Refer to the [data + sample](https://github.com/NVIDIA/physicsnemo/blob/main/examples/weather/temporal_interpolation/data/data.json) + for an example describing the dataset used to train the original model. +* Optional: NetCDF4 files containing the orography and land-sea mask for the grid + contained in the data. These should contain a variable of the same shape as the data. + +## Configuration + +The model training is controlled by YAML configuration files that are managed by +[Hydra](https://hydra.cc/), which is found in the `config` directory. The full +configuration for training of the original model is +[`train_interp.yaml`](https://github.com/NVIDIA/physicsnemo/blob/main/examples/weather/temporal_interpolation/config/train_interp.yaml). +[`train_interp_lite.yaml`](https://github.com/NVIDIA/physicsnemo/blob/main/examples/weather/temporal_interpolation/config/train_interp_lite.yaml) +runs a short test with a lightweight model, which is not expected to produce useful +checkpoints but can be used to verify that training runs without errors. + +See the comments in the configuration files for an explanation of each configuration +parameter. To replicate the model from the paper, you only need to change the file and +directory paths to correspond to those on your system. If you train it with a custom +dataset, you might also need to change the `model.in_channels` and `model.out_channels` +parameters. + +## Starting Training + +Test the training by running the `train.py` script using the "lite" configuration file +on a system with a GPU: + +```bash +python train.py --config-name=train_interp_lite.yaml +``` + +For a multi-GPU or multi-node training job, launch the training with the +`train_interp.yaml` configuration file using `torchrun` or MPI. For example, to train on +eight nodes with eight GPUs each, for a total of 64 GPUs, start a distributed compute +job (for example, using SLURM or Run:ai) and use: + +```bash +torchrun --nnodes=8 --nproc-per-node=8 train.py --config-name=train_interp.yaml +``` + +Or the equivalent `mpirun` command. The code will automatically use all GPUs +available to the job. Remember to set `training.batch_size` in the configuration file to +the batch size *per process*. + +Configuration parameters can be overridden from the command line using the Hydra syntax. +For instance, to set the optimizer learning rate to 0.0001 for the current run, you +can use: + +```bash +torchrun --nnodes=8 --nproc-per-node=8 train.py --config-name=train_interp.yaml ++training.optimizer_params.lr=0.0001 +``` + +## Validation + +To evaluate checkpoints, you can use the `validate.py` script. The script computes a +histogram of squared errors as a function of the interpolation step (+0 h to +6 h), +which can be used to produce a plot similar to Figure 3 of the paper. The validation +uses the same configuration files as training, with validation-specific options passed +through the `validation` configuration group. Refer to the docstring of `error_by_time` +in `validate.py` for the recognized options. + +For example, to run the validation of a model trained with `train_interp.yaml` and save +the resulting error histogram to `validation.nc`: + +```bash +python validate.py --config-name="train_interp" ++validation.output_path=validation.nc +``` diff --git a/examples/weather/temporal_interpolation/config/train_interp.yaml b/examples/weather/temporal_interpolation/config/train_interp.yaml new file mode 100644 index 0000000000..54df9c9ee4 --- /dev/null +++ b/examples/weather/temporal_interpolation/config/train_interp.yaml @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +model: + model_type: "modafno" # should always be "modafno" + model_name: "modafno-cplxscale-smallpatch" # name for the model + inp_shape: [720, 1440] # should be [720, 1440], must be divisible by patch_size + in_channels: null # number of input channels to the model, null determines it from datapipe + out_channels: null # number of output channels from the model, null determines it from datapipe + patch_size: [2,2] # size of AFNO patches + embed_dim: 512 # embedding dimension + mlp_ratio: 2.0 # multiplier for MLP hidden layer size (may be a non-integer value, e.g. 2.5) + num_blocks: 12 # number of AFNO blocks + + scale_shift_mode: complex # type of numbers used for the ModAFNO modulation, "real" or "complex" + embed_model: + dim: 64 # width of time embedding net + depth: 1 # depth of time embedding net + method: sinusoidal # embedding type used in time embedding net, "sinusoidal" or "learned" + +datapipe: + data_dir: "/data/era5-73varQ-hourly" # directory where data files are located + metadata_path: "/data/era5-73varQ-hourly/metadata/data.json" # directory to metadata json file + geopotential_filename: "/data/era5-wind_gust/invariants/orography.nc" # location of orography file + lsm_filename: "/data/era5-wind_gust/invariants/land_sea_mask.nc" # location of lsm file + use_latlon: True # when True, return latitude and longitude from datapipe + num_samples_per_year_train: null # number of training samples per year, null uses all available + num_samples_per_year_valid: 64 # number of validation samples per year + batch_size_train: 1 # batch size per GPU + +training: + max_epoch: 120 # number of data "epochs" (each epoch we save a checkpoint, run validation, update LR) + samples_per_epoch: 50000 # number of samples per "epoch" + load_epoch: "latest" # int, null or "latest"; "latest" loads the most recent checkpoint in checkpoint_dir + checkpoint_dir: "/checkpoints/fcinterp/" # location where checkpoints are saved + +optimizer_params: + lr: 5e-4 # learning rate + betas: [0.9, 0.95] # beta parameters for Adam + +logging: + mlflow: + use_mlflow: True # when True, produce logs with mlflow + experiment_name: "Forecast interpolation model" # experiment name, can be set freely + user_name: "PhysicsNeMo User" # user name, can be set freely + wandb: + use_wandb: False # when True, produce logs with wandb + mode: "offline" # "online", "offline", or "disabled" + project: "Temporal-Interpolation-Training" # project name for wandb + entity: null # entity (username or team) for Weights & Biases + results_dir: "./wandb/" # directory to save wandb logs diff --git a/examples/weather/temporal_interpolation/config/train_interp_lite.yaml b/examples/weather/temporal_interpolation/config/train_interp_lite.yaml new file mode 100644 index 0000000000..bb4a210993 --- /dev/null +++ b/examples/weather/temporal_interpolation/config/train_interp_lite.yaml @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Config file for testing training. Does a very short run with a small model. +# Can be used to test that training runs without errors, not expected to +# produce useful checkpoints. + +model: + model_type: "modafno" # should always be "modafno" + model_name: "modafno-test" # name for the model + inp_shape: [720, 1440] # should be [720, 1440], must be divisible by patch_size + in_channels: null # number of input channels to the model, null determines it from datapipe + out_channels: null # number of output channels from the model, null determines it from datapipe + patch_size: [8,8] # size of AFNO patches + embed_dim: 64 # embedding dimension + mlp_ratio: 2.0 # multiplier for MLP hidden layer size (may be a non-integer value, e.g. 2.5) + num_blocks: 2 # number of AFNO blocks + + scale_shift_mode: complex # type of numbers used for the ModAFNO modulation, "real" or "complex" + embed_model: + dim: 64 # width of time embedding net + depth: 1 # depth of time embedding net + method: sinusoidal # embedding type used in time embedding net, "sinusoidal" or "learned" + +datapipe: + data_dir: "/data/era5-73varQ-hourly" # directory where data files are located + metadata_path: "/data/era5-73varQ-hourly/metadata/data.json" # directory to metadata json file + geopotential_filename: "/data/era5-wind_gust/invariants/orography.nc" # location of orography file + lsm_filename: "/data/era5-wind_gust/invariants/land_sea_mask.nc" # location of lsm file + use_latlon: True # when True, return latitude and longitude from datapipe + num_samples_per_year_train: null # number of training samples per year, null uses all available + num_samples_per_year_valid: 64 # number of validation samples per year + batch_size_train: 1 # batch size per GPU + +training: + max_epoch: 4 # number of data "epochs" (each epoch we save a checkpoint, run validation, update LR) + samples_per_epoch: 50 # number of samples per "epoch" + load_epoch: "latest" # int, null or "latest"; "latest" loads the most recent checkpoint in checkpoint_dir + checkpoint_dir: "/checkpoints/fcinterp/" # location where checkpoints are saved + +optimizer_params: + lr: 5e-4 # learning rate + betas: [0.9, 0.95] # beta parameters for Adam + +logging: + mlflow: + use_mlflow: True # when True, produce logs with mlflow + experiment_name: "Forecast interpolation model" # experiment name, can be set freely + user_name: "PhysicsNeMo User" # user name, can be set freely + wandb: + use_wandb: False # when True, produce logs with wandb + mode: "offline" # "online", "offline", or "disabled" + project: "Temporal-Interpolation-Training" # project name for wandb + entity: null # entity (username or team) for Weights & Biases + results_dir: "./wandb/" # directory to save wandb logs diff --git a/examples/weather/temporal_interpolation/data/data.json b/examples/weather/temporal_interpolation/data/data.json new file mode 100644 index 0000000000..e9593b8449 --- /dev/null +++ b/examples/weather/temporal_interpolation/data/data.json @@ -0,0 +1,90 @@ +{ + "dataset_name": "73ch-hourly", + "attrs": { + "description": "ERA5 data at 1 hourly frequency with snapshots at every hour 0000, 0100, 0200, 0300, ..., 2300 UTC. First snapshot in each file is Jan 01 0000 UTC. " + }, + "h5_path": "fields", + "dims": [ + "time", + "channel", + "lat", + "lon" + ], + "coords": { + "channel": [ + "u10m", + "v10m", + "u100m", + "v100m", + "t2m", + "sp", + "msl", + "tcwv", + "u50", + "u100", + "u150", + "u200", + "u250", + "u300", + "u400", + "u500", + "u600", + "u700", + "u850", + "u925", + "u1000", + "v50", + "v100", + "v150", + "v200", + "v250", + "v300", + "v400", + "v500", + "v600", + "v700", + "v850", + "v925", + "v1000", + "z50", + "z100", + "z150", + "z200", + "z250", + "z300", + "z400", + "z500", + "z600", + "z700", + "z850", + "z925", + "z1000", + "t50", + "t100", + "t150", + "t200", + "t250", + "t300", + "t400", + "t500", + "t600", + "t700", + "t850", + "t925", + "t1000", + "q50", + "q100", + "q150", + "q200", + "q250", + "q300", + "q400", + "q500", + "q600", + "q700", + "q850", + "q925", + "q1000" + ] + } +} diff --git a/examples/weather/temporal_interpolation/datapipe/climate_interp.py b/examples/weather/temporal_interpolation/datapipe/climate_interp.py new file mode 100644 index 0000000000..635ef2ea25 --- /dev/null +++ b/examples/weather/temporal_interpolation/datapipe/climate_interp.py @@ -0,0 +1,205 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime, timedelta +import time + +import numpy as np +import nvidia.dali as dali + +from physicsnemo.datapipes.climate.climate import ( + ClimateDatapipe, + ClimateDaliExternalSource, + ClimateHDF5DaliExternalSource, +) + + +class InterpDaliExternalSource(ClimateDaliExternalSource): + """ + Data source specialized for interpolation training. + + Parameters + ---------- + *args : tuple + Positional arguments passed to parent class. + all_steps : bool, optional + Whether to return all steps in the sequence. Default is False. + **kwargs : dict + Keyword arguments passed to parent class. + """ + + def __init__(self, *args, all_steps: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.all_steps = all_steps + + def __call__( + self, sample_info: dali.types.SampleInfo + ) -> tuple[np.ndarray, np.ndarray]: + """ + Get data from source. + + Parameters + ---------- + sample_info : dali.types.SampleInfo + Information about the sample to retrieve. + + Returns + ------- + state_seq : np.ndarray + Sequence of training data. + timestamps : np.ndarray + Accompanying timestamps for the sequence. + """ + + if sample_info.iteration >= self.num_batches: + raise StopIteration() + + # Shuffle before the next epoch starts + if self.shuffle and sample_info.epoch_idx != self.last_epoch: + print("Shuffling indices") + np.random.shuffle(self.indices) + self.last_epoch = sample_info.epoch_idx + + # Get local indices from global index + # TODO: This is very hacky, but it works for now + idx = self.indices[sample_info.idx_in_epoch] + year_idx = idx // self.num_samples_per_year + in_idx = idx % self.num_samples_per_year + + # quasi-unique deterministic seed for each sample + seed = ( + (sample_info.epoch_idx << 32) + + (sample_info.idx_in_epoch << 16) + + sample_info.idx_in_batch + ) + + interp_idx = np.random.default_rng(seed=seed).integers(self.stride + 1) + if self.all_steps: + steps = np.arange(self.stride + 1) + else: + steps = np.array([0, self.stride, interp_idx]) + state_seq = self._load_sequence(year_idx, in_idx, steps) + + # Load sequence of timestamps + year = self.start_year + year_idx + start_time = datetime(year, 1, 1) + timedelta(hours=int(in_idx) * self.dt) + timestamps = np.array( + [(start_time + timedelta(hours=i * self.dt)).timestamp() for i in steps] + ) + return state_seq, timestamps + + def __len__(self) -> int: + return len(self.indices) // self.stride + + +class InterpHDF5DaliExternalSource( + ClimateHDF5DaliExternalSource, InterpDaliExternalSource +): + """ + DALI source for reading HDF5 formatted climate data files. + + Specialized for interpolation training with HDF5 climate data. + + Parameters + ---------- + *args : tuple + Positional arguments passed to parent classes. + **kwargs : dict + Keyword arguments passed to parent classes. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _get_read_buffer(self, steps: list[int], data) -> np.ndarray: + """Get memory buffer for reading data.""" + shape = (len(steps), len(self.chans)) + data.shape[-2:] + return np.empty(shape, dtype=np.float32) + + def _load_sequence( + self, year_idx: int, idx: int, steps: np.ndarray, num_retries: int = 10 + ) -> np.ndarray: + """ + Load sequence of data for interpolation training. + + Parameters + ---------- + year_idx : int + The index of the yearly data file. + idx : int + The starting index of the data sequence in the yearly file. + steps : np.ndarray + Array of index offsets relative to idx (e.g. [0, 6, 2]). + num_retries : int, optional + Number of times to retry in case of IO failure. Default is 10. + + Returns + ------- + np.ndarray + Data of shape (len(steps), num_channels, height, width). + """ + + # the data is returned in a (time, channels, height, width) shape + data = self._get_data_file(year_idx)["fields"] + + seq = self._get_read_buffer(steps, data) + steps = list(steps) # so we can use .index() + for step_idx, s in enumerate(steps): + first_step_idx = steps.index(s) + if first_step_idx != step_idx: + # when two steps are the same, copy previous to avoid redundant data I/O + seq[step_idx] = seq[first_step_idx] + else: + for retry_num in range(num_retries + 1): + try: + # equivalent to: seq[step_idx] = data[idx + s] + data.read_direct(seq, np.s_[idx + s], np.s_[step_idx]) + break + except BlockingIOError: + # Some systems have had occasional IO issues that can often be + # resolved by retrying + if retry_num == num_retries: + raise + else: + print( + f"IO error reading year_idx={year_idx} idx={idx}, retrying in 5 sec..." + ) + time.sleep(5) + return seq + + +class InterpClimateDatapipe(ClimateDatapipe): + """ + Extends ClimateDatapipe to use interpolation source. + """ + + def _source_cls_from_type(self, source_type: str) -> type[InterpDaliExternalSource]: + """ + Get the external source class based on a string descriptor. + + Parameters + ---------- + source_type : str + String identifier for the source type (e.g., 'hdf5'). + + Returns + ------- + type[InterpDaliExternalSource] + The appropriate external source class for the given type. + """ + return { + "hdf5": InterpHDF5DaliExternalSource, + }[source_type] diff --git a/examples/weather/temporal_interpolation/requirements.txt b/examples/weather/temporal_interpolation/requirements.txt new file mode 100644 index 0000000000..cf7cc3b60f --- /dev/null +++ b/examples/weather/temporal_interpolation/requirements.txt @@ -0,0 +1,2 @@ +mlflow>=2.1.1 +wandb>=0.13.7 \ No newline at end of file diff --git a/examples/weather/temporal_interpolation/train.py b/examples/weather/temporal_interpolation/train.py new file mode 100644 index 0000000000..dcc94ed069 --- /dev/null +++ b/examples/weather/temporal_interpolation/train.py @@ -0,0 +1,440 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import datetime +from typing import Any, Dict +import warnings + +import hydra +from omegaconf import OmegaConf +import torch +import wandb + +from physicsnemo import Module +from physicsnemo.datapipes.climate.climate import ClimateDataSourceSpec +from physicsnemo.datapipes.climate.utils import invariant +from physicsnemo.distributed import DistributedManager +from physicsnemo.launch.logging import LaunchLogger +from physicsnemo.launch.logging.mlflow import initialize_mlflow +from physicsnemo.launch.logging.wandb import initialize_wandb +from physicsnemo.models.afno import ModAFNO +from physicsnemo.launch.utils import load_checkpoint + +from datapipe.climate_interp import InterpClimateDatapipe +from utils import distribute, loss +from utils.trainer import Trainer + +try: + from apex.optimizers import FusedAdam +except ImportError: + warnings.warn("Apex is not installed, defaulting to PyTorch optimizers.") + + +def setup_datapipes( + *, + data_dir: str, + dist_manager: DistributedManager, + metadata_path: str, + geopotential_filename: str | None = None, + lsm_filename: str | None = None, + use_latlon: bool = True, + num_samples_per_year_train: int | None = None, + num_samples_per_year_valid: int = 4, + batch_size_train: int = 4, + batch_size_valid: int | None = None, + num_workers: int = 4, + valid_subdir: str = "test", + valid_start_year: int = 2017, + valid_shuffle: bool = False, +) -> tuple[InterpClimateDatapipe, InterpClimateDatapipe, int]: + """ + Setup datapipes for training. + + The arguments passed to this function can be modified in the 'datapipe' section + of the config. + + Parameters + ---------- + data_dir : str + Path to data directory. + dist_manager : DistributedManager + An initialized DistributedManager instance. + metadata_path : str + Path to metadata file. + geopotential_filename : str or None, optional + Path to NetCDF file with global geopotential on the 0.25 deg grid. + lsm_filename : str or None, optional + Path to NetCDF file with global land-sea mask on the 0.25 deg grid. + use_latlon : bool, optional + If True, will return latitude and longitude from the datapipe. + num_samples_per_year_train : int or None, optional + Number of training samples per year, if None will use all available samples. + num_samples_per_year_valid : int, optional + Number of validation samples per year. + batch_size_train : int, optional + Batch size per GPU for training. + batch_size_valid : int or None, optional + Batch size per GPU for validation, when None equal to batch_size_train. + num_workers : int, optional + Number of datapipe workers per training process. + valid_subdir : str, optional + Subdirectory in data_dir where validation data is found. + valid_start_year : int, optional + Starting year for validation data. + valid_shuffle : bool, optional + When True, shuffle order of validation set; recommend setting to False + for consistent validation results. + + Returns + ------- + tuple of (InterpClimateDatapipe, InterpClimateDatapipe, int) + Tuple of training datapipe and validation datapipe, and the number of auxiliary channels. + """ + if batch_size_valid is None: + batch_size_valid = batch_size_train + + train_dir = os.path.join(data_dir, "train") + valid_dir = os.path.join(data_dir, valid_subdir) + mean_file = os.path.join(data_dir, "stats/global_means.npy") + std_file = os.path.join(data_dir, "stats/global_stds.npy") + + spec_kwargs: Dict[str, Any] = dict( + stats_files={"mean": mean_file, "std": std_file}, + use_cos_zenith=True, + name="atmos", + metadata_path=metadata_path, + stride=6, + ) + + spec_train = ClimateDataSourceSpec(data_dir=train_dir, **spec_kwargs) + spec_valid = ClimateDataSourceSpec(data_dir=valid_dir, **spec_kwargs) + + invariants = {} + num_aux_channels = 3 # 3 channels for cos_zenith + if use_latlon: + invariants["latlon"] = invariant.LatLon() + num_aux_channels += 4 + if geopotential_filename is not None: + invariants["geopotential"] = invariant.FileInvariant(geopotential_filename, "Z") + num_aux_channels += 1 + if lsm_filename is not None: + invariants["land_sea_mask"] = invariant.FileInvariant(lsm_filename, "LSM") + num_aux_channels += 1 + + pipe_kwargs = dict( + invariants=invariants, + crop_window=((0, 720), (0, 1440)), + num_workers=num_workers, + device=dist_manager.device, + dt=1.0, + ) + + if num_samples_per_year_train is None: + num_samples_per_year_train = 365 * 24 - 12 # -12 to prevent overflow + + pipe_train = InterpClimateDatapipe( + [spec_train], + batch_size=batch_size_train, + num_samples_per_year=num_samples_per_year_train, + process_rank=dist_manager.rank, + world_size=dist_manager.world_size, + **pipe_kwargs, + ) + + pipe_valid = InterpClimateDatapipe( + [spec_valid], + batch_size=batch_size_valid, + num_samples_per_year=num_samples_per_year_valid, + shuffle=valid_shuffle, + start_year=valid_start_year, + **pipe_kwargs, + ) + + return (pipe_train, pipe_valid, num_aux_channels) + + +# Default parameters if not overridden by config +default_model_params = { + "modafno": { + "inp_shape": (720, 1440), + "in_channels": 155, + "out_channels": 73, + "patch_size": (8, 8), + "embed_dim": 768, + "depth": 12, + "num_blocks": 8, + } +} + + +def setup_model( + num_variables: int, num_auxiliaries: int, model_cfg: dict | None = None +) -> Module: + """ + Setup interpolation model. + + Parameters + ---------- + num_variables : int + Number of atmospheric variables in the model. + num_auxiliaries : int + Number of auxiliary input channels. + model_cfg : dict or None, optional + Model configuration dict. + + Returns + ------- + Module + Model object. + """ + if model_cfg is None: + model_cfg = {} + model_type = model_cfg.pop("model_type", "modafno") + if model_type != "modafno": + raise ValueError( + "Model types other than 'modafno' are not currently supported." + ) + if model_cfg.get("in_channels") is None: + model_cfg["in_channels"] = 2 * num_variables + num_auxiliaries + if model_cfg.get("out_channels") is None: + model_cfg["out_channels"] = num_variables + model_name = model_cfg.pop("model_name") + model_kwargs = default_model_params[model_type].copy() + model_kwargs.update(model_cfg) + if model_type == "modafno": + model = ModAFNO(**model_kwargs) + + if model_name is not None: + model.meta.name = model_name + + return model + + +def setup_optimizer( + model: torch.nn.Module, + max_epoch: int, + opt_cls: type[torch.optim.Optimizer] | None = None, + opt_params: dict | None = None, + scheduler_cls: type[torch.optim.lr_scheduler.LRScheduler] | None = None, + scheduler_params: dict[str, Any] | None = None, +) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: + """Setup optimizer. + + Parameters + ---------- + model : torch.nn.Module + Model that optimizer is applied to. + max_epoch : int + Maximum number of training epochs (used for scheduler setup). + opt_cls : type[torch.optim.Optimizer] or None, optional + Optimizer class. When None, will setup apex.optimizers.FusedAdam + if available, otherwise PyTorch Adam. + opt_params : dict or None, optional + Dict of parameters (e.g. learning rate) to pass to optimizer. + scheduler_cls : type[torch.optim.lr_scheduler.LRScheduler] or None, optional + Scheduler class. When None, will setup CosineAnnealingLR. + scheduler_params : dict[str, Any] or None, optional + Dict of parameters to pass to scheduler. + + Returns + ------- + tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler] + The initialized optimizer and learning rate scheduler. + """ + + opt_kwargs = {"lr": 0.0005} + if opt_params is not None: + opt_kwargs.update(opt_params) + if opt_cls is None: + try: + opt_cls = FusedAdam + except NameError: # in case we don't have apex + opt_cls = torch.optim.Adam + + scheduler_kwargs = {} + if scheduler_cls is None: + scheduler_cls = torch.optim.lr_scheduler.CosineAnnealingLR + scheduler_kwargs["T_max"] = max_epoch + if scheduler_params is not None: + scheduler_kwargs.update(scheduler_params) + + optimizer = opt_cls(model.parameters(), **opt_kwargs) + scheduler = scheduler_cls(optimizer, **scheduler_kwargs) + return (optimizer, scheduler) + + +@torch.no_grad() +def input_output_from_batch_data( + batch: list[dict[str, torch.Tensor]], time_scale: float = 6 * 3600.0 +) -> tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """ + Convert the datapipe output dict to model input and output batches. + + Parameters + ---------- + batch : list[dict[str, torch.Tensor]] + The list data dicts returned by the datapipe. + time_scale : float, optional + Number of seconds between the interpolation endpoints (default 6 hours). + + Returns + ------- + tuple + Nested tuple in the form ((input, time), output). + """ + batch = batch[0] + # Concatenate all input variables to a single tensor + atmos_vars = batch["state_seq-atmos"] + + atmos_vars_in = [atmos_vars[:, 0], atmos_vars[:, 1]] + if "cos_zenith-atmos" in batch: + atmos_vars_in = atmos_vars_in + [batch["cos_zenith-atmos"].squeeze(dim=2)] + if "latlon" in batch: + atmos_vars_in = atmos_vars_in + [batch["latlon"]] + if "geopotential" in batch: + atmos_vars_in = atmos_vars_in + [batch["geopotential"]] + if "land_sea_mask" in batch: + atmos_vars_in = atmos_vars_in + [batch["land_sea_mask"]] + atmos_vars_in = torch.cat(atmos_vars_in, dim=1) + + atmos_vars_out = atmos_vars[:, 2] + + time = batch["timestamps-atmos"] + # Normalize time coordinate + time = (time[:, -1:] - time[:, :1]).to(dtype=torch.float32) / time_scale + + return ((atmos_vars_in, time), atmos_vars_out) + + +def setup_trainer(**cfg: dict) -> Trainer: + """ + Setup training environment. + + Parameters + ---------- + **cfg : dict + The configuration dict passed from hydra. + + Returns + ------- + Trainer + The Trainer object for training the interpolation model. + """ + + DistributedManager.initialize() + + # Setup datapipes + (train_datapipe, valid_datapipe, num_aux_channels) = setup_datapipes( + **cfg["datapipe"], + dist_manager=DistributedManager(), + ) + + # Setup model + model = setup_model( + num_variables=len(train_datapipe.sources[0].variables), + num_auxiliaries=num_aux_channels, + model_cfg=cfg["model"], + ) + (model, dist_manager) = distribute.distribute_model(model) + + # Setup optimizer and learning rate scheduler + (optimizer, scheduler) = setup_optimizer( + model, + cfg["training"].get("max_epoch", 1), + opt_params=cfg.get("optimizer_params", {}), + scheduler_params=cfg.get("scheduler_params", {}), + ) + + # Initialize mlflow + mlflow_cfg = cfg.get("logging", {}).get("mlflow", {}) + if mlflow_cfg.pop("use_mlflow", False): + initialize_mlflow(**mlflow_cfg) + LaunchLogger.initialize(use_mlflow=True) + + # Initialize wandb + use_wandb = False + wandb_cfg = cfg.get("logging", {}).get("wandb", {}) + if wandb_cfg.get("use_wandb", False): + use_wandb = True + timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + # Get checkpoint directory + checkpoint_dir = cfg.get("training", {}).get("checkpoint_dir") + # Check if we need to resume from checkpoint + wandb_id = None + resume = None + load_epoch = cfg.get("training", {}).get("load_epoch") + if checkpoint_dir is not None and load_epoch is not None: + metadata = {"wandb_id": None} + load_checkpoint(checkpoint_dir, metadata_dict=metadata) + wandb_id = metadata.get("wandb_id") + if wandb_id is not None: + resume = "must" + + initialize_wandb( + project=wandb_cfg.get("project", "Temporal-Interpolation-Training"), + entity=wandb_cfg.get("entity"), + mode=wandb_cfg.get("mode", "offline"), + config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=False), + results_dir=wandb_cfg.get("results_dir", "./wandb/"), + wandb_id=wandb_id, + resume=resume, + save_code=True, + name=f"train-{timestamp}", + init_timeout=600, + ) + + # Setup training loop + loss_func = loss.GeometricL2Loss(num_lats_cropped=cfg["model"]["inp_shape"][0]).to( + device=dist_manager.device + ) + trainer = Trainer( + model, + dist_manager=dist_manager, + loss=loss_func, + train_datapipe=train_datapipe, + valid_datapipe=valid_datapipe, + input_output_from_batch_data=input_output_from_batch_data, + optimizer=optimizer, + scheduler=scheduler, + use_wandb=use_wandb, + **cfg["training"], + ) + + return trainer + + +@hydra.main(version_base=None, config_path="config") +def main(cfg): + """ + Main entry point for training the interpolation model. + + Parameters + ---------- + cfg : DictConfig + Hydra configuration object. + """ + trainer = setup_trainer(**OmegaConf.to_container(cfg)) + trainer.fit() + + # Finish wandb logging if it was used + use_wandb = cfg.get("logging", {}).get("wandb", {}).get("use_wandb", False) + if use_wandb: + wandb.finish() + + +if __name__ == "__main__": + main() diff --git a/examples/weather/temporal_interpolation/utils/distribute.py b/examples/weather/temporal_interpolation/utils/distribute.py new file mode 100644 index 0000000000..c409b9008e --- /dev/null +++ b/examples/weather/temporal_interpolation/utils/distribute.py @@ -0,0 +1,59 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from physicsnemo.distributed import DistributedManager +import torch + + +def distribute_model( + model: torch.nn.Module, +) -> tuple[torch.nn.Module, DistributedManager]: + """ + Initialize DistributedManager and distribute model to multiple processes with DDP. + + Parameters + ---------- + model : torch.nn.Module + The PyTorch model to be distributed across multiple processes. + + Returns + ------- + tuple[torch.nn.Module, DistributedManager] + A tuple containing: + - model : torch.nn.Module + The model, wrapped with DistributedDataParallel if needed. + - dist : DistributedManager + The initialized DistributedManager instance. + """ + if not DistributedManager.is_initialized(): + DistributedManager.initialize() + + dist = DistributedManager() + model = model.to(dist.device) + + if dist.world_size > 1: + ddps = torch.cuda.Stream() + with torch.cuda.stream(ddps): + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[dist.local_rank], + output_device=dist.device, + broadcast_buffers=dist.broadcast_buffers, + find_unused_parameters=dist.find_unused_parameters, + ) + torch.cuda.current_stream().wait_stream(ddps) + + return (model, dist) diff --git a/examples/weather/temporal_interpolation/utils/loss.py b/examples/weather/temporal_interpolation/utils/loss.py new file mode 100644 index 0000000000..ffe17474cf --- /dev/null +++ b/examples/weather/temporal_interpolation/utils/loss.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + + +class GeometricL2Loss(nn.Module): + """Latitude weighted L2 (MSE) loss. + + Parameters + ---------- + lat_range : tuple[float, float], optional + Range of latitudes to cover, by default (-90.0, 90.0) + num_lats : int, optional + Number of latitudes in lat_range, by default 721 + num_lats_cropped : int, optional + Use the first num_lats_cropped latitudes, by default 720 + input_dims : int, optional + Number of dimensions in the input tensors passed to `forward`, by + default 4. + + Forward + ------- + pred : torch.Tensor + Predicted values, shape (..., num_lats_cropped, num_lons), + number of dimensions must equal ``input_dims`` + true : torch.Tensor + True values, shape equal to pred + + Outputs + ------- + torch.Tensor + The computed loss + """ + + def __init__( + self, + lat_range: tuple[float, float] = (-90.0, 90.0), + num_lats: int = 721, + num_lats_cropped: int = 720, + input_dims: int = 4, + ): + super().__init__() + + lats = torch.linspace(lat_range[0], lat_range[1], num_lats) + if lat_range[0] == -90: # special handling for poles + lats[0] = 0.5 * (lats[0] + lats[1]) + if lat_range[1] == 90: + lats[-1] = 0.5 * (lats[-2] + lats[-1]) + lats = torch.deg2rad(lats[:num_lats_cropped]) + weights = torch.cos(lats) + weights = weights / torch.sum(weights) + weights = torch.reshape( + weights, (1,) * (input_dims - 2) + (num_lats_cropped, 1) + ) + self.register_buffer("weights", weights) + + def forward(self, pred: torch.Tensor, true: torch.Tensor) -> torch.Tensor: + if not (pred.ndim == true.ndim == self.weights.ndim): + raise ValueError( + "Shape mismatch: pred, true and weights must have the same number of dimensions." + ) + if pred.shape != true.shape: + raise ValueError("Shape mismatch: pred and true must have the same shape") + err = torch.square(pred - true) + err = torch.sum(err * self.weights, dim=-2) + return torch.mean(err) diff --git a/examples/weather/temporal_interpolation/utils/trainer.py b/examples/weather/temporal_interpolation/utils/trainer.py new file mode 100644 index 0000000000..4558589397 --- /dev/null +++ b/examples/weather/temporal_interpolation/utils/trainer.py @@ -0,0 +1,338 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable, Sequence +from typing import Any, Literal +import time + +import torch +import wandb + +from physicsnemo import Module +from physicsnemo.datapipes.climate.climate import ClimateDatapipe +from physicsnemo.distributed.manager import DistributedManager +from physicsnemo.utils import StaticCaptureTraining, StaticCaptureEvaluateNoGrad +from physicsnemo.launch.logging import LaunchLogger, PythonLogger +from physicsnemo.launch.utils import load_checkpoint, save_checkpoint + + +class Trainer: + """Training loop. + + Parameters + ---------- + model : Module + Model to train. + dist_manager : DistributedManager + Initialized DistributedManager. + loss : torch.nn.Module + Loss function. + train_datapipe : ClimateDatapipe + ClimateDatapipe providing training data. + valid_datapipe : ClimateDatapipe + ClimateDatapipe providing validation data. + samples_per_epoch : int + Number of samples to draw from the datapipe per 'epoch'. + optimizer : torch.optim.Optimizer + Optimizer used for training. + scheduler : torch.optim.lr_scheduler.LRScheduler + Learning rate scheduler. + input_output_from_batch_data : Callable, optional + Function that converts datapipe outputs to training batches. + If not provided, will try to use outputs as-is. + max_epoch : int, optional + The last training epoch. + load_epoch : int, "latest", or None, optional + Which epoch to load. Options: + - "latest": continue from latest checkpoint in checkpoint_dir + - int: continue from the specified epoch + - None: start from scratch + checkpoint_every : int, optional + Save checkpoint every N epochs. + checkpoint_dir : str or None, optional + The directory where checkpoints are saved. + validation_callbacks : Sequence[Callable], optional + Optional callables to execute on validation. Signature: + callback(outvar_true, outvar_pred, epoch=epoch, batch_idx=batch_idx). + use_wandb : bool, optional + When True, log metrics to Weights & Biases. + """ + + def __init__( + self, + model: Module, + dist_manager: DistributedManager, + loss: torch.nn.Module, + train_datapipe: ClimateDatapipe, + valid_datapipe: ClimateDatapipe, + samples_per_epoch: int, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + input_output_from_batch_data: Callable = lambda x: x, + max_epoch: int = 1, + load_epoch: int | Literal["latest"] | None = "latest", + checkpoint_every: int = 1, + checkpoint_dir: str | None = None, + validation_callbacks: Sequence[Callable] = (), + use_wandb: bool = False, + ): + self.model = model + self.dist_manager = dist_manager + self.loss = loss + self.train_datapipe = train_datapipe + self.train_iterator = iter(self.train_datapipe) + self.valid_datapipe = valid_datapipe + self.max_epoch = max_epoch + self.input_output_from_batch_data = input_output_from_batch_data + self.optimizer = optimizer + self.lr_scheduler = scheduler + self.validation_callbacks = validation_callbacks + self.device = self.dist_manager.device + self.logger = PythonLogger() + self.use_wandb = use_wandb + + self.checkpoint_every = checkpoint_every + self.checkpoint_dir = checkpoint_dir + self.epoch = 1 + self.total_samples_trained = 0 + if load_epoch is not None: + epoch = None if load_epoch == "latest" else load_epoch + self.load_checkpoint(epoch=epoch) + self.epoch += 1 + + # wrap capture here instead of using decorator so it'll still be wrapped if + # overridden by a subclass + self.train_step_forward = StaticCaptureTraining( + model=self.model, + optim=self.optimizer, + logger=self.logger, + use_graphs=False, # use_graphs=True seems crash prone + )(self._train_step_forward) + + self.eval_step = StaticCaptureEvaluateNoGrad( + model=self.model, logger=self.logger, use_graphs=False + )(self._eval_step) + + self.local_batches_per_epoch = samples_per_epoch // ( + train_datapipe.world_size * train_datapipe.batch_size + ) + + def _eval_step(self, invar: tuple) -> torch.Tensor: + """Evaluate model for one step. + + Parameters + ---------- + invar : tuple + The inputs to the model, packed into a tuple. + + Returns + ------- + torch.Tensor + The output of the model. + """ + return self.model(*invar) + + def _train_step_forward( + self, invar: tuple, outvar_true: torch.Tensor + ) -> torch.Tensor: + """Training step. + + Parameters + ---------- + invar : tuple + Model inputs packed into a tuple. + outvar_true : torch.Tensor + Correct output value. + + Returns + ------- + torch.Tensor + Model loss on the given data. + """ + outvar_pred = self.model(*invar) + return self.loss(outvar_pred, outvar_true) + + def fit(self): + """Main function for training loop.""" + # Log initial learning rate to wandb + use_wandb_log = self.use_wandb and self.dist_manager.rank == 0 + if use_wandb_log: + current_lr = self.optimizer.param_groups[0]["lr"] + wandb.log({"lr": current_lr, "epoch": self.epoch - 1}) + + for self.epoch in range(self.epoch, self.max_epoch + 1): + epoch_loss = 0.0 + epoch_samples = 0 + time_start = time.time() + + with LaunchLogger( + "train", + epoch=self.epoch, + num_mini_batch=self.local_batches_per_epoch, + epoch_alert_freq=10, + ) as log: + for _ in range(self.local_batches_per_epoch): + try: + batch = next(self.train_iterator) + except StopIteration: + self.train_iterator = iter(self.train_datapipe) + batch = next(self.train_iterator) + loss = self.train_step_forward( + *self.input_output_from_batch_data(batch) + ) + log.log_minibatch({"loss": loss.detach()}) + + # Track loss for epoch average + batch_size = self.train_datapipe.batch_size + epoch_loss += loss.item() * batch_size + epoch_samples += batch_size + + # Log batch-level metrics to wandb + if use_wandb_log: + current_lr = self.optimizer.param_groups[0]["lr"] + wandb.log({"batch_loss": loss.item(), "lr": current_lr}) + + log.log_epoch({"Learning Rate": self.optimizer.param_groups[0]["lr"]}) + + # Compute epoch statistics + time_end = time.time() + mean_loss = epoch_loss / epoch_samples if epoch_samples > 0 else 0.0 + self.total_samples_trained += epoch_samples + + # Log epoch-level metrics to wandb + if use_wandb_log: + current_lr = self.optimizer.param_groups[0]["lr"] + metrics = { + "epoch": self.epoch, + "mean_loss": mean_loss, + "time_per_epoch": time_end - time_start, + "lr": current_lr, + "total_samples_trained": self.total_samples_trained, + "epoch_samples": epoch_samples, + } + wandb.log(metrics) + + # Validation + if self.dist_manager.rank == 0: + with LaunchLogger("valid", epoch=self.epoch) as log: + error = self.validate_on_epoch() + log.log_epoch({"Validation error": error}) + + # Log validation metrics to wandb + if use_wandb_log: + val_loss = error.item() if torch.is_tensor(error) else error + val_metrics = { + "val_loss": val_loss, + "epoch": self.epoch, + "total_samples_trained": (self.total_samples_trained), + } + wandb.log(val_metrics) + + if self.dist_manager.world_size > 1: + torch.distributed.barrier() + + self.lr_scheduler.step() + + checkpoint_epoch = (self.checkpoint_dir is not None) and ( + (self.epoch % self.checkpoint_every == 0) + or (self.epoch == self.max_epoch) + ) + if checkpoint_epoch and self.dist_manager.rank == 0: + # Save Modulus Launch checkpoint + self.save_checkpoint() + + if self.dist_manager.rank == 0: + self.logger.info("Finished training!") + + @torch.no_grad() + def validate_on_epoch(self) -> torch.Tensor: + """Compute loss and metrics over one validation epoch. + + Returns + ------- + torch.Tensor + Validation loss as a tensor. + """ + loss_epoch = 0 + num_examples = 0 # Number of validation examples + # Dealing with DDP wrapper + if hasattr(self.model, "module"): + model = self.model.module + else: + model = self.model + + try: + model.eval() + for i, batch in enumerate(self.valid_datapipe): + (invar, outvar_true) = self.input_output_from_batch_data(batch) + invar = tuple(v.detach() for v in invar) + outvar_true = outvar_true.detach() + outvar_pred = self.eval_step(invar) + + loss_epoch += self.loss(outvar_pred, outvar_true) + num_examples += 1 + + for callback in self.validation_callbacks: + callback(outvar_true, outvar_pred, epoch=self.epoch, batch_idx=i) + finally: # restore train state even if exception occurs + model.train() + return loss_epoch / num_examples + + def load_checkpoint(self, epoch: int | None = None) -> int: + """Try to load model state from a checkpoint. + + Do nothing if a checkpoint is not found in self.checkpoint_dir. + + Parameters + ---------- + epoch : int or None, optional + The number of epoch to load. When None, the latest epoch is loaded. + + Returns + ------- + int + The epoch of the loaded checkpoint, or 0 if no checkpoint was found. + """ + if self.checkpoint_dir is None: + raise ValueError("checkpoint_dir must be set in order to load checkpoints.") + metadata = {} + self.epoch = load_checkpoint( + self.checkpoint_dir, + models=self.model, + optimizer=self.optimizer, + scheduler=self.lr_scheduler, + device=self.device, + epoch=epoch, + metadata_dict=metadata, + ) + self.total_samples_trained = metadata.get("total_samples_trained", 0) + return self.epoch + + def save_checkpoint(self): + """Save current model state as a checkpoint.""" + if self.checkpoint_dir is None: + raise ValueError("checkpoint_dir must be set in order to save checkpoints.") + metadata = {"total_samples_trained": self.total_samples_trained} + if self.use_wandb and wandb.run is not None: + metadata["wandb_id"] = wandb.run.id + save_checkpoint( + self.checkpoint_dir, + models=self.model, + optimizer=self.optimizer, + scheduler=self.lr_scheduler, + epoch=self.epoch, + metadata=metadata, + ) diff --git a/examples/weather/temporal_interpolation/validate.py b/examples/weather/temporal_interpolation/validate.py new file mode 100644 index 0000000000..695ecedb8c --- /dev/null +++ b/examples/weather/temporal_interpolation/validate.py @@ -0,0 +1,293 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Generator, Literal + +import hydra +from omegaconf import DictConfig, OmegaConf +import numpy as np +import torch +import xarray as xr + +from train import input_output_from_batch_data, setup_trainer, Trainer + + +def setup_analysis( + cfg: dict, checkpoint: str | None = None, shuffle: bool = False +) -> Trainer: + """ + Setup trainer for validation analysis. + + Parameters + ---------- + cfg : dict + Configuration dictionary. + checkpoint : str or None, optional + Path to model checkpoint file. + shuffle : bool, optional + Whether to shuffle validation data. + + Returns + ------- + Trainer + Configured trainer instance. + """ + cfg["datapipe"]["num_samples_per_year_valid"] = cfg["datapipe"][ + "num_samples_per_year_train" + ] + cfg["datapipe"]["batch_size_valid"] = 1 + cfg["datapipe"]["valid_shuffle"] = shuffle + + trainer = setup_trainer(**cfg) + if checkpoint is not None: + trainer.model.load(checkpoint) + + return trainer + + +@torch.no_grad() +def inference_model( + trainer: Trainer, + timesteps: int = 6, + denorm: bool = True, + method: Literal["fcinterp", "linear"] = "fcinterp", +) -> Generator[tuple[torch.Tensor, torch.Tensor, int], None, None]: + """ + Run inference on validation data. + + Parameters + ---------- + trainer : Trainer + Trainer instance containing model and datapipe. + timesteps : int, optional + Number of timesteps between interpolation endpoints. + denorm : bool, optional + Whether to denormalize outputs. + method : {"fcinterp", "linear"}, optional + Interpolation method to use. + + Yields + ------ + tuple[torch.Tensor, torch.Tensor, int] + True values, predicted values, and timestep index for each batch. + """ + for batch in trainer.valid_datapipe: + y_true_step = [] + y_pred_step = [] + (invar, outvar_true) = input_output_from_batch_data(batch) + invar = tuple(v.detach() for v in invar) + outvar_true = outvar_true.detach() + y_true_step.append(outvar_true) + step = min(int(round(invar[1].item() * timesteps)), timesteps) + if method == "fcinterp": + y_pred_step.append(trainer.eval_step(invar)) + elif method == "linear": + y_pred_step.append(linear_interp_batch_data(batch, step)) + + y_true = torch.stack(y_true_step, dim=1) + y_pred = torch.stack(y_pred_step, dim=1) + if denorm: + y_true = denormalize(trainer, y_true) + y_pred = denormalize(trainer, y_pred) + + yield (y_true, y_pred, step) + + +def linear_interp_batch_data( + batch: list[dict[str, torch.Tensor]], step: int +) -> torch.Tensor: + """ + Perform linear interpolation on batch data. + + Parameters + ---------- + batch : list[dict[str, torch.Tensor]] + Batch data from datapipe (list containing a dictionary). + step : int + Timestep index for interpolation. + + Returns + ------- + torch.Tensor + Linearly interpolated atmospheric variables. + """ + atmos_vars = batch[0]["state_seq-atmos"] + x0 = atmos_vars[:, 0] + x1 = atmos_vars[:, -1] + alpha = step / (atmos_vars.shape[1] - 1) + return (1 - alpha) * x0 + alpha * x1 + + +def denormalize(trainer: Trainer, y: torch.Tensor) -> torch.Tensor: + """ + Denormalize predictions using dataset statistics. + + Parameters + ---------- + trainer : Trainer + Trainer instance containing datapipe with statistics. + y : torch.Tensor + Normalized tensor to denormalize. + + Returns + ------- + torch.Tensor + Denormalized tensor. + """ + mean = torch.Tensor(trainer.valid_datapipe.sources[0].mu).to(device=y.device)[ + :, None, ... + ] + std = torch.Tensor(trainer.valid_datapipe.sources[0].sd).to(device=y.device)[ + :, None, ... + ] + return y * std + mean + + +def error_by_time( + cfg: dict, + checkpoint: str | None = None, + timesteps: int = 6, + method: Literal["fcinterp", "linear"] = "fcinterp", + max_error: float = 1.0, + nbins: int = 10000, + n_samples: int = 1000, +) -> tuple[list[torch.Tensor], torch.Tensor]: + """ + Compute error statistics for each interpolation step. The error + is computed as the squared difference of the prediction and truth + and is area-weighted (i.e. multiplied by the cosine of the latitude). + It is calculated on the values normalized to zero mean and unit variance, + so that errors of all variables are comparable. + + Parameters + ---------- + cfg : dict + The configuration dict passed from hydra. + checkpoint : str or None, optional + Path to model checkpoint file. + timesteps : int, optional + Number of timesteps between interpolation endpoints. + method : {"fcinterp", "linear"}, optional + Interpolation method to use. + max_error : float, optional + Maximum error value for histogram bins. + nbins : int, optional + Number of histogram bins. + n_samples : int, optional + Number of samples to process. + + Returns + ------- + tuple[list[torch.Tensor], torch.Tensor] + Histogram counts for each timestep and bin edges. + """ + trainer = setup_analysis(cfg=cfg, checkpoint=checkpoint) + + lat = torch.linspace(90, -90, 721)[:-1].to(device=trainer.model.device) + lat[0] = 0.5 * (lat[0] + lat[1]) + cos_lat = torch.cos(lat * (torch.pi / 180))[None, None, :, None] + + bins = torch.linspace(0, max_error, nbins + 1) + + def _hist(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + err = (y_true - y_pred) ** 2 + weights = torch.ones_like(err) * cos_lat + return torch.histogram( + err.ravel().cpu(), bins=bins, weight=weights.ravel().cpu() + )[0] + + hist_counts = [ + torch.zeros(nbins, dtype=torch.float64) for _ in range(timesteps + 1) + ] + + for i_sample, (y_true, y_pred, step) in enumerate( + inference_model(trainer, timesteps=timesteps, denorm=False, method=method) + ): + if i_sample % 100 == 0: + print(f"{i_sample}/{n_samples}") + + hist_counts_step = _hist(y_true[:, -1, ...], y_pred[:, -1, ...]) + hist_counts[step] += hist_counts_step + + if i_sample + 1 >= n_samples: + break + + return (hist_counts, bins) + + +def save_histogram( + hist_counts: list[torch.Tensor], bins: torch.Tensor, output_path: str +) -> None: + """ + Save histogram data to netCDF4 file. + + Parameters + ---------- + hist_counts : list[torch.Tensor] + List of histogram counts for each timestep. + bins : torch.Tensor + Bin edges for the histogram. + output_path : str + Path to output netCDF4 file. + """ + # Convert torch tensors to numpy + hist_counts_np = np.stack([h.cpu().numpy() for h in hist_counts], axis=0) + bins_np = bins.cpu().numpy() + + # Compute bin centers from edges + bin_centers = (bins_np[:-1] + bins_np[1:]) / 2 + + # Create xarray Dataset + ds = xr.Dataset( + { + "hist_counts": (["timestep", "bin"], hist_counts_np), + "bin_edges": (["bin_edge"], bins_np), + }, + coords={ + "timestep": np.arange(len(hist_counts)), + "bin": bin_centers, + "bin_edge": bins_np, + }, + attrs={ + "description": "Histogram of squared errors for temporal interpolation", + "created": datetime.now().isoformat(), + }, + ) + + # Save to netCDF4 + ds.to_netcdf(output_path, format="NETCDF4") + print(f"Histogram saved to {output_path}") + + +@hydra.main(version_base=None, config_path="config") +def main(cfg: DictConfig): + """ + Run validation for interpolation error as a function of step. + + Parameters + ---------- + cfg : DictConfig + Hydra configuration object. + """ + cfg = OmegaConf.to_container(cfg) + validation_cfg = cfg.pop("validation") + output_path = validation_cfg.pop("output_path") + (hist_counts, bins) = error_by_time(cfg, **validation_cfg) + save_histogram(hist_counts, bins, output_path) + + +if __name__ == "__main__": + main()