Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cordiff usability and performance enhancements for custom dataset training #790

Open
wants to merge 45 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
eadd8f5
Add recent checkpoints option, adjust configs
pzharrington Jan 28, 2025
1ae9e7f
Doc for deterministic_sampler
CharlelieLrt Feb 4, 2025
d61aa08
Typo fix
CharlelieLrt Feb 5, 2025
934c1f3
Bugfix and cleanup of corrdiff regression loss and UNet
CharlelieLrt Feb 6, 2025
f120055
Minor fix in docstrings
CharlelieLrt Feb 6, 2025
a7f0836
Bugfix + doc for corrdiff regression CE loss
CharlelieLrt Feb 6, 2025
984adae
Refactor corrdiff configs for custom dataset
CharlelieLrt Feb 8, 2025
11207f7
Bugfix in configs
CharlelieLrt Feb 10, 2025
344ab6c
Added info in corrdiff docs for custom training
CharlelieLrt Feb 11, 2025
a0c59b0
Minor change in corrdiff config
CharlelieLrt Feb 11, 2025
c244e53
bring back base config file removed by mistake
CharlelieLrt Feb 11, 2025
b6a7c2d
Added config for generation on custom dataset
CharlelieLrt Feb 12, 2025
a6c40e1
Forgot some config files
CharlelieLrt Feb 12, 2025
62e6e50
Fixed overlap pixel in custom config based on discussion in PR #703
CharlelieLrt Feb 12, 2025
c1d082c
Corrdiff fixes to enable non-squared images and/or non-square patches…
CharlelieLrt Feb 12, 2025
f8a1c17
Fix small bug in config
CharlelieLrt Feb 12, 2025
d7588ac
Removed arguments redundancy in patching utilities + fixed hight-widt…
CharlelieLrt Feb 13, 2025
3d30e2a
Cleanup
CharlelieLrt Feb 14, 2025
47a054d
Added tests for rectangle images and patches
CharlelieLrt Feb 14, 2025
ddd2f4d
Added wandb logging for corrdiff training
CharlelieLrt Feb 14, 2025
fede749
Implements patching API. Refactors corrdiff train abnd generate to us…
CharlelieLrt Feb 20, 2025
0ad3c01
Corrdiff function to register new custom dataset
CharlelieLrt Feb 20, 2025
2f906da
Reorganize configs again
CharlelieLrt Feb 22, 2025
3c7f80a
Correction in configs: training duration is NOT in kilo images
CharlelieLrt Feb 24, 2025
d366de0
Readme re-write
CharlelieLrt Feb 25, 2025
b0ad80f
Merge branch 'origin/main'
CharlelieLrt Feb 25, 2025
ae4692f
Updated CHANGELOG
CharlelieLrt Feb 25, 2025
0365019
Fixed formatting
CharlelieLrt Feb 26, 2025
8dff626
Test fixes
CharlelieLrt Feb 26, 2025
a1e5f13
Typo fix
CharlelieLrt Feb 26, 2025
bee4727
Fixes on patching API
CharlelieLrt Feb 28, 2025
aa1d969
Fixed patching bug and tests
CharlelieLrt Feb 28, 2025
e799df9
Simplifications in corrdiff diffusion step
CharlelieLrt Mar 1, 2025
e871773
Forgot to propagate change to test for cordiff diffusion step
CharlelieLrt Mar 1, 2025
9cdabb8
Renamed patching API to explicit 2D
CharlelieLrt Mar 1, 2025
a57c6dc
Fixed shape in test
CharlelieLrt Mar 3, 2025
2e9ce25
Replace loops with fold/unfold patching for perf
CharlelieLrt Mar 4, 2025
2049e7e
Added method to dynamically change number of patches in RandomPatching
CharlelieLrt Mar 4, 2025
638b24e
Adds safety checks for patch shapes in patching function. Fixes tests
CharlelieLrt Mar 5, 2025
f5d3bca
Fixes docs
CharlelieLrt Mar 5, 2025
706f614
Forgot a fix in docs
CharlelieLrt Mar 5, 2025
bc49e05
New embedding selection strategy in CorrDiff UNet models
CharlelieLrt Mar 6, 2025
43cbfff
Updated CHANGELOG.md
CharlelieLrt Mar 6, 2025
deb7bec
Fixed tests for SongUNet position emneddings
CharlelieLrt Mar 7, 2025
1c70ade
More robust tests for patching
CharlelieLrt Mar 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- DrivAerML dataset support in FIGConvNet example.
- Retraining recipe for DoMINO from a pretrained model checkpoint
- Added Datacenter CFD use case.
- General purpose patching API for patch-based diffusion
- New positional embedding selection strategy for CorrDiff SongUNet models

### Changed

@@ -25,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Updated utils in `modulus.launch.logging` to avoid unnecessary `wandb` and `mlflow` imports
- Moved to experiment-based Hydra config in Lagrangian-MGN example
- Make data caching optional in `MeshDatapipe`
- Simplified CorrDiff config files, updated default values
- Refactored CorrDiff losses and samplers to use the patching API
- Support for non-square images and patches in patch-based diffusion

### Deprecated

15 changes: 13 additions & 2 deletions docs/api/modulus.utils.rst
Original file line number Diff line number Diff line change
@@ -40,7 +40,11 @@ Filesystem utils
Generative utils
----------------

.. automodule:: modulus.utils.generative.sampler
.. automodule:: modulus.utils.generative.deterministic_sampler
:members:
:show-inheritance:

.. automodule:: modulus.utils.generative.stochastic_sampler
:members:
:show-inheritance:

@@ -66,4 +70,11 @@ Weather / Climate utils
:show-inheritance:

.. automodule:: modulus.utils.zenith_angle
:show-inheritance:
:show-inheritance:

Patching utils
--------------

.. automodule:: modulus.utils.patching
:members:
:show-inheritance:
Binary file added docs/img/corrdiff_training_loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
578 changes: 428 additions & 150 deletions examples/generative/corrdiff/README.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -13,8 +13,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

type: hrrr_mini
data_path: /data/corrdiff-mini/hrrr_mini_train.nc
stats_path: /data/corrdiff-mini/stats.json
output_variables: ['10u', '10v']
Original file line number Diff line number Diff line change
@@ -14,23 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

hydra:
job:
chdir: true
name: generation
run:
dir: ./outputs/${hydra:job.name}

# Get defaults
defaults:

# Dataset
- dataset/cwb_generate

# Sampler
- sampler/stochastic
#- sampler/deterministic

# Generation
- generation/base
#- generation/patched_based
# Dataset type. Must be overridden.
type: ???
# Path to .nc data file. Must be overridden.
data_path: ???
# Path to json stats file. Must be overriden.
stats_path: ???
# Names of input channels. Must be overridden.
input_variables: ???
# Names of output channels. Must be overridden.
output_variables: ???
# Names of invariants variables. Optional.
invariant_variables: ???
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
@@ -15,15 +14,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Dataset type. Do not modify.
type: cwb
data_path: /code/2023-01-24-cwb-4years.zarr
# Path to data file. Must be overridden.
data_path: ???
# Indices of input channels
in_channels: [0, 1, 2, 3, 4, 9, 10, 11, 12, 17, 18, 19]
# Indices of output channels
out_channels: [0, 1, 2, 3]
# Shape of the image
img_shape_x: 448
img_shape_y: 448
# Add grid coordinates to the image
add_grid: true
# Factor to downscale the image
ds_factor: 4
# Path to min and max values of the data
min_path: null
max_path: null
# Path to global means of the data
global_means_path: null
# Path to global stds of the data
global_stds_path: null
Original file line number Diff line number Diff line change
@@ -14,13 +14,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Dataset type. Do not modify.
type: gefs_hrrr
data_path: /data
stats_path: /data/stats.json
# Path to .nc data file. Must be overridden.
data_path: ???
# Path to json stats file. Must be overriden.
stats_path: ???
# Names of output channels.
output_variables: ["u10m", "v10m", "t2m", "precip", "cat_snow", "cat_ice", "cat_freez", "cat_rain", "cat_none"]
# Names of probability variables.
prob_variables: ["cat_snow", "cat_ice", "cat_freez", "cat_rain"]
# Names of input surface variables.
input_surface_variables: ["u10m", "v10m", "t2m", "q2m", "sp", "msl", "precipitable_water"]
# Names of input isobaric variables.
input_isobaric_variables: ['u1000', 'u925', 'u850', 'u700', 'u500', 'u250', 'v1000', 'v925', 'v850', 'v700', 'v500', 'v250', 'z1000', 'z925', 'z850', 'z700', 'z500', 'z200', 't1000', 't925', 't850', 't700', 't500', 't100', 'r1000', 'r925', 'r850', 'r700', 'r500', 'r100']
# Factor to downscale the image.
ds_factor: 4
train: False
hrrr_window: [[1,1057], [4,1796]] # need dims to be divisible by 16 [[0,1024], [0,1024]]
# Years to train the model.
train_years: [2020, 2021, 2022, 2023]
# Years to validate the model.
valid_years: [2024]
# Whether to normalize the data.
normalize: True
# Whether to shard the data.
shard: False
overfit: False
# Whether to use all the data.
use_all: False
sample_shape: [-1, -1]
hrrr_window: [[1,1057], [4,1796]] # need dims to be divisible by 16
Original file line number Diff line number Diff line change
@@ -14,8 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

