Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stormcast Customization #799

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
df109aa
Add common dataloader interface
jleinonen Feb 12, 2025
e18291d
Merge branch 'NVIDIA:main' into stormcast-customization
jleinonen Feb 12, 2025
6ecfe69
Training script runs with refactored dataloader
jleinonen Feb 12, 2025
9b9e2e8
More trainer refactoring
jleinonen Feb 13, 2025
288f949
Refactor inference
jleinonen Feb 21, 2025
87daa62
Merge branch 'NVIDIA:main' into stormcast-customization
jleinonen Feb 21, 2025
0555a43
Add support for gradient accumulation
jleinonen Feb 25, 2025
f7f3dd9
Add support for AMP 16-bit training
jleinonen Feb 25, 2025
b215093
Align training parameters with StormCast paper
jleinonen Feb 26, 2025
cead183
Add comments to inference.py
jleinonen Mar 4, 2025
f7432e6
Add lite configs
jleinonen Mar 4, 2025
7f4dd21
Merge branch 'NVIDIA:main' into stormcast-customization
jleinonen Mar 18, 2025
18b0190
Add lite configs
jleinonen Mar 31, 2025
82e1638
Merge PhysicsNemo rename
jleinonen Mar 31, 2025
d65d1ad
Merge branch 'NVIDIA:main' into stormcast-customization
jleinonen Apr 1, 2025
fb361c4
Small bug fixes
jleinonen Apr 2, 2025
04ecddb
Merge branch 'stormcast-customization' of github.com:jleinonen/modulu…
jleinonen Apr 2, 2025
6e8afa2
Merge branch 'NVIDIA:main' into stormcast-customization
jleinonen Apr 2, 2025
8ba4e89
Add support for compiling model
jleinonen Apr 2, 2025
412653d
Merge branch 'stormcast-customization' of github.com:jleinonen/modulu…
jleinonen Apr 2, 2025
21ea0cc
Validation fixes
jleinonen Apr 2, 2025
1b3af75
Refactor checkpoint loading at startup
jleinonen Apr 4, 2025
96bcd2b
Support wandb offline mode
jleinonen Apr 4, 2025
2db8ac1
Fix regression_model_forward
jleinonen Apr 4, 2025
5efd529
Merge branch 'NVIDIA:main' into stormcast-customization
jleinonen Apr 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 103 additions & 35 deletions examples/generative/stormcast/config/dataset/hrrr_era5.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

name: data_loader_hrrr_era5.HrrrEra5Dataset # Dataset class module and name, in any module in `datasets` folder

# Main dataset
location: 'data' # Path to the dataset
conus_dataset_name: 'hrrr_v3' # Version name for the dataset
Expand All @@ -33,38 +35,104 @@ valid_years: [2022] # Years to use for validation
invariants: ["lsm", "orog"] # Invariant quantitites to include
input_channels: 'all' #'all' or list of channels to condition on
diffusion_channels: "all" #'all' or list of channels to condition on
exclude_channels: # Dataset channels to exclude from inputs/predicitons
- u35
- u40
- v35
- v40
- t35
- t40
- q35
- q40
- w1
- w2
- w3
- w4
- w5
- w6
- w7
- w8
- w9
- w10
- w11
- w13
- w15
- w20
- w25
- w30
- w35
- w40
- p25
- p30
- p35
- p40
- z35
- z40
- tcwv
- vil
kept_era5_channels: "all"
kept_hrrr_channels:
- u10m
- v10m
- t2m
- msl
- u1
- u2
- u3
- u4
- u5
- u6
- u7
- u8
- u9
- u10
- u11
- u13
- u15
- u20
- u25
- u30
- v1
- v2
- v3
- v4
- v5
- v6
- v7
- v8
- v9
- v10
- v11
- v13
- v15
- v20
- v25
- v30
- t1
- t2
- t3
- t4
- t5
- t6
- t7
- t8
- t9
- t10
- t11
- t13
- t15
- t20
- t25
- t30
- q1
- q2
- q3
- q4
- q5
- q6
- q7
- q8
- q9
- q10
- q11
- q13
- q15
- q20
- q25
- q30
- z1
- z2
- z3
- z4
- z5
- z6
- z7
- z8
- z9
- z10
- z11
- z13
- z15
- z20
- z25
- z30
- p1
- p2
- p3
- p4
- p5
- p6
- p7
- p8
- p9
- p10
- p11
- p13
- p15
- p20
- refc
5 changes: 4 additions & 1 deletion examples/generative/stormcast/config/diffusion.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,7 @@ model:
spatial_pos_embed: True

training:
loss: 'edm'
loss: 'edm'
total_train_steps: 450000 # use more training samples for diffusion training; follows StormCast paper
checkpoint_freq: 10000 # How often to save the checkpoints, measured in number of training steps
validation_freq: 10000 # how often to record the validation loss, measured in number of training steps
30 changes: 30 additions & 0 deletions examples/generative/stormcast/config/diffusion_lite.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Defaults
defaults:
- diffusion

# Minimal training config with frequent printouts and checkpoint saving.
# Can be used to test that training runs without errors.
# Do not use for real training runs.