name: lt_aware_ce_regression
# Name of the preconditioner
hr_mean_conditioning: False
# High-res mean (regression's output) as additional condition

# Dataset type
type: hrrr_mini
# Path to .nc data file. Must be overridden.
data_path: ???
# Path to json stats file. Must be overriden.
stats_path: ???
# Names of output channels. Must be overridden.
output_variables: ['10u', '10v']
Original file line number Diff line number Diff line change
@@ -14,35 +14,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

num_ensembles: 64
# Number of ensembles to generate per input
seed_batch_size: 4
# Size of the batched inference
defaults:
- sampler: stochastic
# Recommended is stochastic sampler. Change to deterministic if needed.

num_ensembles: ???
# Number of ensembles to generate per input. Should be overridden.
seed_batch_size: ???
# Size of the batched inference. Should be overridden.
inference_mode: all
# Choose between "all" (regression + diffusion), "regression" or "diffusion"
patch_size: 448
patch_shape_x: 448
patch_shape_y: 448
# Patch size. Patch-based sampling will be utilized if these dimensions differ from
# img_shape_x and img_shape_y
overlap_pixels: 4
# Number of overlapping pixels between adjacent patches
boundary_pixels: 2
# Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary
# artifact.
# Choose between "all" (regression + diffusion), "regression" or "diffusion"
hr_mean_conditioning: true
gridtype: learnable
N_grid_channels: 100
sample_res: full
# Sampling resolution
times_range: null
times:
- 2021-02-02T00:00:00
- 2021-03-02T00:00:00
- 2021-04-02T00:00:00
# hurricane
- 2021-09-12T00:00:00
- 2021-09-12T12:00:00
# Whether to use hr_mean_conditioning
times_range: ???
# Time range to generate. Should be overridden.
has_lead_time: False
# Whether the model has lead time.

perf:
force_fp16: false
@@ -55,9 +42,3 @@ perf:
num_writer_workers: 1
# number of workers to use for writing file
# To support multiple workers a threadsafe version of the netCDF library must be used

io:
res_ckpt_filename: diffusion_checkpoint.mdlus
# Checkpoint filename for the diffusion model
reg_ckpt_filename: regression_checkpoint.mdlus
# Checkpoint filename for the mean predictor model
Original file line number Diff line number Diff line change
@@ -15,9 +15,6 @@
# limitations under the License.

defaults:
- corrdiff_regression
- base_all

model_args:
model_channels: 64
channel_mult: [1, 2, 2]
attn_resolutions: [16]
patching: False
Original file line number Diff line number Diff line change
@@ -14,21 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

hydra:
job:
chdir: true
name: gefs_hrrr_regression
run:
dir: ./outputs/${hydra:job.name}

# Get defaults
defaults:
- base_all

# Dataset
- dataset/gefs_hrrr

# Model
- model/corrdiff_regression_gefs_hrrr

# Training
- training/corrdiff_regression_gefs_hrrr
patching: True
# Use patch-based sampling
overlap_pix: 4
# Number of overlapping pixels between adjacent patches
boundary_pix: 2
# Number of boundary pixels to be cropped out. 2 is recommended to address the boundary
# artifact.
patch_shape_x: ???
patch_shape_y: ???
# Patch size. Patch-based sampling will be utilized if these dimensions
# differ from img_shape_x and img_shape_y. Needs to be overridden.
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# @package _global_.sampler

type: deterministic
num_steps: 9
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# @package _global_.sampler

type: stochastic
boundary_pix: 2
overlap_pix: 4
#overlap_pix has to be no less than 2*boundary_pix
37 changes: 37 additions & 0 deletions examples/generative/corrdiff/conf/base/model/diffusion.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

name: diffusion
# Model type.
hr_mean_conditioning: True
# Recommended to use high-res conditioning for diffusion.
scale_cond_input: False
# If true, also scales the input conditioning. Recommended to False.

# Standard model parameters.
model_args:
gridtype: "sinusoidal"
# Type of positional grid to use: 'sinusoidal', 'learnable', 'linear'.
# Controls how positional information is encoded.
N_grid_channels: 4
# Number of channels for positional grid embeddings
embedding_type: "zero"
# Type of timestep embedding: 'positional' for DDPM++, 'fourier' for NCSN++,
# 'zero' for none
model_type: "SongUNetPosEmbd"
# Type of model architecture: 'SongUNetPosLtEmbd' for lead-time aware UNet
# with positional embeddings, 'SongUNetPosEmbd' for UNet with positional
# embeddings, 'DhariwalUNet' for UNet with Fourier embeddings
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


name: lt_aware_ce_regression
# Model type.
hr_mean_conditioning: False
# No high-res conditioning for regression.

# Default model parameters.
model_args:
img_channels: 4
# Number of color channels in the model
N_grid_channels: 4
# Number of channels for positional grid embeddings
embedding_type: "zero"
# Type of timestep embedding: 'positional' for DDPM++, 'fourier' for NCSN++,
# 'zero' for none
lead_time_channels: 4
# Number of channels for lead-time embeddings
lead_time_steps: 9
# Number of lead-time steps
model_type: "SongUNetPosLtEmbd"
# Type of model architecture: 'SongUNetPosLtEmbd' for lead-time aware UNet with
# positional embeddings, 'SongUNetPosEmbd' for UNet with positional embeddings,
# 'DhariwalUNet' for UNet with Fourier embeddings
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

name: lt_aware_patched_diffusion
# Model type.
hr_mean_conditioning: True
# Recommended to use high-res conditioning for diffusion.
scale_cond_input: False
# If true, also scales the input conditioning. Recommended to False.

# Standard model parameters.
model_args:
gridtype: "learnable"
# Type of positional grid to use: 'sinusoidal', 'learnable', 'linear'.
# Controls how positional information is encoded.
N_grid_channels: 100
# Number of channels for positional grid embeddings
lead_time_channels: 20
# Number of channels for lead-time embeddings
lead_time_steps: 9
# Number of lead-time steps
model_type: "SongUNetPosLtEmbd"
# Type of model architecture: 'SongUNetPosLtEmbd' for lead-time aware UNet
# with positional embeddings, 'SongUNetPosEmbd' for UNet with positional
# embeddings, 'DhariwalUNet' for UNet with Fourier embeddings
Original file line number Diff line number Diff line change
@@ -14,11 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

name: patched_diffusion
# Name of the preconditioner
name: diffusion
# Model type.
hr_mean_conditioning: True
# High-res mean (regression's output) as additional condition
scale_cond_input: True
# If true, also scales the input conditioning
# For backward compatibility, this is true by default
# We recommend setting this to false for new training runs
# Recommended to use high-res conditioning for diffusion.
scale_cond_input: False
# If true, also scales the input conditioning. Recommended to False.

# Standard model parameters.
model_args:
gridtype: "learnable"
# Type of positional grid to use: 'sinusoidal', 'learnable', 'linear'.
# Controls how positional information is encoded.
N_grid_channels: 100
# Number of channels for positional grid embeddings
Original file line number Diff line number Diff line change
@@ -14,11 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

name: diffusion
# Name of the preconditioner
name: regression
# Model type.
hr_mean_conditioning: false
# High-res mean (regression's output) as additional condition
scale_cond_input: True
# If true, also scales the input conditioning
# For backward compatibility, this is true by default
# We recommend setting this to false for new training runs
# No high-res conditioning for regression.

# Default regression model parameters. Do not modify.
model_args:
"img_channels": 4
# Number of color channels in the model
"N_grid_channels": 4
# Number of channels for positional grid embeddings
"embedding_type": "zero"
# Type of timestep embedding: 'positional' for DDPM++, 'fourier' for NCSN++,
# 'zero' for none
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# @package _global_.model

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
@@ -14,16 +16,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

defaults:
- corrdiff_diffusion

hr_mean_conditioning: True
scale_cond_input: false
# If true, also scales the input conditioning
# For backward compatibility, this is true by default
# We recommend setting this to false for new training runs

model_args:
# Base multiplier for the number of channels across the network.
model_channels: 64
# Per-resolution multipliers for the number of channels.
channel_mult: [1, 2, 2]
# Resolutions at which self-attention layers are applied.
attn_resolutions: [16]
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# @package _global_.model

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
@@ -14,23 +16,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

hydra:
job:
chdir: true
name: mini_generation
run:
dir: ./outputs/${hydra:job.name}

# Get defaults
defaults:

# Dataset
- dataset/hrrrmini

# Sampler
- sampler/stochastic
#- sampler/deterministic

# Generation
- generation/mini
#- generation/patched_based
model_args:
# Base multiplier for the number of channels across the network.
model_channels: 128
# Per-resolution multipliers for the number of channels.
channel_mult: [1, 2, 2, 2, 2]
# Resolutions at which self-attention layers are applied.
attention_levels: [28]
Original file line number Diff line number Diff line change
@@ -20,34 +20,38 @@ hp:
# Training duration based on the number of processed samples
total_batch_size: 256
# Total batch size
batch_size_per_gpu: 2
batch_size_per_gpu: "auto"
# Batch size per GPU
lr: 0.0002
# Learning rate
grad_clip_threshold: null
# no gradient clipping for defualt non-patch-based training
# no gradient clipping for default non-patch-based training
lr_decay: 1
# LR decay rate
lr_rampup: 10000000
lr_rampup: 0
# Rampup for learning rate, in number of samples

# Performance
perf:
fp_optimizations: fp32
fp_optimizations: amp-bf16
# Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"]
# "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16}
dataloader_workers: 4
# DataLoader worker processes
songunet_checkpoint_level: 0 # 0 means no checkpointing
# Gradient checkpointing level, value is number of layers to checkpoint

# I/O
# IO
io:
regression_checkpoint_path: ???
# Where to load the regression checkpoint. Should be overridden.
print_progress_freq: 1000
# How often to print progress
save_checkpoint_freq: 5000
# How often to save the checkpoints, measured in number of processed samples
save_n_recent_checkpoints: -1
# Set to a positive integer to only keep the most recent n checkpoints
validation_freq: 5000
# how often to record the validation loss, measured in number of processed samples
validation_steps: 10
# how many loss evaluations are used to compute the validation loss per checkpoint
# how many loss evaluations are used to compute the validation loss per checkpoint
Original file line number Diff line number Diff line change
@@ -14,7 +14,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Validation dataset options
# (need to set dataset.train_test_split == true to have an effect)
train: false
all_times: false
defaults:
- base_all
Original file line number Diff line number Diff line change
@@ -15,32 +15,23 @@
# limitations under the License.

defaults:
- corrdiff_regression

- base_all

# Hyperparameters
hp:
training_duration: 2000000
training_duration: 1000000
# Training duration based on the number of processed samples
total_batch_size: 1
# Total batch size
batch_size_per_gpu: 1
# Batch size per GPU
lr_rampup: 0
# Rampup for learning rate, in number of samples

# Performance
perf:
fp_optimizations: amp-bf16
# Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"]
# "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16}
dataloader_workers: 1
# DataLoader worker processes
songunet_checkpoint_level: 2 # 0 means no checkpointing
songunet_checkpoint_level: 2

# I/O
io:
print_progress_freq: 1
# How often to print progress
save_checkpoint_freq: 5
# How often to save the checkpoints, measured in number of processed samples
# How often to save the checkpoints, measured in number of processed samples
Original file line number Diff line number Diff line change
@@ -14,47 +14,43 @@
# See the License for the specific language governing permissions and
# limitations under the License.

defaults:
- base_all

# Hyperparameters
hp:
training_duration: 200000
# Training duration based on the number of processed images, measured in kilo images (thousands of images)
training_duration: 10000000
# Training duration based on the number of processed images
total_batch_size: 1
# Total batch size
batch_size_per_gpu: 1
# Batch size per GPU
lr: 0.0002
# Learning rate
grad_clip_threshold: 1e6
# no gradient clipping for defualt non-patch-based training
# no gradient clipping for default non-patch-based training
lr_decay: 0.7
# LR decay rate
patch_shape_x: 448
patch_shape_y: 448
# Patch size. Patch training is used if these dimensions differ from img_shape_x and img_shape_y
patch_num: 4
# Number of patches from a single sample. Total number of patches is patch_num * batch_size_global
patch_shape_x: ???
patch_shape_y: ???
# Patch size. Patch training is used if these dimensions differ from
# img_shape_x and img_shape_y. Should be overridden.
patch_num: ???
# Number of patches from a single sample. Total number of patches is
# patch_num * batch_size_global. Should be overridden.
lr_rampup: 1000000
# Rampup for learning rate, in number of samples

# Performance
perf:
fp_optimizations: amp-bf16
# Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"]
# "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16}
dataloader_workers: 4
# DataLoader worker processes
songunet_checkpoint_level: 1 # 0 means no checkpointing
# Gradient checkpointing level, value is number of layers to checkpoint

# I/O
io:
regression_checkpoint_path: /lustre/fsw/portfolios/coreai/projects/coreai_climate_earth2/tge/gefs_regression/checkpoints_lt_aware_ce_regression/UNet.0.15.mdlus
# Where to load the regression checkpoint
io:
print_progress_freq: 1
# How often to print progress
save_checkpoint_freq: 5
# How often to save the checkpoints, measured in number of processed samples
validation_freq: 1
# how often to record the validation loss, measured in number of processed samples
validation_steps: 1000
# how many loss evaluations are used to compute the validation loss per checkpoint
# how many loss evaluations are used to compute the validation loss per checkpoint
Original file line number Diff line number Diff line change
@@ -15,29 +15,28 @@
# limitations under the License.

defaults:
- corrdiff_regression

- base_all

# Hyperparameters
hp:
training_duration: 2000000
training_duration: 10000000
# Training duration based on the number of processed samples
total_batch_size: 256
# Total batch size
batch_size_per_gpu: "auto"
# Batch size per GPU
lr_rampup: 0
# Rampup for learning rate, in number of samples

# Performance
perf:
fp_optimizations: amp-bf16
# Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"]
# "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16}
dataloader_workers: 1
# DataLoader worker processes
grad_clip_threshold: 1e6
# no gradient clipping for default non-patch-based training
lr_decay: 0.7
# LR decay rate
patch_shape_x: ???
patch_shape_y: ???
# Patch size. Patch training is used if these dimensions differ from
# img_shape_x and img_shape_y. Should be overridden.
patch_num: ???
# Number of patches from a single sample. Total number of patches is
# patch_num * batch_size_global. Should be overridden.

# I/O
io:
# Where to load the regression checkpoint
print_progress_freq: 10000
save_checkpoint_freq: 500000
# How often to save the checkpoints, measured in number of processed samples
validation_freq: 50000
# how often to record the validation loss, measured in number of processed samples

Original file line number Diff line number Diff line change
@@ -14,8 +14,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

name: regression
# Name of the preconditioner
hr_mean_conditioning: False
# High-res mean (regression's output) as additional condition

defaults:
- base_all
81 changes: 81 additions & 0 deletions examples/generative/corrdiff/conf/config_generate_custom.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

hydra:
job:
chdir: true
name: <my_job_name> # Change `my_job_name`
run:
dir: ./<my_output_dir>/${hydra:job.name} # Change `my_output_dir`
searchpath:
- pkg://conf/base # Do not modify

# Base parameters for dataset, model, and generation
defaults:

- dataset: custom
# The dataset type for training.
# Accepted values:
# `gefs_hrrr`: full GEFS-HRRR dataset for continental US.
# `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments.
# `cwb`: full CWB dataset for Taiwan.
# `custom`: user-defined dataset. Parameters need to be specified below.

- generation: patched
# The base generation parameters.
# Accepted values:
# `patched`: base parameters for a patch-based model
# `non_patched`: base parameters for a non-patched model


# Dataset parameters. Used for `custom` dataset type.
# Modify or add below parameters that should be passed as argument to the
# user-defined dataset class.
dataset:
type: <path/to/dataset.py::DatasetClass>
# Path to the user-defined dataset class. The user-defined dataset class is
# automatically loaded from the path. The user-defined class "DatasetClass"
# must be defined in the path "path/to/dataset.py".
data_path: <path_to_data_file>
# Path to .nc data file
stats_path: <path_to_stats_file>
# Path to json stats file
input_variables: []
# Names or indices of input channels
output_variables: []
# Names or indices of output channels
invariant_variables: null
# Names or indices of invariant channels. Optional.

# Generation parameters to specialize
generation:
num_ensembles: 64
# int, number of ensembles to generate per input
seed_batch_size: 4
# int, size of the batched inference
patch_shape_x: 448
patch_shape_y: 448
# int, patch size. Only used for `generation: patched`. For custom dataset,
# this should be determined based on an autocorrelation plot.
times:
- YYYY-MM-DDThh:mm:ss # Replace with target value
# List[str], time stamps in ISO 8601 format. Replace and list desired target
# time stamps.
io:
res_ckpt_filename: <diffusion_checkpoint.mdlus>
# Path to checkpoint file for the diffusion model
reg_ckpt_filename: <regression_checkpoint.mdlus>
# Path to checkpoint filename for the mean predictor model
74 changes: 60 additions & 14 deletions examples/generative/corrdiff/conf/config_generate_gefs_hrrr.yaml
Original file line number Diff line number Diff line change
@@ -15,22 +15,68 @@
# limitations under the License.

hydra:
job:
chdir: true
name: gefs_hrrr_generation
run:
dir: output/${hydra:job.name}
job:
chdir: true
name: generate_gefs_hrrr
run:
dir: ./outputs/${hydra:job.name}
searchpath:
- pkg://conf/base # Do not modify

# Get defaults
# Base parameters for dataset, model, and generation
defaults:

# Dataset
- dataset/gefs_hrrr
- dataset: gefs_hrrr
# The dataset type for training.
# Accepted values:
# `gefs_hrrr`: full GEFS-HRRR dataset for continental US.
# `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments.
# `cwb`: full CWB dataset for Taiwan.
# `custom`: user-defined dataset. Parameters need to be specified below.

# Sampler
- sampler/stochastic
#- sampler/deterministic
- generation: patched
# The base generation parameters.
# Accepted values:
# `patched`: base parameters for a patch-based model
# `non_patched`: base parameters for a non-patched model

# Generation
- generation/patched_based_gefs_hrrr
#- generation/patched_based

# Dataset parameters. Used for `custom` dataset type.
# Modify or add below parameters that should be passed as argument to the
# user-defined dataset class.
dataset:
data_path: /data
# Path to .nc data file
stats_path: /data/stats.json
# Path to json stats file


# Generation parameters to specialize
generation:
num_ensembles: 1
# int, number of ensembles to generate per input
seed_batch_size: 1
# int, size of the batched inference
patch_shape_x: 448
patch_shape_y: 448
# int, patch size. Only used for `generation: patched`. For custom dataset,
# this should be determined based on an autocorrelation plot.
times:
- "2024011212f00"
- "2024011212f03"
- "2024011212f06"
- "2024011212f09"
- "2024011212f12"
- "2024011212f15"
- "2024011212f18"
- "2024011212f21"
- "2024011212f24"
# List[str], time stamps in ISO 8601 format. Replace and list desired target
# time stamps.
has_lead_time: True

io:
res_ckpt_filename: <diffusion_checkpoint.mdlus>
# Path to checkpoint file for the diffusion model
reg_ckpt_filename: <regression_checkpoint.mdlus>
# Path to checkpoint filename for the mean predictor model
67 changes: 67 additions & 0 deletions examples/generative/corrdiff/conf/config_generate_hrrr_mini.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

hydra:
job:
chdir: true
name: generate_hrrr_mini
run:
dir: ./outputs/${hydra:job.name}
searchpath:
- pkg://conf/base # Do not modify

# Base parameters for dataset, model, and generation
defaults:

- dataset: hrrr_mini
# The dataset type for training.
# Accepted values:
# `gefs_hrrr`: full GEFS-HRRR dataset for continental US.
# `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments.
# `cwb`: full CWB dataset for Taiwan.
# `custom`: user-defined dataset. Parameters need to be specified below.

- generation: non_patched
# The base generation parameters.
# Accepted values:
# `patched`: base parameters for a patch-based model
# `non_patched`: base parameters for a non-patched model


# Dataset parameters. Used for `custom` dataset type.
# Modify or add below parameters that should be passed as argument to the
# user-defined dataset class.
dataset:
data_path: /data/corrdiff-mini/hrrr_mini_train.nc
# Path to .nc data file
stats_path: /data/corrdiff-mini/stats.json
# Path to json stats file

# Generation parameters to specialize
generation:
num_ensembles: 2
# int, number of ensembles to generate per input
seed_batch_size: 1
# int, size of the batched inference
times:
- 2020-02-02T00:00:00
# List[str], time stamps in ISO 8601 format. Replace and list desired target
# time stamps.
io:
res_ckpt_filename: <diffusion_checkpoint.mdlus>
# Path to checkpoint file for the diffusion model
reg_ckpt_filename: <regression_checkpoint.mdlus>
# Path to checkpoint filename for the mean predictor model
74 changes: 74 additions & 0 deletions examples/generative/corrdiff/conf/config_generate_taiwan.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

hydra:
job:
chdir: true
name: generate_taiwan
run:
dir: ./outputs/${hydra:job.name}
searchpath:
- pkg://conf/base # Do not modify

# Base parameters for dataset, model, and generation
defaults:

- dataset: cwb
# The dataset type for training.
# Accepted values:
# `gefs_hrrr`: full GEFS-HRRR dataset for continental US.
# `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments.
# `cwb`: full CWB dataset for Taiwan.
# `custom`: user-defined dataset. Parameters need to be specified below.

- generation: non_patched
# The base generation parameters.
# Accepted values:
# `patched`: base parameters for a patch-based model
# `non_patched`: base parameters for a non-patched model


# Dataset parameters. Used for `custom` dataset type.
# Modify or add below parameters that should be passed as argument to the
# user-defined dataset class.
dataset:
data_path: /code/2023-01-24-cwb-4years.zarr
train: False
all_times: True


# Generation parameters to specialize
generation:
num_ensembles: 64
# int, number of ensembles to generate per input
seed_batch_size: 1
# int, size of the batched inference
hr_mean_conditioning: false
# Whether to use hr_mean_conditioning
times:
- 2021-02-02T00:00:00
- 2021-03-02T00:00:00
- 2021-04-02T00:00:00
# hurricane
- 2021-09-12T00:00:00
- 2021-09-12T12:00:00
# List[str], time stamps in ISO 8601 format. Replace and list desired target
# time stamps.
io:
res_ckpt_filename: <diffusion_checkpoint.mdlus>
# Path to checkpoint file for the diffusion model
reg_ckpt_filename: <regression_checkpoint.mdlus>
# Path to checkpoint filename for the mean predictor model
41 changes: 0 additions & 41 deletions examples/generative/corrdiff/conf/config_training.yaml

This file was deleted.

117 changes: 117 additions & 0 deletions examples/generative/corrdiff/conf/config_training_custom.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

hydra:
job:
chdir: true
name: <my_job_name> # Change `my_job_name`
run:
dir: ./<my_output_dir>/${hydra:job.name} # Change `my_output_dir`
searchpath:
- pkg://conf/base # Do not modify

# Base parameters for dataset, model, training, and validation
defaults:

- dataset: custom
# The dataset type for training.
# Accepted values:
# `gefs_hrrr`: full GEFS-HRRR dataset for continental US.
# `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments.
# `cwb`: full CWB dataset for Taiwan.
# `custom`: user-defined dataset. Parameters need to be specified below.