training:
print_progress_freq: 5 # How often to print progress, measured in number of training steps
checkpoint_freq: 5 # How often to save the checkpoints, measured in number of training steps
validation_freq: 5 # how often to record the validation loss, measured in number of training steps
batch_size: 2
total_train_steps: 20
9 changes: 6 additions & 3 deletions examples/generative/stormcast/config/inference/stormcast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ initial_time: "2022-11-04T21:00:00" # datetime to intialize forecast with (YYYY-
n_steps: 12 # number of steps (in units of 1hr timesteps) to forecast

# I/O
plot_var_hrrr: "refc" # HRRR variable to plot
plot_var_era5: "t2m" # ERA5 variable to plot
output_hrrr_channels: [] # HRRR variables to save to disk (empty list == all channels saved)
plot_var_state: "refc" # state variable to plot
plot_var_background: "t2m" # background variable to plot
output_state_channels: [] # state variables to save to disk (empty list == all channels saved)
save_vertical_vars: ["u", "v", "t", "q", "z", "p", "w"] # variables with multiple vertical levels
save_vertical_levels: ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "13", "15", "20", "25", "30", "35", "40"] # level names for vertical variables
save_horizontal_vars: ["msl", "refc", "u10m", "v10m"] # single-level variables
30 changes: 30 additions & 0 deletions examples/generative/stormcast/config/regression_lite.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Defaults
defaults:
- regression

# Minimal training config with frequent printouts and checkpoint saving.
# Can be used to test that training runs without errors.
# Do not use for real training runs.

training:
print_progress_freq: 5 # How often to print progress, measured in number of training steps
checkpoint_freq: 5 # How often to save the checkpoints, measured in number of training steps
validation_freq: 5 # how often to record the validation loss, measured in number of training steps
batch_size: 2
total_train_steps: 20
19 changes: 13 additions & 6 deletions examples/generative/stormcast/config/training/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,27 @@ run_id: '0' # Unique ID to use for this training run
rundir: ./${training.outdir}/${training.experiment_name}/${training.run_id} # Path where experiement outputs will be saved
num_data_workers: 4 # Number of dataloader worker threads per proc
log_to_wandb: False # Whether or not to log to Weights & Biases (requires wandb account)
wandb_mode: "online" # logging mode, "online" or "offline"
seed: -1 # Specify a random seed by setting this to an int > 0
cudnn_benchmark: True # Enable/disable CuDNN benchmark mode
resume_checkpoint: null # Specify a path to a training checkpoint to resume from
resume_checkpoint: "latest" # epoch number to continue training from, or "latest" for the latest checkpoint
initial_weights: null # if not null, a .mdlus checkpoint to load weights at the start of training; no effect if training continues from a checkpoint

# Logging frequency
print_progress_freq: 5 # How often to print progress, measured in number of training steps
checkpoint_freq: 5 # How often to save the checkpoints, measured in number of training steps
validation_freq: 5 # how often to record the validation loss, measured in number of training steps
print_progress_freq: 100 # How often to print progress, measured in number of training steps
checkpoint_freq: 1000 # How often to save the checkpoints, measured in number of training steps
validation_freq: 1000 # how often to record the validation loss, measured in number of training steps

# Optimization hyperparameters
batch_size: 1 # Total training batch size -- must be >= (and divisble by) number of GPUs being used
batch_size: 64 # Total training batch size -- must be >= (and divisble by) number of GPUs being used
batch_size_per_gpu: "auto" # Batch size on each GPU, set to an int to force smaller local batch with gradient accumulation
lr: 4E-4 # Initial learning rate
lr_rampup_steps: 1000 # Number of training steps over which to perform linear LR warmup
total_train_steps: 20 # Number of total training steps
total_train_steps: 16000 # Number of total training steps, 16000 with batch size 64 corresponds to StormCast paper regression
clip_grad_norm: -1 # Threshold for gradient clipping, set to -1 to disable
loss: 'regression' # Loss type; use 'regression' or 'edm' for the regression and diffusion, respectively
fp_optimizations: fp32 # Floating point mode, one of ["fp32", "amp-fp16", "amp-bf16"]
compile_model: False # use torch.compile to compile model

# Validation options
validation_plot_variables: ["u10m", "v10m", "t2m", "refc", "q1", "q5", "q10"]
36 changes: 36 additions & 0 deletions examples/generative/stormcast/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import pkgutil

from .dataset import StormCastDataset


# Find StormCastDataset implementations found in files in the datasets directory
# and list them by module and name in the dataset_classes dict
dataset_modules = pkgutil.iter_modules(["datasets"])
dataset_modules = [mod.name for mod in dataset_modules if mod.name != "dataset"]
dataset_classes = {}
for mod_name in dataset_modules:
module = importlib.import_module(f"datasets.{mod_name}")
for (name, member) in module.__dict__.items():
if (
name != "StormCastDataset"
and isinstance(member, type)
and issubclass(member, StormCastDataset)
):
dataset_classes[f"{mod_name}.{name}"] = member
Loading