- model: diffusion
# The model type.
# Accepted values:
# `regression`: a regression UNet for deterministic predictions
# `lt_aware_ce_regression`: similar to `regression` but with lead time
# conditioning
# `diffusion`: a diffusion UNet for residual predictions
# `patched_diffusion`: a more memory-efficient diffusion model
# `lt_aware_patched_diffusion`: similar to `patched_diffusion` but
# with lead time conditioning

- model_size: normal
# The model size configuration.
# Accepted values:
# `normal`: normal model size
# `mini`: smaller model size for fast experiments

- training: ${model}
# The base training parameters. Determined by the model type.


# Dataset parameters. Used for `custom` dataset type.
# Modify or add below parameters that should be passed as argument to the
# user-defined dataset class.
dataset:
type: <path/to/dataset.py::DatasetClass>
# Path to the user-defined dataset class. The user-defined dataset class is
# automatically loaded from the path. The user-defined class "DatasetClass"
# must be defined in the path "path/to/dataset.py".
data_path: <path_to_data_file>
# Path to .nc data file
stats_path: <path_to_stats_file>
# Path to json stats file
input_variables: []
# Names or indices of input channels
output_variables: []
# Names or indices of output channels
invariant_variables: null
# Names or indices of invariant channels. Optional.

# Training parameters
training:
hp:
training_duration: 10000000
# Training duration based on the number of processed samples
total_batch_size: 256
# Total batch size
batch_size_per_gpu: "auto"
# Batch size per GPU. Set to "auto" to automatically determine the batch
# size based on the number of GPUs.
patch_shape_x: 448
patch_shape_y: 448
# Patch size. Only used for `model: patched_diffusion` or `model:
# lt_aware_patched_diffusion`. For custom dataset, this should be
# determined based on an autocorrelation plot.
patch_num: 10
# Number of patches from a single sample. Total number of patches is
# patch_num * total_batch_size. Only used for `model: patched_diffusion`
# or `model: lt_aware_patched_diffusion`.
lr: 0.0002
# Learning rate
lr_rampup: 0
# Rampup for learning rate, in number of samples
io:
regression_checkpoint_path: <path/to/checkpoint.mdlus>
# Path to load the regression checkpoint

# Parameters for wandb logging
wanbd:
mode: online
# Logging mode. Accepted values: "offline", "online" or "disabled"
key: <your_api_key>
# You wandb API key
project: <your_project_name>
# Your wanb project name
entity: <your_entity_name>
# Name of your wandb enity
name: <your_experiment_name>
# Name of your experiment
watch_model: true
# Bool: if true, then wandb logs statistics about your model parameters

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

hydra:
job:
chdir: true
name: gefs_hrrr_diffusion
run:
dir: ./output/${hydra:job.name}
searchpath:
- pkg://conf/base # Do not modify

# Base parameters for dataset, model, training, and validation
defaults:

- dataset: gefs_hrrr
# The dataset type for training.
# Accepted values:
# `gefs_hrrr`: full GEFS-HRRR dataset for continental US.
# `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments.
# `cwb`: full CWB dataset for Taiwan.
# `custom`: user-defined dataset. Parameters need to be specified below.

- model: lt_aware_patched_diffusion
# The model type.
# Accepted values:
# `regression`: a regression UNet for deterministic predictions
# `lt_aware_ce_regression`: similar to `regression` but with lead time
# conditioning
# `diffusion`: a diffusion UNet for residual predictions
# `patched_diffusion`: a more memory-efficient diffusion model
# `lt_aware_patched_diffusion`: similar to `patched_diffusion` but
# with lead time conditioning

- model_size: normal
# The model size configuration.
# Accepted values:
# `normal`: normal model size
# `mini`: smaller model size for fast experiments

- training: ${model}
# The base training parameters. Determined by the model type.


# Dataset parameters. Used for `custom` dataset type.
# Modify or add below parameters that should be passed as argument to the
# user-defined dataset class.
dataset:
data_path: /data
# Path to .nc data file
stats_path: /data/stats.json
# Path to json stats file

model:
scale_cond_input: true
# If true, also scales the input conditioning. Set to True for backward
# compatibility.

# Training parameters
training:
hp:
training_duration: 10000000
# Training duration based on the number of processed samples
patch_shape_x: 448
patch_shape_y: 448
# Patch size. Patch training is used if these dimensions differ from
# img_shape_x and img_shape_y.
patch_num: 4
# Number of patches from a single sample. Total number of patches is
# patch_num * batch_size_global.
io:
regression_checkpoint_path: <path/to/checkpoint.mdlus>
# Path to load the regression checkpoint

# Parameters for wandb logging
wanbd:
mode: online
# Logging mode. Accepted values: "offline", "online" or "disabled"
key: <your_api_key>
# You wandb API key
project: <your_project_name>
# Your wanb project name
entity: <your_entity_name>
# Name of your wandb enity
name: <your_experiment_name>
# Name of your experiment
watch_model: true
# Bool: if true, then wandb logs statistics about your model parameters
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

hydra:
job:
chdir: true
name: gefs_hrrr_regression
run:
dir: ./output/${hydra:job.name}
searchpath:
- pkg://conf/base # Do not modify

# Base parameters for dataset, model, training, and validation
defaults:

- dataset: gefs_hrrr
# The dataset type for training.
# Accepted values:
# `gefs_hrrr`: full GEFS-HRRR dataset for continental US.
# `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments.
# `cwb`: full CWB dataset for Taiwan.
# `custom`: user-defined dataset. Parameters need to be specified below.

- model: lt_aware_ce_regression
# The model type.
# Accepted values:
# `regression`: a regression UNet for deterministic predictions
# `lt_aware_ce_regression`: similar to `regression` but with lead time
# conditioning
# `diffusion`: a diffusion UNet for residual predictions
# `patched_diffusion`: a more memory-efficient diffusion model
# `lt_aware_patched_diffusion`: similar to `patched_diffusion` but
# with lead time conditioning

- model_size: normal
# The model size configuration.
# Accepted values:
# `normal`: normal model size
# `mini`: smaller model size for fast experiments

- training: ${model}
# The base training parameters. Determined by the model type.


# Dataset parameters. Used for `custom` dataset type.
# Modify or add below parameters that should be passed as argument to the
# user-defined dataset class.
dataset:
data_path: /data
# Path to .nc data file
stats_path: /data/stats.json
# Path to json stats file


# Training parameters
training:
hp:
training_duration: 1000000
# Training duration based on the number of processed samples

# Parameters for wandb logging
wanbd:
mode: online
# Logging mode. Accepted values: "offline", "online" or "disabled"
key: <your_api_key>
# You wandb API key
project: <your_project_name>
# Your wanb project name
entity: <your_entity_name>
# Name of your wandb enity
name: <your_experiment_name>
# Name of your experiment
watch_model: true
# Bool: if true, then wandb logs statistics about your model parameters
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

hydra:
job:
chdir: true
name: hrrr_mini_diffusion
run:
dir: ./output/${hydra:job.name}
searchpath:
- pkg://conf/base # Do not modify

# Base parameters for dataset, model, training, and validation
defaults:

- dataset: hrrr_mini
# The dataset type for training.
# Accepted values:
# `gefs_hrrr`: full GEFS-HRRR dataset for continental US.
# `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments.
# `cwb`: full CWB dataset for Taiwan.
# `custom`: user-defined dataset. Parameters need to be specified below.

- model: diffusion
# The model type.
# Accepted values:
# `regression`: a regression UNet for deterministic predictions
# `lt_aware_ce_regression`: similar to `regression` but with lead time
# conditioning
# `diffusion`: a diffusion UNet for residual predictions
# `patched_diffusion`: a more memory-efficient diffusion model
# `lt_aware_patched_diffusion`: similar to `patched_diffusion` but
# with lead time conditioning

- model_size: mini
# The model size configuration.
# Accepted values:
# `normal`: normal model size
# `mini`: smaller model size for fast experiments

- training: ${model}
# The base training parameters. Determined by the model type.


# Dataset parameters. Used for `custom` dataset type.
# Modify or add below parameters that should be passed as argument to the
# user-defined dataset class.
dataset:
data_path: /data/corrdiff-mini/hrrr_mini_train.nc
# Path to .nc data file
stats_path: /data/corrdiff-mini/stats.json
# Path to json stats file

# Training parameters
training:
hp:
training_duration: 8000000
# Training duration based on the number of processed samples
io:
print_progress_freq: 10000
regression_checkpoint_path: <path/to/checkpoint.mdlus>
# Path to load the regression checkpoint

# Parameters for wandb logging
wanbd:
mode: online
# Logging mode. Accepted values: "offline", "online" or "disabled"
key: <your_api_key>
# You wandb API key
project: <your_project_name>
# Your wanb project name
entity: <your_entity_name>
# Name of your wandb enity
name: <your_experiment_name>
# Name of your experiment
watch_model: true
# Bool: if true, then wandb logs statistics about your model parameters
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

hydra:
job:
chdir: true
name: hrrr_mini_regression
run:
dir: ./output/${hydra:job.name}
searchpath:
- pkg://conf/base # Do not modify

# Base parameters for dataset, model, training, and validation
defaults:

- dataset: hrrr_mini
# The dataset type for training.
# Accepted values:
# `gefs_hrrr`: full GEFS-HRRR dataset for continental US.
# `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments.
# `cwb`: full CWB dataset for Taiwan.
# `custom`: user-defined dataset. Parameters need to be specified below.

- model: regression
# The model type.
# Accepted values:
# `regression`: a regression UNet for deterministic predictions
# `lt_aware_ce_regression`: similar to `regression` but with lead time
# conditioning
# `diffusion`: a diffusion UNet for residual predictions
# `patched_diffusion`: a more memory-efficient diffusion model
# `lt_aware_patched_diffusion`: similar to `patched_diffusion` but
# with lead time conditioning

- model_size: mini
# The model size configuration.
# Accepted values:
# `normal`: normal model size
# `mini`: smaller model size for fast experiments

- training: ${model}
# The base training parameters. Determined by the model type.


# Dataset parameters. Used for `custom` dataset type.
# Modify or add below parameters that should be passed as argument to the
# user-defined dataset class.
dataset:
data_path: /data/corrdiff-mini/hrrr_mini_train.nc
# Path to .nc data file
stats_path: /data/corrdiff-mini/stats.json
# Path to json stats file

# Training parameters
training:
hp:
training_duration: 2000000
# Training duration based on the number of processed samples
io:
print_progress_freq: 10000

# Parameters for wandb logging
wanbd:
mode: online
# Logging mode. Accepted values: "offline", "online" or "disabled"
key: <your_api_key>
# You wandb API key
project: <your_project_name>
# Your wanb project name
entity: <your_entity_name>
# Name of your wandb enity
name: <your_experiment_name>
# Name of your experiment
watch_model: true
# Bool: if true, then wandb logs statistics about your model parameters

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

hydra:
job:
chdir: true
name: taiwan_diffusion
run:
dir: ./output/${hydra:job.name}
searchpath:
- pkg://conf/base # Do not modify

# Base parameters for dataset, model, training, and validation
defaults:

- dataset: cwb
# The dataset type for training.
# Accepted values:
# `gefs_hrrr`: full GEFS-HRRR dataset for continental US.
# `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments.
# `cwb`: full CWB dataset for Taiwan.
# `custom`: user-defined dataset. Parameters need to be specified below.

- model: diffusion
# The model type.
# Accepted values:
# `regression`: a regression UNet for deterministic predictions
# `lt_aware_ce_regression`: similar to `regression` but with lead time
# conditioning
# `diffusion`: a diffusion UNet for residual predictions
# `patched_diffusion`: a more memory-efficient diffusion model
# `lt_aware_patched_diffusion`: similar to `patched_diffusion` but
# with lead time conditioning

- model_size: normal
# The model size configuration.
# Accepted values:
# `normal`: normal model size
# `mini`: smaller model size for fast experiments

- training: ${model}
# The base training parameters. Determined by the model type.


# Dataset parameters. Used for `custom` dataset type.
# Modify or add below parameters that should be passed as argument to the
# user-defined dataset class.
dataset:
data_path: /code/2023-01-24-cwb-4years.zarr

model:
scale_cond_input: True
# If true, also scales the input conditioning. True for backward
# compatibility.
hr_mean_conditioning: false
# High-res mean (regression's output) as additional condition

# Training parameters
training:
hp:
training_duration: 200000000
# Training duration based on the number of processed samples
lr_rampup: 10000000
# Rampup for learning rate, in number of samples
io:
regression_checkpoint_path: <path/to/checkpoint.mdlus>
# Path to load the regression checkpoint

# Additional parameters for validation
validation:
train: false
all_times: false

# Parameters for wandb logging
wanbd:
mode: online
# Logging mode. Accepted values: "offline", "online" or "disabled"
key: <your_api_key>
# You wandb API key
project: <your_project_name>
# Your wanb project name
entity: <your_entity_name>
# Name of your wandb enity
name: <your_experiment_name>
# Name of your experiment
watch_model: true
# Bool: if true, then wandb logs statistics about your model parameters
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

hydra:
job:
chdir: true
name: taiwan_regression
run:
dir: ./output/${hydra:job.name}
searchpath:
- pkg://conf/base # Do not modify

# Base parameters for dataset, model, training, and validation
defaults:

- dataset: cwb
# The dataset type for training.
# Accepted values:
# `gefs_hrrr`: full GEFS-HRRR dataset for continental US.
# `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments.
# `cwb`: full CWB dataset for Taiwan.
# `custom`: user-defined dataset. Parameters need to be specified below.

- model: regression
# The model type.
# Accepted values:
# `regression`: a regression UNet for deterministic predictions
# `lt_aware_ce_regression`: similar to `regression` but with lead time
# conditioning
# `diffusion`: a diffusion UNet for residual predictions
# `patched_diffusion`: a more memory-efficient diffusion model
# `lt_aware_patched_diffusion`: similar to `patched_diffusion` but
# with lead time conditioning

- model_size: normal
# The model size configuration.
# Accepted values:
# `normal`: normal model size
# `mini`: smaller model size for fast experiments

- training: ${model}
# The base training parameters. Determined by the model type.


# Dataset parameters. Used for `custom` dataset type.
# Modify or add below parameters that should be passed as argument to the
# user-defined dataset class.
dataset:
data_path: /code/2023-01-24-cwb-4years.zarr

# Training parameters
training:
hp:
training_duration: 200000000
# Training duration based on the number of processed samples
lr_rampup: 10000000
# Rampup for learning rate, in number of samples

# Parameters for wandb logging
wanbd:
mode: online
# Logging mode. Accepted values: "offline", "online" or "disabled"
key: <your_api_key>
# You wandb API key
project: <your_project_name>
# Your wanb project name
entity: <your_entity_name>
# Name of your wandb enity
name: <your_experiment_name>
# Name of your experiment
watch_model: true
# Bool: if true, then wandb logs statistics about your model parameters
31 changes: 0 additions & 31 deletions examples/generative/corrdiff/conf/dataset/cwb_generate.yaml

This file was deleted.

63 changes: 0 additions & 63 deletions examples/generative/corrdiff/conf/generation/base.yaml

This file was deleted.

39 changes: 0 additions & 39 deletions examples/generative/corrdiff/conf/generation/mini.yaml

This file was deleted.

This file was deleted.

This file was deleted.

597 changes: 0 additions & 597 deletions examples/generative/corrdiff/conf/references/config_data_ref.yaml

This file was deleted.

55 changes: 0 additions & 55 deletions examples/generative/corrdiff/conf/training/corrdiff_diffusion.yaml

This file was deleted.

This file was deleted.

This file was deleted.

54 changes: 54 additions & 0 deletions examples/generative/corrdiff/datasets/dataset.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,8 @@
from typing import Iterable, Tuple, Union
import copy
import torch
import importlib.util
from pathlib import Path

from modulus.utils.generative import InfiniteSampler
from modulus.distributed import DistributedManager
@@ -32,6 +34,58 @@
}


def register_dataset(dataset_spec: str) -> None:
"""
Register a new dataset class from a file path specification.
Parameters
----------
dataset_spec : str
String specification in the format "path_to_file.py::dataset_class"
Raises
------
ValueError
If the dataset_spec format is invalid or if the file doesn't exist
ImportError
If the dataset class cannot be imported
"""
try:
file_path, class_name = dataset_spec.split("::")
except ValueError:
raise ValueError(
"Invalid dataset specification. Expected format: "
"'path_to_file.py::dataset_class'"
)

if class_name in known_datasets:
return # Dataset already registered

# Convert to Path and validate
file_path = Path(file_path)
if not file_path.exists():
raise ValueError(f"Dataset file not found: {file_path}")
if not file_path.suffix == ".py":
raise ValueError(f"Dataset file must be a Python file: {file_path}")

# Import the module and get the class
spec = importlib.util.spec_from_file_location(file_path.stem, str(file_path))
if spec is None or spec.loader is None:
raise ImportError(f"Could not load spec for {file_path}")

module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

try:
dataset_class = getattr(module, class_name)
except AttributeError:
raise ImportError(f"Could not find dataset class '{class_name}' in {file_path}")

# Register the dataset
known_datasets[dataset_spec] = dataset_class
return


def init_train_valid_datasets_from_config(
dataset_cfg: dict,
dataloader_cfg: Union[dict, None] = None,
6 changes: 4 additions & 2 deletions examples/generative/corrdiff/datasets/gefs_hrrr.py
Original file line number Diff line number Diff line change
@@ -126,6 +126,8 @@ class HrrrForecastGEFSDataset(DownscalingDataset):
Expects data to be stored under directory specified by 'location'
GEFS under <root_dir>/gefs/
HRRR under <root_dir>/hrrr/
Within each directory, there should be one zarr file per
year containing the data of interest.
"""

def __init__(
@@ -142,7 +144,7 @@ def __init__(
train_years: Iterable[int] = (2020, 2021, 2022, 2023),
valid_years: Iterable[int] = (2024,),
hrrr_window: Union[Tuple[Tuple[int, int], Tuple[int, int]], None] = None,
sample_shape: Tuple[int, int] = None,
sample_shape: Tuple[int, int] = [-1, -1],
ds_factor: int = 1,
shard: bool = False,
overfit: bool = False,
@@ -468,7 +470,7 @@ def image_shape(self) -> Tuple[int, int]:
return (y_end - y_start, x_end - x_start)

def _get_crop_box(self):
if self.sample_shape == None:
if self.sample_shape == [-1, -1]:
return self.hrrr_window

((y_start, y_end), (x_start, x_end)) = self.hrrr_window
65 changes: 25 additions & 40 deletions examples/generative/corrdiff/generate.py
Original file line number Diff line number Diff line change
@@ -23,10 +23,10 @@
import netCDF4 as nc
from modulus.distributed import DistributedManager
from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper
from modulus.utils.patching import GridPatching2D
from modulus import Module
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from einops import rearrange
from torch.distributed import gather


@@ -45,6 +45,7 @@
save_images,
)
from helpers.train_helpers import set_patch_shape
from datasets.dataset import register_dataset


@hydra.main(version_base="1.2", config_path="conf", config_name="config_generate")
@@ -85,6 +86,11 @@ def main(cfg: DictConfig) -> None:

# Create dataset object
dataset_cfg = OmegaConf.to_container(cfg.dataset)

# Register dataset (if custom dataset)
register_dataset(dataset_cfg.dataset.type)
logger0.info(f"Using dataset: {dataset_cfg.dataset.type}")

if "has_lead_time" in cfg.generation:
has_lead_time = cfg.generation["has_lead_time"]
else:
@@ -96,19 +102,23 @@ def main(cfg: DictConfig) -> None:
img_out_channels = len(dataset.output_channels())

# Parse the patch shape
if hasattr(cfg.generation, "patch_shape_x"): # TODO better config handling
if cfg.generation.patching:
patch_shape_x = cfg.generation.patch_shape_x
else:
patch_shape_x = None
if hasattr(cfg.generation, "patch_shape_y"):
patch_shape_y = cfg.generation.patch_shape_y
else:
patch_shape_y = None
patch_shape_x, patch_shape_y = None, None
patch_shape = (patch_shape_y, patch_shape_x)
img_shape, patch_shape = set_patch_shape(img_shape, patch_shape)
if patch_shape != img_shape:
use_patching, img_shape, patch_shape = set_patch_shape(img_shape, patch_shape)
if use_patching:
patching = GridPatching2D(
img_shape=img_shape,
patch_shape=patch_shape,
boundary_pix=cfg.generation.boundary_pix,
overlap_pix=cfg.generation.overlap_pix,
)
logger0.info("Patch-based training enabled")
else:
patching = None
logger0.info("Patch-based training disabled")

# Parse the inference mode
@@ -164,44 +174,27 @@ def main(cfg: DictConfig) -> None:
solver=cfg.sampler.solver,
)
elif cfg.sampler.type == "stochastic":
sampler_fn = partial(
stochastic_sampler,
img_shape=img_shape[1],
patch_shape=patch_shape[1],
boundary_pix=cfg.sampler.boundary_pix,
overlap_pix=cfg.sampler.overlap_pix,
)
sampler_fn = partial(stochastic_sampler, patching=patching)
else:
raise ValueError(f"Unknown sampling method {cfg.sampling.type}")

# Main generation definition
def generate_fn():
img_shape_y, img_shape_x = img_shape
with nvtx.annotate("generate_fn", color="green"):
if cfg.generation.sample_res == "full":
image_lr_patch = image_lr
else:
torch.cuda.nvtx.range_push("rearrange")
image_lr_patch = rearrange(
image_lr,
"b c (h1 h) (w1 w) -> (b h1 w1) c h w",
h1=img_shape_y // patch_shape[0],
w1=img_shape_x // patch_shape[1],
)
torch.cuda.nvtx.range_pop()
image_lr_patch = image_lr_patch.to(memory_format=torch.channels_last)
# (1, C, H, W)
img_lr = image_lr.to(memory_format=torch.channels_last)

if net_reg:
with nvtx.annotate("regression_model", color="yellow"):
image_reg = regression_step(
net=net_reg,
img_lr=image_lr_patch,
img_lr=img_lr,
latents_shape=(
cfg.generation.seed_batch_size,
img_out_channels,
img_shape[0],
img_shape[1],
),
), # (batch_size, C, H, W)
lead_time_label=lead_time_label,
)
if net_res:
@@ -213,16 +206,15 @@ def generate_fn():
image_res = diffusion_step(
net=net_res,
sampler_fn=sampler_fn,
seed_batch_size=cfg.generation.seed_batch_size,
img_shape=img_shape,
img_out_channels=img_out_channels,
rank_batches=rank_batches,
img_lr=image_lr_patch.expand(
img_lr=img_lr.expand(
cfg.generation.seed_batch_size, -1, -1, -1
).to(memory_format=torch.channels_last),
rank=dist.rank,
device=device,
hr_mean=mean_hr,
mean_hr=mean_hr,
lead_time_label=lead_time_label,
)
if cfg.generation.inference_mode == "regression":
@@ -232,13 +224,6 @@ def generate_fn():
else:
image_out = image_reg + image_res

if cfg.generation.sample_res != "full":
image_out = rearrange(
image_out,
"(b h1 w1) c h w -> b c (h1 h) (w1 w)",
h1=img_shape_y // patch_shape[0],
w1=img_shape_x // patch_shape[1],
)
# Gather tensors on rank 0
if dist.world_size > 1:
if dist.rank == 0:
14 changes: 12 additions & 2 deletions examples/generative/corrdiff/helpers/train_helpers.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
import torch
import numpy as np
from omegaconf import ListConfig
import warnings


def set_patch_shape(img_shape, patch_shape):
@@ -26,12 +27,21 @@ def set_patch_shape(img_shape, patch_shape):
patch_shape_x = img_shape_x
if (patch_shape_y is None) or (patch_shape_y > img_shape_y):
patch_shape_y = img_shape_y
if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y:
if patch_shape_x == img_shape_x and patch_shape_y == img_shape_y:
use_patching = False
else:
use_patching = True
if use_patching:
if patch_shape_x != patch_shape_y:
warnings.warn(
f"You are using rectangular patches "
f"of shape {(patch_shape_y, patch_shape_x)}, "
f"which are an experimental feature."
)
raise NotImplementedError("Rectangular patch not supported yet")
if patch_shape_x % 32 != 0 or patch_shape_y % 32 != 0:
raise ValueError("Patch shape needs to be a multiple of 32")
return (img_shape_y, img_shape_x), (patch_shape_y, patch_shape_x)
return use_patching, (img_shape_y, img_shape_x), (patch_shape_y, patch_shape_x)


def set_seed(rank):
105 changes: 74 additions & 31 deletions examples/generative/corrdiff/train.py
Original file line number Diff line number Diff line change
@@ -22,11 +22,16 @@
from modulus import Module
from modulus.models.diffusion import UNet, EDMPrecondSR
from modulus.distributed import DistributedManager
from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper
from modulus.metrics.diffusion import RegressionLoss, ResLoss, RegressionLossCE
from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper
from modulus.metrics.diffusion import RegressionLoss, ResidualLoss, RegressionLossCE
from modulus.utils.patching import RandomPatching2D
from modulus.launch.logging import (
PythonLogger,
RankZeroLoggingWrapper,
initialize_wandb,
)
import wandb
from modulus.launch.utils import load_checkpoint, save_checkpoint
from datasets.dataset import init_train_valid_datasets_from_config
from datasets.dataset import init_train_valid_datasets_from_config, register_dataset
from helpers.train_helpers import (
set_patch_shape,
set_seed,
@@ -37,6 +42,23 @@
)


def checkpoint_list(path, suffix=".mdlus"):
"""Helper function to return sorted list, in ascending order, of checkpoints in a path"""
checkpoints = []
for file in os.listdir(path):
if file.endswith(suffix):
# Split the filename and extract the index
try:
index = int(file.split(".")[-2])
checkpoints.append((index, file))
except ValueError:
continue

# Sort by index and return filenames
checkpoints.sort(key=lambda x: x[0])
return [file for _, file in checkpoints]


# Train the CorrDiff model using the configurations in "conf/config_training.yaml"
@hydra.main(version_base="1.2", config_path="conf", config_name="config_training")
def main(cfg: DictConfig) -> None:
@@ -50,10 +72,22 @@ def main(cfg: DictConfig) -> None:
writer = SummaryWriter(log_dir="tensorboard")
logger = PythonLogger("main") # General python logger
logger0 = RankZeroLoggingWrapper(logger, dist) # Rank 0 logger
wandb.login(key=cfg.wandb.key)
initialize_wandb(
project=cfg.wandb.project,
entity=cfg.wandb.entity,
name=cfg.wandb.name,
mode=cfg.wandb.mode,
)

# Resolve and parse configs
OmegaConf.resolve(cfg)
dataset_cfg = OmegaConf.to_container(cfg.dataset) # TODO needs better handling

# Register custom dataset if specified in config
register_dataset(dataset_cfg.dataset.type)
logger0.info(f"Using dataset: {dataset_cfg.dataset.type}")

if hasattr(cfg, "validation"):
train_test_split = True
validation_dataset_cfg = OmegaConf.to_container(cfg.validation)
@@ -122,13 +156,20 @@ def main(cfg: DictConfig) -> None:
patch_shape_x = None
patch_shape_y = None
patch_shape = (patch_shape_y, patch_shape_x)
img_shape, patch_shape = set_patch_shape(img_shape, patch_shape)
if patch_shape != img_shape:
use_patching, img_shape, patch_shape = set_patch_shape(img_shape, patch_shape)
if use_patching:
# Utility to perform patches extraction and batching
patching = RandomPatching2D(
img_shape=img_shape,
patch_shape=patch_shape,
patch_num=getattr(cfg.training.hp, "patch_num", 1),
)
logger0.info("Patch-based training enabled")
else:
patching = None
logger0.info("Patch-based training disabled")
# interpolate global channel if patch-based model is used
if img_shape[1] != patch_shape[1]:
if use_patching:
img_in_channels += dataset_channels

# Instantiate the model and move to device.
@@ -147,41 +188,23 @@ def main(cfg: DictConfig) -> None:
}
standard_model_cfgs = { # default parameters for different network types
"regression": {
"img_channels": 4,
"N_grid_channels": 4,
"embedding_type": "zero",
"checkpoint_level": songunet_checkpoint_level,
},
"lt_aware_ce_regression": {
"img_channels": 4,
"N_grid_channels": 4,
"embedding_type": "zero",
"lead_time_channels": 4,
"lead_time_steps": 9,
"prob_channels": prob_channels,
"checkpoint_level": songunet_checkpoint_level,
"model_type": "SongUNetPosLtEmbd",
},
"diffusion": {
"img_channels": img_out_channels,
"gridtype": "sinusoidal",
"N_grid_channels": 4,
"checkpoint_level": songunet_checkpoint_level,
},
"patched_diffusion": {
"img_channels": img_out_channels,
"gridtype": "learnable",
"N_grid_channels": 100,
"checkpoint_level": songunet_checkpoint_level,
},
"lt_aware_patched_diffusion": {
"img_channels": img_out_channels,
"gridtype": "learnable",
"N_grid_channels": 100,
"lead_time_channels": 20,
"lead_time_steps": 9,
"checkpoint_level": songunet_checkpoint_level,
"model_type": "SongUNetPosLtEmbd",
},
}
model_args.update(standard_model_cfgs[cfg.model.name])
@@ -229,6 +252,8 @@ def main(cfg: DictConfig) -> None:
output_device=dist.device,
find_unused_parameters=dist.find_unused_parameters,
)
if cfg.wandb.watch_model and dist.rank == 0:
wandb.watch(model)

# Load the regression checkpoint if applicable
if hasattr(cfg.training.io, "regression_checkpoint_path"):
@@ -244,19 +269,15 @@ def main(cfg: DictConfig) -> None:
logger0.success("Loaded the pre-trained regression model")

# Instantiate the loss function
patch_num = getattr(cfg.training.hp, "patch_num", 1)
if cfg.model.name in (
"diffusion",
"patched_diffusion",
"lt_aware_patched_diffusion",
):
loss_fn = ResLoss(
loss_fn = ResidualLoss(
regression_net=regression_net,
img_shape_x=img_shape[1],
img_shape_y=img_shape[0],
patch_shape_x=patch_shape[1],
patch_shape_y=patch_shape[0],
patch_num=patch_num,
img_shape_x=img_shape[1],
hr_mean_conditioning=cfg.model.hr_mean_conditioning,
)
elif cfg.model.name == "regression":
@@ -317,12 +338,15 @@ def main(cfg: DictConfig) -> None:
img_clean = img_clean.to(dist.device).to(torch.float32).contiguous()
img_lr = img_lr.to(dist.device).to(torch.float32).contiguous()
labels = labels.to(dist.device).contiguous()
# Sample new random patches for this iteration
patching.reset_patch_indices()
loss_fn_kwargs = {
"net": model,
"img_clean": img_clean,
"img_lr": img_lr,
"labels": labels,
"augment_pipe": None,
"patching": patching,
}
if lead_time_label:
lead_time_label = lead_time_label[0].to(dist.device).contiguous()
@@ -352,6 +376,12 @@ def main(cfg: DictConfig) -> None:
writer.add_scalar(
"training_loss_running_mean", average_loss_running_mean, cur_nimg
)
wandb.log(
{
"training_loss": average_loss,
"training_loss_running_mean": average_loss_running_mean,
}
)

ptt = is_time_for_periodic_task(
cur_nimg,
@@ -435,6 +465,11 @@ def main(cfg: DictConfig) -> None:
writer.add_scalar(
"validation_loss", average_valid_loss, cur_nimg
)
wandb.log(
{
"validation_loss": average_valid_loss,
}
)

if is_time_for_periodic_task(
cur_nimg,
@@ -486,6 +521,14 @@ def main(cfg: DictConfig) -> None:
epoch=cur_nimg,
)

# Retain only the recent n checkpoints, if desired
if cfg.training.io.save_n_recent_checkpoints > 0:
for suffix in [".mdlus", ".pt"]:
ckpts = checkpoint_list(checkpoint_dir, suffix=suffix)
while len(ckpts) > cfg.training.io.save_n_recent_checkpoints:
os.remove(os.path.join(checkpoint_dir, ckpts[0]))
ckpts = ckpts[1:]

# Done.
logger0.info("Training Completed.")

2 changes: 1 addition & 1 deletion modulus/metrics/diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@
EDMLossSR,
RegressionLoss,
RegressionLossCE,
ResLoss,
ResidualLoss,
VELoss,
VELoss_dfsr,
VPLoss,
Loading