diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d8c9bbce4..6501510272 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,10 +11,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added ReGen score-based data assimilation example +- General purpose patching API for patch-based diffusion +- New positional embedding selection strategy for CorrDiff SongUNet models - Added Multi-Storage Client to allow checkpointing to/from Object Storage ### Changed +- Simplified CorrDiff config files, updated default values +- Refactored CorrDiff losses and samplers to use the patching API +- Support for non-square images and patches in patch-based diffusion + ### Deprecated ### Removed diff --git a/docs/api/physicsnemo.utils.rst b/docs/api/physicsnemo.utils.rst index 767146a630..3e4e759a8c 100644 --- a/docs/api/physicsnemo.utils.rst +++ b/docs/api/physicsnemo.utils.rst @@ -40,6 +40,14 @@ Filesystem utils Generative utils ---------------- +.. automodule:: physicsnemo.utils.generative.deterministic_sampler + :members: + :show-inheritance: + +.. automodule:: physicsnemo.utils.generative.stochastic_sampler + :members: + :show-inheritance: + .. automodule:: physicsnemo.utils.generative.utils :members: :show-inheritance: @@ -62,4 +70,11 @@ Weather / Climate utils :show-inheritance: .. automodule:: physicsnemo.utils.zenith_angle + :show-inheritance: + +Patching utils +-------------- + +.. automodule:: physicsnemo.utils.patching + :members: :show-inheritance: \ No newline at end of file diff --git a/docs/img/corrdiff_training_loss.png b/docs/img/corrdiff_training_loss.png new file mode 100644 index 0000000000..d6e9659e3a Binary files /dev/null and b/docs/img/corrdiff_training_loss.png differ diff --git a/examples/generative/corrdiff/README.md b/examples/generative/corrdiff/README.md index dfa9bbf072..94c641288a 100644 --- a/examples/generative/corrdiff/README.md +++ b/examples/generative/corrdiff/README.md @@ -1,10 +1,32 @@ # Generative Correction Diffusion Model (CorrDiff) for Km-scale Atmospheric Downscaling +## Table of Contents +- [Generative Correction Diffusion Model (CorrDiff) for Km-scale Atmospheric Downscaling](#generative-correction-diffusion-model-corrdiff-for-km-scale-atmospheric-downscaling) + - [Table of Contents](#table-of-contents) + - [Problem overview](#problem-overview) + - [Getting started with the HRRR-Mini example](#getting-started-with-the-hrrr-mini-example) + - [Preliminaries](#preliminaries) + - [Configuration basics](#configuration-basics) + - [Training the regression model](#training-the-regression-model) + - [Training the diffusion model](#training-the-diffusion-model) + - [Generation](#generation) + - [Another example: Taiwan dataset](#another-example-taiwan-dataset) + - [Dataset \& Datapipe](#dataset--datapipe) + - [Training the models](#training-the-models) + - [Sampling and Model Evaluation](#sampling-and-model-evaluation) + - [Logging and Monitoring](#logging-and-monitoring) + - [Training CorrDiff on a Custom Dataset](#training-corrdiff-on-a-custom-dataset) + - [Defining a Custom Dataset](#defining-a-custom-dataset) + - [Training configuration](#training-configuration) + - [Generation configuration](#generation-configuration) + - [FAQs](#faqs) + - [References](#references) + ## Problem overview To improve weather hazard predictions without expensive simulations, a cost-effective -stochasticdownscaling model, [CorrDiff](https://arxiv.org/abs/2309.15214), is trained +stochastic downscaling model, [CorrDiff](https://arxiv.org/abs/2309.15214), is trained using high-resolution weather data and coarser ERA5 reanalysis. CorrDiff employs a two-step approach with UNet and diffusion to address multi-scale challenges, showing strong performance in predicting weather @@ -16,220 +38,495 @@ weather forecasts.

-## Getting started +## Getting started with the HRRR-Mini example + +To get started with CorrDiff, we provide a simplified version called CorrDiff-Mini that combines: + +1. A smaller neural network architecture that reduces memory usage and training time +2. A reduced training dataset, based on the HRRR dataset, that contains fewer samples (available at [NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/resources/modulus_datasets-hrrr_mini)) -To build custom CorrDiff versions, you can get started by training the "Mini" version of CorrDiff, which uses smaller training samples and a smaller network to reduce training costs from thousands of GPU hours to around 10 hours on A100 GPUs while still producing reasonable results. It also includes a simple data loader that can be used as a baseline for training CorrDiff on custom datasets. +Together, these modifications reduce training time from thousands of GPU hours to around 10 hours on A100 GPUs. The simplified data loader included with CorrDiff-Mini also serves as a helpful example for training CorrDiff on custom datasets. Note that CorrDiff-Mini is intended for learning and educational purposes only - its predictions should not be used for real applications. ### Preliminaries Start by installing PhysicsNeMo (if not already installed) and copying this folder (`examples/generative/corrdiff`) to a system with a GPU available. Also download the CorrDiff-Mini dataset from [NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/resources/modulus_datasets-hrrr_mini). ### Configuration basics -CorrDiff training is handled by `train.py` and controlled by YAML configuration files handled by [Hydra](https://hydra.cc/docs/intro/). Prebuilt configuration files are found in the `conf` directory. You can choose the configuration file using the `--config-name` option. The main configuration file specifies the training dataset, the model configuration and the training options. The details of these are given in the corresponding configuration files. To change a configuration option, you can either edit the configuration files or use the Hydra command line overrides. For example, the training batch size is controlled by the option `training.hp.total_batch_size`. We can override this from the command line with the `++` syntax: `python train.py ++training.hp.total_batch_size=64` would set run the training with the batch size set to 64. +CorrDiff training is managed through `train.py` and uses YAML configuration files powered by [Hydra](https://hydra.cc/docs/intro/). The configuration system is organized as follows: + +- **Base Configurations**: Located in the `conf/base` directory +- **Configuration Files**: + - **Training Configurations**: + - GEFS-HRRR dataset (continental United States): + - `conf/config_training_gefs_hrrr_regression.yaml` - Configuration for training the regression model on GEFS-HRRR dataset + - `conf/config_training_gefs_hrrr_diffusion.yaml` - Configuration for training the diffusion model on GEFS-HRRR dataset + - HRRR-Mini dataset (smaller continental United States,): + - `conf/config_training_hrrr_mini_regression.yaml` - Simplified regression model training setup for the HRRR-Mini example + - `conf/config_training_hrrr_mini_diffusion.yaml` - Simplified diffusion model training setup for the HRRR-Mini example + - Taiwan dataset: + - `conf/config_training_taiwan_regression.yaml` - Configuration for training the regression model on Taiwan weather data + - `conf/config_training_taiwan_diffusion.yaml` - Configuration for training the diffusion model on Taiwan weather data + - Custom dataset: + - `conf/config_training_custom.yaml` - Template configuration for training on custom datasets + - **Generation Configurations**: + - `conf/config_generate_taiwan.yaml` - Settings for generating predictions using Taiwan-trained models + - `conf/config_generate_hrrr_mini.yaml` - Settings for generating predictions using HRRR-Mini models + - `conf/config_generate_gefs_hrrr.yaml` - Settings for generating predictions using GEFS-HRRR models + - `conf/config_generate_custom.yaml` - Template configuration for generation with custom trained models + +To select a specific configuration, use the `--config-name` option when running the training script. Each training configuration file defines three main components: +1. Training dataset parameters +2. Model architecture settings +3. Training hyperparameters + +You can modify configuration options in two ways: +1. **Direct Editing**: Modify the YAML files directly +2. **Command Line Override**: Use Hydra's `++` syntax to override settings at runtime + +For example, to change the training batch size (controlled by `training.hp.total_batch_size`): +```bash +python train.py ++training.hp.total_batch_size=64 # Sets batch size to 64 +``` + +This modular configuration system allows for flexible experimentation while maintaining reproducibility. ### Training the regression model -To train the CorrDiff-Mini regression model, we use the main configuration file [config_training_mini_regression.yaml](conf/config_training_mini_regression.yaml). This includes the following components: -* The HRRR-Mini dataset: [conf/dataset/hrrrmini.yaml](conf/dataset/hrrrmini.yaml) -* The GEFS-HRRR dataset: [conf/dataset/hrrrmini.yaml](conf/dataset/gefs_hrrr.yaml) -* The CorrDiff-Mini regression model: [conf/model/corrdiff_regression_mini.yaml](conf/model/corrdiff_regression_mini.yaml) -* The CorrDiff-Mini regression training options: [conf/training/corrdiff_regression_mini.yaml](conf/training/corrdiff_regression_mini.yaml) -* The CorrDiff-GEFS-HRRR regression training options: [conf/model/corrdiff_regression_mini.yaml](conf/training/config_training_gefs_regression.yaml) - -To start the training, run: + +CorrDiff uses a two-step training process: +1. Train a deterministic regression model +2. Train a diffusion model using the pre-trained regression model + +For the CorrDiff-Mini regression model, we use the following configuration components: + +The top-level configuration file `config_training_hrrr_mini_regression.yaml` contains the most commonly modified parameters: +- `dataset`: Dataset type and paths (`hrrr_mini`, `gefs_hrrr`, `cwb`, or `custom`) +- `model`: Model architecture type (`regression`, `diffusion`, etc.) +- `model_size`: Model capacity (`normal` or `mini` for faster experiments) +- `training`: High-level training parameters (duration, batch size, IO settings) +- `wandb`: Weights & Biases logging settings (`mode`, `results_dir`, `watch_model`) + +This configuration automatically loads these specific files from `conf/base`: +- `dataset/hrrr_mini.yaml`: HRRR-Mini dataset parameters (data paths, variables) +- `model/regression.yaml`: Regression UNet architecture settings +- `model_size/mini.yaml`: Reduced model capacity settings for faster training +- `training/regression.yaml`: Training loop parameters specific to regression model + +These base configuration files contain more detailed settings that are less commonly modified but give fine-grained control over the training process. + +To begin training, execute the following command using [`train.py`](train.py): ```bash -python train.py --config-name=config_training_mini_regression.yaml ++dataset.data_path=/hrrr_mini_train.nc ++dataset.stats_path=/stats.json +python train.py --config-name=config_training_hrrr_mini_regression.yaml ``` -where you should replace both instances of `` with the absolute path to the directory containing the downloaded HRRR-Mini dataset. -The training will require a few hours on a single A100 GPU. If training is interrupted, it will automatically continue from the latest checkpoint when restarted. Multi-GPU and multi-node training are supported and will launch automatically when the training is run in a `torchrun` or MPI environment. +**Training Details:** +- Duration: A few hours on a single A100 GPU +- Checkpointing: Automatically resumes from latest checkpoint if interrupted +- Multi-GPU Support: Compatible with `torchrun` or MPI for distributed training -The results, including logs and checkpoints, are saved by default to `outputs/mini_generation/`. You can direct the checkpoints to be saved elsewhere by setting: `++training.io.checkpoint_dir=`. - -> **_Out of memory?_** CorrDiff-Mini trains by default with a batch size of 256 (set by `training.hp.total_batch_size`). If you're using a single GPU, especially one with a smaller amout of memory, you might see out-of-memory error. If that happens, set a smaller batch size per GPU, e.g.: `++training.hp.batch_size_per_gpu=64`. CorrDiff training will then automatically use gradient accumulation to train with an effective batch size of `training.hp.total_batch_size`. +> **💡 Memory Management** +> The default configuration uses a batch size of 256 (controlled by `training.hp.total_batch_size`). If you encounter memory constraints, particularly on GPUs with limited memory, you can reduce the per-GPU batch size by setting `++training.hp.batch_size_per_gpu=64`. CorrDiff will automatically employ gradient accumulation to maintain the desired effective batch size while using less memory. ### Training the diffusion model -The pre-trained regression model is needed to train the diffusion model. Assuming you trained the regression model for the default 2 million samples, the final checkpoint will be `checkpoints_regression/UNet.0.2000000.mdlus`. -Save the final regression checkpoint into a new location, then run: +After successfully training the regression model, you can proceed with training the diffusion model. The process requires: + +- A pre-trained regression model checkpoint +- The same dataset used for regression training +- Configuration file [`conf/config_training_hrrr_mini_diffusion.yaml`](conf/config_training_hrrr_mini_diffusion.yaml) + +To start the diffusion model training, execute: ```bash -python train.py --config-name=config_training_mini_diffusion.yaml ++dataset.data_path=/hrrr_mini_train.nc ++dataset.stats_path=/stats.json ++training.io.regression_checkpoint_path= +python train.py --config-name=config_training_hrrr_mini_diffusion.yaml \ + ++training.io.regression_checkpoint_path= ``` where `` should point to the saved regression checkpoint. -Once the training is completed, copy the latest checkpoint (`checkpoints_diffusion/EDMPrecondSR.0.8000000.mdlus`) to a file. +The training will generate checkpoints in the `checkpoints_diffusion` directory. Upon completion, the final model will be saved as `EDMPrecondSR.0.8000000.mdlus`. ### Generation -Use the `generate.py` script to generate samples with the trained networks: +Once both models are trained, you can use [`generate.py`](generate.py) to create new predictions. The generation process requires: + +**Required Files:** +- Trained regression model checkpoint +- Trained diffusion model checkpoint +- Configuration file [`conf/config_generate_hrrr_mini.yaml`](conf/config_generate_hrrr_mini.yaml) + +Execute the generation command: ```bash -python generate.py --config-name="config_generate_mini.yaml" ++generation.io.res_ckpt_filename= ++generation.io.reg_ckpt_filename= ++generation.io.output_filename= +python generate.py --config-name="config_generate_hrrr_mini.yaml" \ + ++generation.io.res_ckpt_filename= \ + ++generation.io.reg_ckpt_filename= ``` -where `` and `` should point to the regression and diffusion model checkpoints, respectively, and `` indicates the output NetCDF4 file. -You can open the output file with e.g. the Python NetCDF4 library. The inputs are saved in the `input` group of the file, the ground truth data in the `truth` group, and the CorrDiff prediction in the `prediction` group. +The output is saved as a NetCDF4 file containing three groups: +- `input`: The original input data +- `truth`: The ground truth data for comparison +- `prediction`: The CorrDiff model predictions -## Configs +You can analyze the results using the Python NetCDF4 library or visualization tools of your choice. -The `conf` directory contains the configuration files for the model, data, -training, etc. The configs are given in YAML format and use the `omegaconf` -library to manage them. Several example configs are given for training -different models that are regression, diffusion, and patched-based diffusion -models. -The default configs are set to train the regression model. -To train the other models, please adjust `conf/config_training.yaml` -according to the comments. Alternatively, you can create a new config file -and specify it using the `--config-name` option. +## Another example: Taiwan dataset +### Dataset & Datapipe -## Dataset & Datapipe +The Taiwan example demonstrates CorrDiff training on a high-resolution weather dataset conditioned on the low-resolution [ERA5 dataset](https://www.ecmwf.int/en/forecasts/dataset/ecmwf-reanalysis-v5). This dataset is available for non-commercial use under the [CC BY-NC-ND 4.0 license](https://creativecommons.org/licenses/by-nc-nd/4.0/legalcode.en). -In this example, CorrDiff training is demonstrated on the Taiwan dataset, -conditioned on the [ERA5 dataset](https://www.ecmwf.int/en/forecasts/dataset/ecmwf-reanalysis-v5). -We have made this dataset available for non-commercial use under the -[CC BY-NC-ND 4.0 license](https://creativecommons.org/licenses/by-nc-nd/4.0/legalcode.en) -and can be downloaded from [https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/resources/modulus_datasets_cwa](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/resources/modulus_datasets_cwa) -by `ngc registry resource download-version "nvidia/modulus/modulus_datasets_cwa:v1"`. -The datapipe in this example is tailored specifically for the Taiwan dataset. -A light-weight datapipe for the HRRR dataset is also available and can be used -with the CorrDiff-mini model. -For other datasets, you will need to create a custom datapipe. -You can use the lightweight HRRR datapipe as a starting point for developing your new one. +**Dataset Access:** +- Location: [NGC Catalog - CWA Dataset](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/resources/modulus_datasets_cwa) +- Download Command: + ```bash + ngc registry resource download-version "nvidia/modulus/modulus_datasets_cwa:v1" + ``` +### Training the models -## Training the models +The Taiwan example supports three types of models, each serving a different purpose: +1. **Regression Model**: Basic deterministic model +2. **Diffusion Model**: Full probabilistic model +3. **Patch-based Diffusion Model**: Memory-efficient variant that processes small spatial regions to improve scalability -There are several models available for training in this example, including -a regression, a diffusion, and a patched-based diffusion model. -The Patch-based diffusion model uses small subsets of the target region during -both training and generation to enhance the scalability of the model. -Apart from the dataset configs the main configs for training are `model`, -`training`, and `validation`. These can be adjusted accordingly depending on -whether you are training the regression, diffusion, or the patch-based -diffusion model. Note that training the varients of the diffusion model -requres a trained regression checkpoint, and the path to that checkpoint should -be included in the `conf/training/corrdiff_diffusion.yaml ` file. -Therefore, you should start with training -a regression model, followed by training a diffusion model. To choose which model -to train, simply change the configs in `conf/config_training.yaml`. +The patch-based approach divides the target region into smaller subsets during both training and generation, making it particularly useful for memory-constrained environments or large spatial domains. -For training the regression model, your `config_training.yaml` should be: +**Configuration Structure:** -``` -hydra: - job: - chdir: true - name: regression - run: - dir: ./outputs/${hydra:job.name} +The top-level configuration file `config_training_taiwan_regression.yaml` contains commonly modified parameters: +- `dataset`: Set to `cwb` for the Taiwan Central Weather Bureau dataset +- `model`: Model type (`regression`, `diffusion`, or `patched_diffusion`) +- `model_size`: Model capacity (`normal` recommended for Taiwan dataset) +- `training.hp`: Training duration and batch size settings +- `wandb`: Experiment tracking configuration -defaults: +This configuration automatically loads these specific files from `conf/base`: +- `dataset/cwb.yaml`: Taiwan dataset parameters +- `model/regression.yaml` or `model/diffusion.yaml`: Model architecture settings +- `training/regression.yaml` or `training/diffusion.yaml`: Training parameters - # Dataset - - dataset/cwb_train +When training the diffusion variants, you'll need to specify the path to your pre-trained regression checkpoint in `training.io.regression_checkpoint_path`. This is essential as the diffusion model learns to predict residuals on top of the regression model's predictions. - # Model - - model/corrdiff_regression +**Training Commands:** - # Training - - training/corrdiff_regression +For single-GPU training: +```bash +python train.py --config-name=config_training_taiwan_regression.yaml +``` - # Validation - - validation/basic - ``` +For multi-GPU or multi-node training: +```bash +torchrun --standalone --nnodes= --nproc_per_node= train.py +``` -Similarly, for taining of the diffusion model, you should have: +To switch between model types, simply change the configuration name in the training command (e.g., `config_training_taiwan_diffusion.yaml` for the diffusion model). -``` -hydra: - job: - chdir: true - name: diffusion - run: - dir: ./outputs/${hydra:job.name} +### Sampling and Model Evaluation -defaults: +The evaluation pipeline for CorrDiff models consists of two main components: - # Dataset - - dataset/cwb_train +1. **Sample Generation** ([`generate.py`](generate.py)): + Generates predictions and saves them in a netCDF file format. The process uses configuration settings from [`conf/config_generate.yaml`](conf/config_generate.yaml). + ```bash + python generate.py --config-name=config_generate_taiwan.yaml + ``` - # Model - - model/corrdiff_diffusion +2. **Performance Scoring** ([`score_samples.py`](score_samples.py)): + Computes both deterministic metrics (like MSE, MAE) and probabilistic scores for the generated samples. + ```bash + python score_samples.py path= output= + ``` - # Training - - training/corrdiff_diffusion +For visualization and analysis, you have several options: +- Use the plotting scripts in the [`inference`](inference/) directory +- Visualize results with [Earth2Studio](https://github.com/NVIDIA/earth2studio) +- Create custom visualizations using the NetCDF4 output structure - # Validation - - validation/basic -``` +### Logging and Monitoring -To train the model, run +CorrDiff supports two powerful tools for experiment tracking and visualization: -```python train.py``` +**TensorBoard Integration:** +TensorBoard provides real-time monitoring of training metrics when running in a Docker container: -You can monitor the training progress using TensorBoard. -Open a new terminal, navigate to the example directory, and run: +1. Configure Docker: + ```bash + docker run -p 6006:6006 ... # Include port forwarding + ``` -```tensorboard --logdir=outputs/``` +2. Start TensorBoard: + ```bash + tensorboard --logdir=/path/to/logdir --port=6006 + ``` -If using a shared cluster, you may need to forward the port to see the tensorboard logs. -Data parallelism is supported. Use `torchrun` -To launch a multi-GPU or multi-node training: +3. Set up SSH tunnel (for remote servers): + ```bash + ssh -L 6006:localhost:6006 @ + ``` -```torchrun --standalone --nnodes= --nproc_per_node= train.py``` +4. Access the dashboard at `http://localhost:6006` -### Sampling and Model Evaluation +**Weights & Biases Integration:** +CorrDiff includes integration with Weights & Biases for experiment tracking. The following parameters are hardcoded in the code: -Model evaluation is split into two components. `generate.py` creates a netCDF file -for the generated outputs, and `score_samples.py` computes deterministic and probablistic -scores. +- Project name: "Modulus-Launch" +- Entity: "Modulus" +- Run name: Generated based on configuration job name +- Group: "CorrDiff-DDP-Group" -To generate samples and save output in a netCDF file, run: +You can configure the following wandb parameters in the configuration files: -```bash -python generate.py +```yaml +wandb: + mode: offline # Options: "online", "offline", "disabled" + results_dir: "./wandb" # Directory to store wandb results + watch_model: true # Whether to track model parameters and gradients ``` -This will use the base configs specified in the `conf/config_generate.yaml` file. -Next, to score the generated samples, run: - -```bash -python score_samples.py path= output= +To use wandb: + +1. Initialize wandb (first time only): + ```bash + wandb login + ``` + +2. Training runs will automatically log to the wandb project, tracking: + - Training and validation metrics + - Model architecture details + - System resource usage + - Hyperparameters + +You can access your experiment dashboard at Weights & Biases website. + +## Training CorrDiff on a Custom Dataset + +This repository includes examples of **CorrDiff** training on specific datasets, such as **Taiwan** and **HRRR**. However, many use cases require training **CorrDiff** on a **custom high-resolution dataset**. The steps below outline the process. + +### Defining a Custom Dataset + +To train CorrDiff on a custom dataset, you need to implement a custom dataset class that inherits from `DownscalingDataset` defined in [`datasets/base.py`](./datasets/base.py). This base class defines the interface that all dataset implementations must follow. + +**Required Implementation:** + +1. Your dataset class must inherit from `DownscalingDataset` and implement its + abstract methods, for example: + - `longitude()` and `latitude()`: Return coordinate arrays + - `input_channels()` and `output_channels()`: Define metadata for input/output variables + - `time()`: Return time values + - `image_shape()`: Return spatial dimensions + - `__len__()`: Return total number of samples + - `__getitem__()`: Return data for a given index + +The most important method is `__getitem__`, which must return a tuple of tensors: +```python +def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Returns: + Tuple containing: + - img_clean: Target high-resolution data [output_channels, height, width] + - img_lr: Input low-resolution data [input_channels, height, width] + - lead_time_label: (Optional) Lead time information [1] + """ + # Your implementation here + # For basic implementation without lead time: + return img_clean, img_lr + # If including lead time information: + # return img_clean, img_lr, lead_time_label ``` -Some legacy plotting scripts are also available in the `inference` directory. -You can also bring your checkpoints to [earth2studio] -for further anaylysis and visualizations. +2. Configure your dataset in the YAML configuration file. Any parameters below + will be passed to your dataset's `__init__` method. For example: + ```yaml + dataset: + type: path/to/your/dataset.py::CustomDataset # Path to file::class name + # All parameters below will be passed to your dataset's __init__ + data_path: /path/to/your/data + stats_path: /path/to/statistics.json # Optional normalization stats + input_variables: ["temperature", "pressure"] # Example parameters + output_variables: ["high_res_temperature"] + invariant_variables: ["topography"] # Optional static fields + # Add any other parameters needed by your dataset class + ``` + +**Important Notes:** +- The training script will automatically: + 1. Parse the `type` field to locate your dataset file and class + 2. Register your custom dataset class using `register_dataset()` + 3. Pass all other fields in the `dataset` section as kwargs to your class constructor +- All tensors should be properly normalized (use `normalize_input`/`normalize_output` methods if needed) +- Ensure consistent dimensions across all samples +- Channel metadata should accurately describe your data variables + + +For reference implementations of dataset classes, look at: +- [`datasets/hrrrmini.py`](./datasets/hrrrmini.py) - Simple example using NetCDF format +- [`datasets/cwb.py`](./datasets/cwb.py) - More complex example -## Logging -We use TensorBoard for logging training and validation losses, as well as -the learning rate during training. To visualize TensorBoard running in a -Docker container on a remote server from your local desktop, follow these steps: +### Training configuration -1. **Expose the Port in Docker:** - Expose port 6006 in the Docker container by including - `-p 6006:6006` in your docker run command. +After implementing your custom dataset, you can proceed with the two-step training process followed by generation. The configuration system uses a hierarchical structure that balances ease of use with detailed control over the training process. -2. **Launch TensorBoard:** - Start TensorBoard within the Docker container: - ```bash - tensorboard --logdir=/path/to/logdir --port=6006 - ``` +**Top-level Configuration** (`config_training_custom.yaml`): +This file serves as your primary interface for configuring the training process. It contains commonly modified parameters that can be set either directly in the file or through command-line overrides: -3. **Set Up SSH Tunneling:** - Create an SSH tunnel to forward port 6006 from the remote server to your local machine: - ```bash - ssh -L 6006:localhost:6006 @ - ``` - Replace `` with your SSH username and `` with the IP address - of your remote server. You can use a different port if necessary. +- `dataset`: Configuration for your custom dataset implementation, including paths and variables +- `model`: Core model settings, including type selection (`regression` or `diffusion`) +- `training`: High-level training parameters like batch size and duration +- `wandb`: Weights & Biases logging settings (`mode`, `results_dir`, `watch_model`) -4. **Access TensorBoard:** - Open your web browser and navigate to `http://localhost:6006` to view TensorBoard. +**Fine-grained Control**: +The base configuration files in `conf/base/` provide detailed control over specific components. These files are automatically loaded based on your top-level choices: + +- `model/*.yaml`: Contains architecture-specific settings for network depth, attention mechanisms, and embedding configurations +- `training/*.yaml`: Defines training loop behavior, including optimizer settings and checkpoint frequency +- `model_size/*.yaml`: Provides preset configurations for different model capacities + +While direct modification of these base files is typically unnecessary, any +parameter can be overridden using Hydra's `++` syntax. For example, to reduce +the learning rate to 0.0001: + +```bash +python train.py --config-name=config_training_custom.yaml ++training.hp.lr=0.0001 +``` + +This configuration system allows you to start with sensible defaults while maintaining the flexibility to customize any aspect of the training process. + +You can directly modify the training configuration file to change the dataset, +model, and training parameters, or use Hydra's `++` syntax to override +them. Once the regression model is trained, proceed with training the diffusion +model. During training, you can fine-tune various parameters. The most commonly adjusted parameters include: + +- `training.hp.total_batch_size`: Controls the total batch size across all GPUs +- `training.hp.batch_size_per_gpu`: Adjusts per-GPU memory usage +- `training.hp.patch_shape_x/y`: Sets dimensions for patch-based training +- `training.hp.training_duration`: Defines total training steps +- `training.hp.lr_rampup`: Controls learning rate warmup period + +> **Starting with a Small Model** +> When developing a new dataset implementation, it is recommended to start with a smaller model for faster iteration and debugging. You can do this by setting `model_size: mini` in your configuration file: +> ```yaml +> defaults: +> - model_size: mini # Use smaller architecture for testing +> ``` +> This is similar to the model used in the HRRR-Mini example and can +> significantly reduce testing time. After debugging, you can switch +> back to the full model by setting the `model_size` setting to `normal`. + +> **Note on Patch Size Selection** +> When implementing a patch-based training, choosing the right patch size is critical for model performance. The patch dimensions are controlled by `patch_shape_x` and `patch_shape_y` in your configuration file. To determine optimal patch sizes: +> 1. Calculate the auto-correlation function of your data using the provided utilities in [`inference/power_spectra.py`](./inference/power_spectra.py): +> - `average_power_spectrum()` +> - `power_spectra_to_acf()` +> 2. Set patch dimensions to match or exceed the distance at which auto-correlation approaches zero +> 3. This ensures each patch captures the full spatial correlation structure of your data +> +> This analysis helps balance computational efficiency with the preservation of important physical relationships in your data. + +### Generation configuration + +After training both models successfully, you can use CorrDiff's generation pipeline to create predictions. The generation system follows a similar hierarchical configuration structure as training. + +**Top-level Configuration** (`config_generate_custom.yaml`): +This file serves as the main interface for controlling the generation process. It defines essential parameters that can be modified either directly or through command-line overrides. + +**Fine-grained Control**: +The base configuration files in `conf/base/generation` provide fine-grained control over +the generation process. For example, `sampling/stochastic.yaml` controls the +stochastic sampling process (noise scheduling, number of sampling steps, +classifier-free guidance settings). While these base configurations are typically used as-is, you can override any +parameter directly in the configuration file or using Hydra's `++` syntax. For +example to increase the number of ensembles generated per input, you can run: + +```bash +python generate.py --config-name=config_generate_custom.yaml \ + ++generation.io.res_ckpt_filename=/path/to/diffusion/checkpoint.mdlus \ + ++generation.io.reg_ckpt_filename=/path/to/regression/checkpoint.mdlus \ + ++dataset.type=path/to/your/dataset.py::CustomDataset \ + ++generation.num_ensembles=10 +``` + +Key generation parameters that can be adjusted include for example: + +- `generation.num_ensembles`: Number of samples to generate per input +- `generation.patch_shape_x/y`: Patch dimensions for patch-based generation + +The generated samples are saved in a NetCDF file with three main components: +- Input data: The original low-resolution inputs +- Ground truth: The actual high-resolution data (if available) +- Predictions: The model-generated high-resolution outputs + +### FAQs + +1. **Are there pre-trained checkpoints available and when should they be used for training/inference?** + Pre-trained checkpoints are available through NVIDIA AI Enterprise. For + example, a trained model for the continental United States on the GEFS-HRRR + dataset is available + [here](https://build.nvidia.com/nvidia/corrdiff/modelcard). However, note + that these checkpoints are not necessarily compatible with the current + implementation of `train.py` and `generate.py` in CorrDiff. Typically, these + checkpoints should only be used for inference in + [Earth2Studio](https://github.com/NVIDIA/earth2studio). It is therefore generally + recommended to start training CorrDiff models from a scratch. If you do have + a checkpoint compatible with the current implementation of `train.py` and + `generate.py` (e.g. from one of your own previous training run), it is + recommended to restart training from your checkpoint rather than from scratch + if the following conditions are met: + - Your custom dataset covers a region included in the training data of the checkpoint (e.g., a sub-region of the continental United States for the checkpoint mentioned above). + - At most half of the variables in your dataset are also included in the training data of the checkpoint. + + Training from scratch is recommended for all other cases. + +1. **How many samples are needed to train a CorrDiff model?** + The more, the better. As a rule of thumb, at least 50,000 samples are necessary. + *Note: For patch-based diffusion, each patch can be counted as a sample.* + +2. **How many GPUs are required to train CorrDiff?** + A single GPU is sufficient as long as memory is not exhausted, but this may + result in extremely slow training. To accelerate training, CorrDiff + leverages distributed data parallelism. The total training wall-clock time + roughly decreases linearly with the number of GPUs. Most CorrDiff training + examples have been conducted with 64 A100 GPUs. If you encounter an + out-of-memory error, reduce `batch_size_per_gpu` or, for + patch-based diffusion models, decrease the patch size—ensuring it remains + larger than the auto-correlation distance. + +3. **How long does it take to train CorrDiff on a custom dataset?** + Training CorrDiff on the continental United States dataset required + approximately 5,000 A100 GPU hours. This corresponds to roughly 80 hours of + wall-clock time with 64 GPUs. You can expect the cost to scale + linearly with the number of samples available. + +4. **What are CorrDiff's current limitations for custom datasets?** + The main limitation of CorrDiff is the maximum _downscaling ratio_ it can + achieve. For a purely spatial super-resolution task (where input and output variables are the same), CorrDiff can reliably achieve a maximum resolution scaling of ×16. If the task involves inferring new output variables, the maximum reliable spatial super-resolution is ×11. + +5. **What does a successful training look like?** + In a successful training run, the loss function should decrease monotonically, as shown below: +

+ +

-**Note:** Ensure the remote server’s firewall allows connections on port `6006` -and that your local machine’s firewall allows outgoing connections. +1. **Which hyperparameters are most important?** + One of the most crucial hyperparameters is the patch size for a patch-based + diffusion model (`patch_shape_x` and `patch_shape_y` in the configuration file). A larger + patch size increases computational cost and GPU memory requirements, while a + smaller patch size may lead to a loss of physical information. The patch + size should not be smaller than the auto-correlation distance, which can be + determined using the auto-correlation plotting utility. Other important hyperparameters include: + + - Training duration (`training.hp.training_duration`): Total number of + samples to process during training. Values between 1M and 30M samples are + typical, depending on the size of the dataset and on the type of model + (regression or diffusion). + - Learning rate ramp-up (`training.hp.lr_rampup`): Number of samples over + which learning rate gradually increases. In some cases, `lr_rampup=0` is + sufficient, but if training is unstable, it may be necessary to increase + it. Values between 0 and 200M samples are typical. + - Learning rate (`training.hp.lr`): Base learning rate that controls how + quickly model parameters are updated. It may be decreased if training is + unstable, and increased if training is slow. + - Batch size per GPU (`training.hp.batch_size_per_gpu`): Number of samples + processed in parallel on each GPU. It needs to be reduced if you encounter + an out-of-memory error. ## References diff --git a/examples/generative/corrdiff/conf/dataset/hrrrmini.yaml b/examples/generative/corrdiff/conf/base/__init__.py similarity index 83% rename from examples/generative/corrdiff/conf/dataset/hrrrmini.yaml rename to examples/generative/corrdiff/conf/base/__init__.py index 0b31cb3fa0..b2f171d4ac 100644 --- a/examples/generative/corrdiff/conf/dataset/hrrrmini.yaml +++ b/examples/generative/corrdiff/conf/base/__init__.py @@ -13,8 +13,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -type: hrrr_mini -data_path: /data/corrdiff-mini/hrrr_mini_train.nc -stats_path: /data/corrdiff-mini/stats.json -output_variables: ['10u', '10v'] \ No newline at end of file diff --git a/examples/generative/corrdiff/conf/config_generate.yaml b/examples/generative/corrdiff/conf/base/dataset/custom.yaml similarity index 65% rename from examples/generative/corrdiff/conf/config_generate.yaml rename to examples/generative/corrdiff/conf/base/dataset/custom.yaml index f5a84f8438..e17cbd50d7 100644 --- a/examples/generative/corrdiff/conf/config_generate.yaml +++ b/examples/generative/corrdiff/conf/base/dataset/custom.yaml @@ -14,23 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -hydra: - job: - chdir: true - name: generation - run: - dir: ./outputs/${hydra:job.name} - -# Get defaults -defaults: - - # Dataset - - dataset/cwb_generate - - # Sampler - - sampler/stochastic - #- sampler/deterministic - - # Generation - - generation/base - #- generation/patched_based +# Dataset type. Must be overridden. +type: ??? +# Path to .nc data file. Must be overridden. +data_path: ??? +# Path to json stats file. Must be overriden. +stats_path: ??? +# Names of input channels. Must be overridden. +input_variables: ??? +# Names of output channels. Must be overridden. +output_variables: ??? +# Names of invariants variables. Optional. +invariant_variables: ??? diff --git a/examples/generative/corrdiff/conf/dataset/cwb_train.yaml b/examples/generative/corrdiff/conf/base/dataset/cwb.yaml similarity index 73% rename from examples/generative/corrdiff/conf/dataset/cwb_train.yaml rename to examples/generative/corrdiff/conf/base/dataset/cwb.yaml index 346f98ba32..b686cb4b3a 100644 --- a/examples/generative/corrdiff/conf/dataset/cwb_train.yaml +++ b/examples/generative/corrdiff/conf/base/dataset/cwb.yaml @@ -1,4 +1,3 @@ - # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 @@ -15,15 +14,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Dataset type. Do not modify. type: cwb -data_path: /code/2023-01-24-cwb-4years.zarr +# Path to data file. Must be overridden. +data_path: ??? +# Indices of input channels in_channels: [0, 1, 2, 3, 4, 9, 10, 11, 12, 17, 18, 19] +# Indices of output channels out_channels: [0, 1, 2, 3] +# Shape of the image img_shape_x: 448 img_shape_y: 448 +# Add grid coordinates to the image add_grid: true +# Factor to downscale the image ds_factor: 4 +# Path to min and max values of the data min_path: null max_path: null +# Path to global means of the data global_means_path: null +# Path to global stds of the data global_stds_path: null \ No newline at end of file diff --git a/examples/generative/corrdiff/conf/dataset/gefs_hrrr.yaml b/examples/generative/corrdiff/conf/base/dataset/gefs_hrrr.yaml similarity index 68% rename from examples/generative/corrdiff/conf/dataset/gefs_hrrr.yaml rename to examples/generative/corrdiff/conf/base/dataset/gefs_hrrr.yaml index 3b67d77bbb..2a69065652 100644 --- a/examples/generative/corrdiff/conf/dataset/gefs_hrrr.yaml +++ b/examples/generative/corrdiff/conf/base/dataset/gefs_hrrr.yaml @@ -14,13 +14,33 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Dataset type. Do not modify. type: gefs_hrrr -data_path: /data -stats_path: /data/stats.json +# Path to .nc data file. Must be overridden. +data_path: ??? +# Path to json stats file. Must be overriden. +stats_path: ??? +# Names of output channels. output_variables: ["u10m", "v10m", "t2m", "precip", "cat_snow", "cat_ice", "cat_freez", "cat_rain", "cat_none"] +# Names of probability variables. prob_variables: ["cat_snow", "cat_ice", "cat_freez", "cat_rain"] +# Names of input surface variables. input_surface_variables: ["u10m", "v10m", "t2m", "q2m", "sp", "msl", "precipitable_water"] +# Names of input isobaric variables. input_isobaric_variables: ['u1000', 'u925', 'u850', 'u700', 'u500', 'u250', 'v1000', 'v925', 'v850', 'v700', 'v500', 'v250', 'z1000', 'z925', 'z850', 'z700', 'z500', 'z200', 't1000', 't925', 't850', 't700', 't500', 't100', 'r1000', 'r925', 'r850', 'r700', 'r500', 'r100'] +# Factor to downscale the image. ds_factor: 4 train: False -hrrr_window: [[1,1057], [4,1796]] # need dims to be divisible by 16 [[0,1024], [0,1024]] +# Years to train the model. +train_years: [2020, 2021, 2022, 2023] +# Years to validate the model. +valid_years: [2024] +# Whether to normalize the data. +normalize: True +# Whether to shard the data. +shard: False +overfit: False +# Whether to use all the data. +use_all: False +sample_shape: [-1, -1] +hrrr_window: [[1,1057], [4,1796]] # need dims to be divisible by 16 diff --git a/examples/generative/corrdiff/conf/model/corrdiff_regression_gefs_hrrr.yaml b/examples/generative/corrdiff/conf/base/dataset/hrrr_mini.yaml similarity index 75% rename from examples/generative/corrdiff/conf/model/corrdiff_regression_gefs_hrrr.yaml rename to examples/generative/corrdiff/conf/base/dataset/hrrr_mini.yaml index ef4ccd655b..aa6f0a171b 100644 --- a/examples/generative/corrdiff/conf/model/corrdiff_regression_gefs_hrrr.yaml +++ b/examples/generative/corrdiff/conf/base/dataset/hrrr_mini.yaml @@ -14,8 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: lt_aware_ce_regression - # Name of the preconditioner -hr_mean_conditioning: False - # High-res mean (regression's output) as additional condition - +# Dataset type +type: hrrr_mini +# Path to .nc data file. Must be overridden. +data_path: ??? +# Path to json stats file. Must be overriden. +stats_path: ??? +# Names of output channels. Must be overridden. +output_variables: ['10u', '10v'] diff --git a/examples/generative/corrdiff/conf/generation/patched_based.yaml b/examples/generative/corrdiff/conf/base/generation/base_all.yaml similarity index 57% rename from examples/generative/corrdiff/conf/generation/patched_based.yaml rename to examples/generative/corrdiff/conf/base/generation/base_all.yaml index a8a514f098..c655332d43 100644 --- a/examples/generative/corrdiff/conf/generation/patched_based.yaml +++ b/examples/generative/corrdiff/conf/base/generation/base_all.yaml @@ -14,35 +14,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -num_ensembles: 64 - # Number of ensembles to generate per input -seed_batch_size: 4 - # Size of the batched inference +defaults: + - sampler: stochastic + # Recommended is stochastic sampler. Change to deterministic if needed. + +num_ensembles: ??? +# Number of ensembles to generate per input. Should be overridden. +seed_batch_size: ??? +# Size of the batched inference. Should be overridden. inference_mode: all - # Choose between "all" (regression + diffusion), "regression" or "diffusion" -patch_size: 448 -patch_shape_x: 448 -patch_shape_y: 448 - # Patch size. Patch-based sampling will be utilized if these dimensions differ from - # img_shape_x and img_shape_y -overlap_pixels: 4 - # Number of overlapping pixels between adjacent patches -boundary_pixels: 2 - # Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary - # artifact. +# Choose between "all" (regression + diffusion), "regression" or "diffusion" hr_mean_conditioning: true -gridtype: learnable -N_grid_channels: 100 -sample_res: full - # Sampling resolution +# Whether to use hr_mean_conditioning times_range: null -times: - - 2021-02-02T00:00:00 - - 2021-03-02T00:00:00 - - 2021-04-02T00:00:00 - # hurricane - - 2021-09-12T00:00:00 - - 2021-09-12T12:00:00 +# Time range to generate. Can be overridden. +has_lead_time: False +# Whether the model has lead time. perf: force_fp16: false @@ -55,9 +42,3 @@ perf: num_writer_workers: 1 # number of workers to use for writing file # To support multiple workers a threadsafe version of the netCDF library must be used - -io: - res_ckpt_filename: diffusion_checkpoint.mdlus - # Checkpoint filename for the diffusion model - reg_ckpt_filename: regression_checkpoint.mdlus - # Checkpoint filename for the mean predictor model diff --git a/examples/generative/corrdiff/conf/model/corrdiff_regression_mini.yaml b/examples/generative/corrdiff/conf/base/generation/non_patched.yaml similarity index 87% rename from examples/generative/corrdiff/conf/model/corrdiff_regression_mini.yaml rename to examples/generative/corrdiff/conf/base/generation/non_patched.yaml index d62f26aac1..3a3ef7a42b 100644 --- a/examples/generative/corrdiff/conf/model/corrdiff_regression_mini.yaml +++ b/examples/generative/corrdiff/conf/base/generation/non_patched.yaml @@ -15,9 +15,6 @@ # limitations under the License. defaults: - - corrdiff_regression + - base_all -model_args: - model_channels: 64 - channel_mult: [1, 2, 2] - attn_resolutions: [16] +patching: False \ No newline at end of file diff --git a/examples/generative/corrdiff/conf/config_training_gefs_regression.yaml b/examples/generative/corrdiff/conf/base/generation/patched.yaml similarity index 63% rename from examples/generative/corrdiff/conf/config_training_gefs_regression.yaml rename to examples/generative/corrdiff/conf/base/generation/patched.yaml index 4b5a85e267..1de48db343 100644 --- a/examples/generative/corrdiff/conf/config_training_gefs_regression.yaml +++ b/examples/generative/corrdiff/conf/base/generation/patched.yaml @@ -14,21 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -hydra: - job: - chdir: true - name: gefs_hrrr_regression - run: - dir: ./outputs/${hydra:job.name} - -# Get defaults defaults: + - base_all - # Dataset - - dataset/gefs_hrrr - - # Model - - model/corrdiff_regression_gefs_hrrr - - # Training - - training/corrdiff_regression_gefs_hrrr +patching: True +# Use patch-based sampling +overlap_pix: 4 +# Number of overlapping pixels between adjacent patches +boundary_pix: 2 +# Number of boundary pixels to be cropped out. 2 is recommended to address the boundary +# artifact. +patch_shape_x: ??? +patch_shape_y: ??? + # Patch size. Patch-based sampling will be utilized if these dimensions + # differ from img_shape_x and img_shape_y. Needs to be overridden. \ No newline at end of file diff --git a/examples/generative/corrdiff/conf/sampler/deterministic.yaml b/examples/generative/corrdiff/conf/base/generation/sampler/deterministic.yaml similarity index 96% rename from examples/generative/corrdiff/conf/sampler/deterministic.yaml rename to examples/generative/corrdiff/conf/base/generation/sampler/deterministic.yaml index bf0777aaab..7e527d90f2 100644 --- a/examples/generative/corrdiff/conf/sampler/deterministic.yaml +++ b/examples/generative/corrdiff/conf/base/generation/sampler/deterministic.yaml @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# @package _global_.sampler type: deterministic num_steps: 9 diff --git a/examples/generative/corrdiff/conf/sampler/stochastic.yaml b/examples/generative/corrdiff/conf/base/generation/sampler/stochastic.yaml similarity index 96% rename from examples/generative/corrdiff/conf/sampler/stochastic.yaml rename to examples/generative/corrdiff/conf/base/generation/sampler/stochastic.yaml index 314645e08f..7eb902fab4 100644 --- a/examples/generative/corrdiff/conf/sampler/stochastic.yaml +++ b/examples/generative/corrdiff/conf/base/generation/sampler/stochastic.yaml @@ -1,3 +1,5 @@ +# @package _global_.sampler + # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 @@ -15,6 +17,4 @@ # limitations under the License. type: stochastic -boundary_pix: 2 -overlap_pix: 4 #overlap_pix has to be no less than 2*boundary_pix diff --git a/examples/generative/corrdiff/conf/model/corrdiff_patched_diffusion_gefs_hrrr.yaml b/examples/generative/corrdiff/conf/base/model/diffusion.yaml similarity index 60% rename from examples/generative/corrdiff/conf/model/corrdiff_patched_diffusion_gefs_hrrr.yaml rename to examples/generative/corrdiff/conf/base/model/diffusion.yaml index e8b8a5a42a..f2063db353 100644 --- a/examples/generative/corrdiff/conf/model/corrdiff_patched_diffusion_gefs_hrrr.yaml +++ b/examples/generative/corrdiff/conf/base/model/diffusion.yaml @@ -14,11 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: lt_aware_patched_diffusion - # Name of the preconditioner +name: diffusion +# Model type. hr_mean_conditioning: True - # High-res mean (regression's output) as additional condition -scale_cond_input: True - # If true, also scales the input conditioning - # For backward compatibility, this is true by default - # We recommend setting this to false for new training runs \ No newline at end of file +# Recommended to use high-res conditioning for diffusion. + +# Standard model parameters. +model_args: + gridtype: "sinusoidal" + # Type of positional grid to use: 'sinusoidal', 'learnable', 'linear'. + # Controls how positional information is encoded. + N_grid_channels: 4 + # Number of channels for positional grid embeddings + embedding_type: "zero" + # Type of timestep embedding: 'positional' for DDPM++, 'fourier' for NCSN++, + # 'zero' for none diff --git a/examples/generative/corrdiff/conf/base/model/lt_aware_ce_regression.yaml b/examples/generative/corrdiff/conf/base/model/lt_aware_ce_regression.yaml new file mode 100644 index 0000000000..b4dc43510d --- /dev/null +++ b/examples/generative/corrdiff/conf/base/model/lt_aware_ce_regression.yaml @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +name: lt_aware_ce_regression + # Model type. +hr_mean_conditioning: False + # No high-res conditioning for regression. + +# Default model parameters. +model_args: + N_grid_channels: 4 + # Number of channels for positional grid embeddings + embedding_type: "zero" + # Type of timestep embedding: 'positional' for DDPM++, 'fourier' for NCSN++, + # 'zero' for none + lead_time_channels: 4 + # Number of channels for lead-time embeddings + lead_time_steps: 9 + # Number of lead-time steps + model_type: "SongUNetPosLtEmbd" + # Type of model architecture: 'SongUNetPosLtEmbd' for lead-time aware UNet with + # positional embeddings, 'SongUNetPosEmbd' for UNet with positional + # embeddings, 'SongUNet' for UNet without positional embeddings, + # 'DhariwalUNet' for UNet with Fourier embeddings. If not provided, default + # to 'SongUNetPosEmbd'. diff --git a/examples/generative/corrdiff/conf/base/model/lt_aware_patched_diffusion.yaml b/examples/generative/corrdiff/conf/base/model/lt_aware_patched_diffusion.yaml new file mode 100644 index 0000000000..38481cce77 --- /dev/null +++ b/examples/generative/corrdiff/conf/base/model/lt_aware_patched_diffusion.yaml @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: lt_aware_patched_diffusion +# Model type. +hr_mean_conditioning: True +# Recommended to use high-res conditioning for diffusion. + +# Standard model parameters. +model_args: + gridtype: "learnable" + # Type of positional grid to use: 'sinusoidal', 'learnable', 'linear'. + # Controls how positional information is encoded. + N_grid_channels: 100 + # Number of channels for positional grid embeddings + lead_time_channels: 20 + # Number of channels for lead-time embeddings + lead_time_steps: 9 + # Number of lead-time steps + model_type: "SongUNetPosLtEmbd" + # Type of model architecture: 'SongUNetPosLtEmbd' for lead-time aware UNet with + # positional embeddings, 'SongUNetPosEmbd' for UNet with positional + # embeddings, 'SongUNet' for UNet without positional embeddings, + # 'DhariwalUNet' for UNet with Fourier embeddings. If not provided, default + # to 'SongUNetPosEmbd'. diff --git a/examples/generative/corrdiff/conf/model/corrdiff_patched_diffusion.yaml b/examples/generative/corrdiff/conf/base/model/patched_diffusion.yaml similarity index 67% rename from examples/generative/corrdiff/conf/model/corrdiff_patched_diffusion.yaml rename to examples/generative/corrdiff/conf/base/model/patched_diffusion.yaml index 22cd7f791a..bf88bbb649 100644 --- a/examples/generative/corrdiff/conf/model/corrdiff_patched_diffusion.yaml +++ b/examples/generative/corrdiff/conf/base/model/patched_diffusion.yaml @@ -14,11 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: patched_diffusion - # Name of the preconditioner +name: diffusion +# Model type. hr_mean_conditioning: True - # High-res mean (regression's output) as additional condition -scale_cond_input: True - # If true, also scales the input conditioning - # For backward compatibility, this is true by default - # We recommend setting this to false for new training runs \ No newline at end of file +# Recommended to use high-res conditioning for diffusion. + +# Standard model parameters. +model_args: + gridtype: "learnable" + # Type of positional grid to use: 'sinusoidal', 'learnable', 'linear'. + # Controls how positional information is encoded. + N_grid_channels: 100 + # Number of channels for positional grid embeddings \ No newline at end of file diff --git a/examples/generative/corrdiff/conf/model/corrdiff_diffusion.yaml b/examples/generative/corrdiff/conf/base/model/regression.yaml similarity index 68% rename from examples/generative/corrdiff/conf/model/corrdiff_diffusion.yaml rename to examples/generative/corrdiff/conf/base/model/regression.yaml index e0607efffa..1be8082326 100644 --- a/examples/generative/corrdiff/conf/model/corrdiff_diffusion.yaml +++ b/examples/generative/corrdiff/conf/base/model/regression.yaml @@ -14,11 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: diffusion - # Name of the preconditioner +name: regression +# Model type. hr_mean_conditioning: false - # High-res mean (regression's output) as additional condition -scale_cond_input: True - # If true, also scales the input conditioning - # For backward compatibility, this is true by default - # We recommend setting this to false for new training runs \ No newline at end of file +# No high-res conditioning for regression. + +# Default regression model parameters. Do not modify. +model_args: + "N_grid_channels": 4 + # Number of channels for positional grid embeddings + "embedding_type": "zero" + # Type of timestep embedding: 'positional' for DDPM++, 'fourier' for NCSN++, + # 'zero' for none \ No newline at end of file diff --git a/examples/generative/corrdiff/conf/model/corrdiff_diffusion_mini.yaml b/examples/generative/corrdiff/conf/base/model_size/mini.yaml similarity index 76% rename from examples/generative/corrdiff/conf/model/corrdiff_diffusion_mini.yaml rename to examples/generative/corrdiff/conf/base/model_size/mini.yaml index 9b54039343..2eb8f8aba7 100644 --- a/examples/generative/corrdiff/conf/model/corrdiff_diffusion_mini.yaml +++ b/examples/generative/corrdiff/conf/base/model_size/mini.yaml @@ -1,3 +1,5 @@ +# @package _global_.model + # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 @@ -14,16 +16,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -defaults: - - corrdiff_diffusion - -hr_mean_conditioning: True -scale_cond_input: false - # If true, also scales the input conditioning - # For backward compatibility, this is true by default - # We recommend setting this to false for new training runs model_args: + # Base multiplier for the number of channels across the network. model_channels: 64 + # Per-resolution multipliers for the number of channels. channel_mult: [1, 2, 2] + # Resolutions at which self-attention layers are applied. attn_resolutions: [16] \ No newline at end of file diff --git a/examples/generative/corrdiff/conf/config_generate_mini.yaml b/examples/generative/corrdiff/conf/base/model_size/normal.yaml similarity index 70% rename from examples/generative/corrdiff/conf/config_generate_mini.yaml rename to examples/generative/corrdiff/conf/base/model_size/normal.yaml index a2dd3c29de..dd3450a33d 100644 --- a/examples/generative/corrdiff/conf/config_generate_mini.yaml +++ b/examples/generative/corrdiff/conf/base/model_size/normal.yaml @@ -1,3 +1,5 @@ +# @package _global_.model + # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 @@ -14,23 +16,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -hydra: - job: - chdir: true - name: mini_generation - run: - dir: ./outputs/${hydra:job.name} - -# Get defaults -defaults: - - # Dataset - - dataset/hrrrmini - - # Sampler - - sampler/stochastic - #- sampler/deterministic - # Generation - - generation/mini - #- generation/patched_based +model_args: + # Base multiplier for the number of channels across the network. + model_channels: 128 + # Per-resolution multipliers for the number of channels. + channel_mult: [1, 2, 2, 2, 2] + # Resolutions at which self-attention layers are applied. + attention_levels: [28] \ No newline at end of file diff --git a/examples/generative/corrdiff/conf/training/corrdiff_regression.yaml b/examples/generative/corrdiff/conf/base/training/base_all.yaml similarity index 82% rename from examples/generative/corrdiff/conf/training/corrdiff_regression.yaml rename to examples/generative/corrdiff/conf/base/training/base_all.yaml index 2d27efb8db..d669adcfd4 100644 --- a/examples/generative/corrdiff/conf/training/corrdiff_regression.yaml +++ b/examples/generative/corrdiff/conf/base/training/base_all.yaml @@ -20,20 +20,20 @@ hp: # Training duration based on the number of processed samples total_batch_size: 256 # Total batch size - batch_size_per_gpu: 2 + batch_size_per_gpu: "auto" # Batch size per GPU lr: 0.0002 # Learning rate grad_clip_threshold: null - # no gradient clipping for defualt non-patch-based training + # no gradient clipping for default non-patch-based training lr_decay: 1 # LR decay rate - lr_rampup: 10000000 + lr_rampup: 0 # Rampup for learning rate, in number of samples # Performance perf: - fp_optimizations: fp32 + fp_optimizations: amp-bf16 # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} dataloader_workers: 4 @@ -41,13 +41,17 @@ perf: songunet_checkpoint_level: 0 # 0 means no checkpointing # Gradient checkpointing level, value is number of layers to checkpoint -# I/O +# IO io: + regression_checkpoint_path: null + # Where to load the regression checkpoint. Should be overridden. print_progress_freq: 1000 # How often to print progress save_checkpoint_freq: 5000 # How often to save the checkpoints, measured in number of processed samples + save_n_recent_checkpoints: -1 + # Set to a positive integer to only keep the most recent n checkpoints validation_freq: 5000 # how often to record the validation loss, measured in number of processed samples validation_steps: 10 - # how many loss evaluations are used to compute the validation loss per checkpoint + # how many loss evaluations are used to compute the validation loss per checkpoint diff --git a/examples/generative/corrdiff/conf/model/corrdiff_regression.yaml b/examples/generative/corrdiff/conf/base/training/diffusion.yaml similarity index 83% rename from examples/generative/corrdiff/conf/model/corrdiff_regression.yaml rename to examples/generative/corrdiff/conf/base/training/diffusion.yaml index fa96f1533b..503a914d27 100644 --- a/examples/generative/corrdiff/conf/model/corrdiff_regression.yaml +++ b/examples/generative/corrdiff/conf/base/training/diffusion.yaml @@ -14,8 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: regression - # Name of the preconditioner -hr_mean_conditioning: False - # High-res mean (regression's output) as additional condition +defaults: + - base_all +io: + regression_checkpoint_path: ??? + # Where to load the regression checkpoint. Must be overridden. diff --git a/examples/generative/corrdiff/conf/training/corrdiff_regression_gefs_hrrr.yaml b/examples/generative/corrdiff/conf/base/training/lt_aware_ce_regression.yaml similarity index 69% rename from examples/generative/corrdiff/conf/training/corrdiff_regression_gefs_hrrr.yaml rename to examples/generative/corrdiff/conf/base/training/lt_aware_ce_regression.yaml index d7621715fe..69c84093fa 100644 --- a/examples/generative/corrdiff/conf/training/corrdiff_regression_gefs_hrrr.yaml +++ b/examples/generative/corrdiff/conf/base/training/lt_aware_ce_regression.yaml @@ -15,32 +15,23 @@ # limitations under the License. defaults: - - corrdiff_regression - + - base_all # Hyperparameters hp: - training_duration: 2000000 + training_duration: 1000000 # Training duration based on the number of processed samples total_batch_size: 1 # Total batch size batch_size_per_gpu: 1 # Batch size per GPU - lr_rampup: 0 - # Rampup for learning rate, in number of samples -# Performance perf: - fp_optimizations: amp-bf16 - # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] - # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} - dataloader_workers: 1 - # DataLoader worker processes - songunet_checkpoint_level: 2 # 0 means no checkpointing + songunet_checkpoint_level: 2 # I/O io: print_progress_freq: 1 # How often to print progress save_checkpoint_freq: 5 - # How often to save the checkpoints, measured in number of processed samples + # How often to save the checkpoints, measured in number of processed samples \ No newline at end of file diff --git a/examples/generative/corrdiff/conf/training/corrdiff_patched_diffusion_gefs_hrrr.yaml b/examples/generative/corrdiff/conf/base/training/lt_aware_patched_diffusion.yaml similarity index 65% rename from examples/generative/corrdiff/conf/training/corrdiff_patched_diffusion_gefs_hrrr.yaml rename to examples/generative/corrdiff/conf/base/training/lt_aware_patched_diffusion.yaml index 68405bc309..5e63cf47c1 100644 --- a/examples/generative/corrdiff/conf/training/corrdiff_patched_diffusion_gefs_hrrr.yaml +++ b/examples/generative/corrdiff/conf/base/training/lt_aware_patched_diffusion.yaml @@ -14,42 +14,40 @@ # See the License for the specific language governing permissions and # limitations under the License. +defaults: + - base_all + # Hyperparameters hp: - training_duration: 200000 - # Training duration based on the number of processed images, measured in kilo images (thousands of images) + training_duration: 10000000 + # Training duration based on the number of processed images total_batch_size: 1 # Total batch size batch_size_per_gpu: 1 # Batch size per GPU - lr: 0.0002 - # Learning rate grad_clip_threshold: 1e6 - # no gradient clipping for defualt non-patch-based training + # no gradient clipping for default non-patch-based training lr_decay: 0.7 # LR decay rate - patch_shape_x: 448 - patch_shape_y: 448 - # Patch size. Patch training is used if these dimensions differ from img_shape_x and img_shape_y - patch_num: 4 - # Number of patches from a single sample. Total number of patches is patch_num * batch_size_global + patch_shape_x: ??? + patch_shape_y: ??? + # Patch size. Patch training is used if these dimensions differ from + # img_shape_x and img_shape_y. Should be overridden. + patch_num: ??? + # Number of patches from a single sample. Total number of patches is + # patch_num * batch_size_global. Should be overridden. lr_rampup: 1000000 # Rampup for learning rate, in number of samples # Performance perf: - fp_optimizations: amp-bf16 - # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] - # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} - dataloader_workers: 4 - # DataLoader worker processes songunet_checkpoint_level: 1 # 0 means no checkpointing # Gradient checkpointing level, value is number of layers to checkpoint # I/O io: - regression_checkpoint_path: /lustre/fsw/portfolios/coreai/projects/coreai_climate_earth2/tge/gefs_regression/checkpoints_lt_aware_ce_regression/UNet.0.15.mdlus - # Where to load the regression checkpoint + regression_checkpoint_path: ??? + # Where to load the regression checkpoint. Must be overridden. print_progress_freq: 1 # How often to print progress save_checkpoint_freq: 5 @@ -57,4 +55,4 @@ io: validation_freq: 1 # how often to record the validation loss, measured in number of processed samples validation_steps: 1000 - # how many loss evaluations are used to compute the validation loss per checkpoint + # how many loss evaluations are used to compute the validation loss per checkpoint \ No newline at end of file diff --git a/examples/generative/corrdiff/conf/base/training/patched_diffusion.yaml b/examples/generative/corrdiff/conf/base/training/patched_diffusion.yaml new file mode 100644 index 0000000000..ed0002473f --- /dev/null +++ b/examples/generative/corrdiff/conf/base/training/patched_diffusion.yaml @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - base_all + +# Hyperparameters +hp: + training_duration: 10000000 + # Training duration based on the number of processed samples + grad_clip_threshold: 1e6 + # no gradient clipping for default non-patch-based training + lr_decay: 0.7 + # LR decay rate + patch_shape_x: ??? + patch_shape_y: ??? + # Patch size. Patch training is used if these dimensions differ from + # img_shape_x and img_shape_y. Should be overridden. + patch_num: ??? + # Number of patches from a single sample. Total number of patches is + # patch_num * batch_size_global. Should be overridden. + +# I/O +io: + regression_checkpoint_path: ??? + # Where to load the regression checkpoint. Must be overridden. + save_checkpoint_freq: 500000 + # How often to save the checkpoints, measured in number of processed samples + validation_freq: 50000 + # how often to record the validation loss, measured in number of processed samples + diff --git a/examples/generative/corrdiff/conf/validation/cwb.yaml b/examples/generative/corrdiff/conf/base/training/regression.yaml similarity index 85% rename from examples/generative/corrdiff/conf/validation/cwb.yaml rename to examples/generative/corrdiff/conf/base/training/regression.yaml index 15e2aa6dc1..d2e3174075 100644 --- a/examples/generative/corrdiff/conf/validation/cwb.yaml +++ b/examples/generative/corrdiff/conf/base/training/regression.yaml @@ -14,7 +14,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Validation dataset options -# (need to set dataset.train_test_split == true to have an effect) -train: false -all_times: false \ No newline at end of file +defaults: + - base_all \ No newline at end of file diff --git a/examples/generative/corrdiff/conf/config_generate_custom.yaml b/examples/generative/corrdiff/conf/config_generate_custom.yaml new file mode 100644 index 0000000000..616d390039 --- /dev/null +++ b/examples/generative/corrdiff/conf/config_generate_custom.yaml @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: false + name: # Change `my_job_name` + run: + dir: .//${hydra:job.name} # Change `my_output_dir` + searchpath: + - pkg://conf/base # Do not modify + +# Base parameters for dataset, model, and generation +defaults: + + - dataset: custom + # The dataset type for training. + # Accepted values: + # `gefs_hrrr`: full GEFS-HRRR dataset for continental US. + # `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments. + # `cwb`: full CWB dataset for Taiwan. + # `custom`: user-defined dataset. Parameters need to be specified below. + + - generation: patched + # The base generation parameters. + # Accepted values: + # `patched`: base parameters for a patch-based model + # `non_patched`: base parameters for a non-patched model + + +# Dataset parameters. Used for `custom` dataset type. +# Modify or add below parameters that should be passed as argument to the +# user-defined dataset class. +dataset: + type: + # Path to the user-defined dataset class. The user-defined dataset class is + # automatically loaded from the path. The user-defined class "DatasetClass" + # must be defined in the path "path/to/dataset.py". + data_path: + # Path to .nc data file + stats_path: + # Path to json stats file + input_variables: [] + # Names or indices of input channels + output_variables: [] + # Names or indices of output channels + invariant_variables: null + # Names or indices of invariant channels. Optional. + +# Generation parameters to specialize +generation: + num_ensembles: 64 + # int, number of ensembles to generate per input + seed_batch_size: 4 + # int, size of the batched inference + patch_shape_x: 448 + patch_shape_y: 448 + # int, patch size. Only used for `generation: patched`. For custom dataset, + # this should be determined based on an autocorrelation plot. + times: + - YYYY-MM-DDThh:mm:ss # Replace with target value + # List[str], time stamps in ISO 8601 format. Replace and list desired target + # time stamps. + io: + res_ckpt_filename: + # Path to checkpoint file for the diffusion model + reg_ckpt_filename: + # Path to checkpoint filename for the mean predictor model + +# Parameters for wandb logging +wandb: + mode: offline + # Configure whether to use wandb: "offline", "online", "disabled" diff --git a/examples/generative/corrdiff/conf/config_generate_gefs_hrrr.yaml b/examples/generative/corrdiff/conf/config_generate_gefs_hrrr.yaml index 758f46f3e9..0814816ca2 100644 --- a/examples/generative/corrdiff/conf/config_generate_gefs_hrrr.yaml +++ b/examples/generative/corrdiff/conf/config_generate_gefs_hrrr.yaml @@ -15,22 +15,73 @@ # limitations under the License. hydra: - job: - chdir: true - name: gefs_hrrr_generation - run: - dir: output/${hydra:job.name} + job: + chdir: false + name: generate_gefs_hrrr + run: + dir: ./outputs/${hydra:job.name} + searchpath: + - pkg://conf/base # Do not modify -# Get defaults +# Base parameters for dataset, model, and generation defaults: - # Dataset - - dataset/gefs_hrrr + - dataset: gefs_hrrr + # The dataset type for training. + # Accepted values: + # `gefs_hrrr`: full GEFS-HRRR dataset for continental US. + # `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments. + # `cwb`: full CWB dataset for Taiwan. + # `custom`: user-defined dataset. Parameters need to be specified below. - # Sampler - - sampler/stochastic - #- sampler/deterministic + - generation: patched + # The base generation parameters. + # Accepted values: + # `patched`: base parameters for a patch-based model + # `non_patched`: base parameters for a non-patched model - # Generation - - generation/patched_based_gefs_hrrr - #- generation/patched_based + +# Dataset parameters. Used for `custom` dataset type. +# Modify or add below parameters that should be passed as argument to the +# user-defined dataset class. +dataset: + data_path: ./data + # Path to .nc data file + stats_path: ./data/stats.json + # Path to json stats file + + +# Generation parameters to specialize +generation: + num_ensembles: 1 + # int, number of ensembles to generate per input + seed_batch_size: 1 + # int, size of the batched inference + patch_shape_x: 448 + patch_shape_y: 448 + # int, patch size. Only used for `generation: patched`. For custom dataset, + # this should be determined based on an autocorrelation plot. + times: + - "2024011212f00" + - "2024011212f03" + - "2024011212f06" + - "2024011212f09" + - "2024011212f12" + - "2024011212f15" + - "2024011212f18" + - "2024011212f21" + - "2024011212f24" + # List[str], time stamps in ISO 8601 format. Replace and list desired target + # time stamps. + has_lead_time: True + + io: + res_ckpt_filename: + # Path to checkpoint file for the diffusion model + reg_ckpt_filename: + # Path to checkpoint filename for the mean predictor model + +# Parameters for wandb logging +wandb: + mode: offline + # Configure whether to use wandb: "offline", "online", "disabled" \ No newline at end of file diff --git a/examples/generative/corrdiff/conf/config_generate_hrrr_mini.yaml b/examples/generative/corrdiff/conf/config_generate_hrrr_mini.yaml new file mode 100644 index 0000000000..2c425495ae --- /dev/null +++ b/examples/generative/corrdiff/conf/config_generate_hrrr_mini.yaml @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: false + name: generate_hrrr_mini + run: + dir: ./outputs/${hydra:job.name} + searchpath: + - pkg://conf/base # Do not modify + +# Base parameters for dataset, model, and generation +defaults: + + - dataset: hrrr_mini + # The dataset type for training. + # Accepted values: + # `gefs_hrrr`: full GEFS-HRRR dataset for continental US. + # `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments. + # `cwb`: full CWB dataset for Taiwan. + # `custom`: user-defined dataset. Parameters need to be specified below. + + - generation: non_patched + # The base generation parameters. + # Accepted values: + # `patched`: base parameters for a patch-based model + # `non_patched`: base parameters for a non-patched model + + +# Dataset parameters. Used for `custom` dataset type. +# Modify or add below parameters that should be passed as argument to the +# user-defined dataset class. +dataset: + data_path: ./data/corrdiff-mini/hrrr_mini_train.nc + # Path to .nc data file + stats_path: ./data/corrdiff-mini/stats.json + # Path to json stats file + +# Generation parameters to specialize +generation: + num_ensembles: 2 + # int, number of ensembles to generate per input + seed_batch_size: 1 + # int, size of the batched inference + times: + - 2020-02-02T00:00:00 + # List[str], time stamps in ISO 8601 format. Replace and list desired target + # time stamps. + io: + res_ckpt_filename: + # Path to checkpoint file for the diffusion model + reg_ckpt_filename: + # Path to checkpoint filename for the mean predictor model + +# Parameters for wandb logging +wandb: + mode: offline + # Configure whether to use wandb: "offline", "online", "disabled" \ No newline at end of file diff --git a/examples/generative/corrdiff/conf/config_generate_taiwan.yaml b/examples/generative/corrdiff/conf/config_generate_taiwan.yaml new file mode 100644 index 0000000000..e66153d940 --- /dev/null +++ b/examples/generative/corrdiff/conf/config_generate_taiwan.yaml @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: false + name: generate_taiwan + run: + dir: ./outputs/${hydra:job.name} + searchpath: + - pkg://conf/base # Do not modify + +# Base parameters for dataset, model, and generation +defaults: + + - dataset: cwb + # The dataset type for training. + # Accepted values: + # `gefs_hrrr`: full GEFS-HRRR dataset for continental US. + # `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments. + # `cwb`: full CWB dataset for Taiwan. + # `custom`: user-defined dataset. Parameters need to be specified below. + + - generation: non_patched + # The base generation parameters. + # Accepted values: + # `patched`: base parameters for a patch-based model + # `non_patched`: base parameters for a non-patched model + + +# Dataset parameters. Used for `custom` dataset type. +# Modify or add below parameters that should be passed as argument to the +# user-defined dataset class. +dataset: + data_path: ./data/2023-01-24-cwb-4years.zarr + train: False + all_times: True + + +# Generation parameters to specialize +generation: + num_ensembles: 64 + # int, number of ensembles to generate per input + seed_batch_size: 1 + # int, size of the batched inference + hr_mean_conditioning: false + # Whether to use hr_mean_conditioning + times: + - 2021-02-02T00:00:00 + - 2021-03-02T00:00:00 + - 2021-04-02T00:00:00 + # hurricane + - 2021-09-12T00:00:00 + - 2021-09-12T12:00:00 + # List[str], time stamps in ISO 8601 format. Replace and list desired target + # time stamps. + io: + res_ckpt_filename: + # Path to checkpoint file for the diffusion model + reg_ckpt_filename: + # Path to checkpoint filename for the mean predictor model + +# Parameters for wandb logging +wandb: + mode: offline + # Configure whether to use wandb: "offline", "online", "disabled" diff --git a/examples/generative/corrdiff/conf/config_training.yaml b/examples/generative/corrdiff/conf/config_training.yaml deleted file mode 100644 index 1d3694a37c..0000000000 --- a/examples/generative/corrdiff/conf/config_training.yaml +++ /dev/null @@ -1,41 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -hydra: - job: - chdir: true - name: regression # choose from [regression, diffusion, patched_diffusion] - run: - dir: ./outputs/${hydra:job.name} - -# Get defaults -defaults: - - # Dataset - - dataset/cwb_train - - # Model - - model/corrdiff_regression - #- model/corrdiff_diffusion - #- model/corrdiff_patched_diffusion - - # Training - - training/corrdiff_regression - #- training/corrdiff_diffusion - #- training/corrdiff_patched_diffusion - - # Validation (comment out to disable validation) - - validation/cwb diff --git a/examples/generative/corrdiff/conf/config_training_custom.yaml b/examples/generative/corrdiff/conf/config_training_custom.yaml new file mode 100644 index 0000000000..4bdf09d61a --- /dev/null +++ b/examples/generative/corrdiff/conf/config_training_custom.yaml @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: false + name: # Change `my_job_name` + run: + dir: .//${hydra:job.name} # Change `my_output_dir` + searchpath: + - pkg://conf/base # Do not modify + +# Base parameters for dataset, model, training, and validation +defaults: + + - dataset: custom + # The dataset type for training. + # Accepted values: + # `gefs_hrrr`: full GEFS-HRRR dataset for continental US. + # `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments. + # `cwb`: full CWB dataset for Taiwan. + # `custom`: user-defined dataset. Parameters need to be specified below. + + - model: diffusion + # The model type. + # Accepted values: + # `regression`: a regression UNet for deterministic predictions + # `lt_aware_ce_regression`: similar to `regression` but with lead time + # conditioning + # `diffusion`: a diffusion UNet for residual predictions + # `patched_diffusion`: a more memory-efficient diffusion model + # `lt_aware_patched_diffusion`: similar to `patched_diffusion` but + # with lead time conditioning + + - model_size: normal + # The model size configuration. + # Accepted values: + # `normal`: normal model size + # `mini`: smaller model size for fast experiments + + - training: ${model} + # The base training parameters. Determined by the model type. + + +# Dataset parameters. Used for `custom` dataset type. +# Modify or add below parameters that should be passed as argument to the +# user-defined dataset class. +dataset: + type: + # Path to the user-defined dataset class. The user-defined dataset class is + # automatically loaded from the path. The user-defined class "DatasetClass" + # must be defined in the path "path/to/dataset.py". + data_path: + # Path to .nc data file + stats_path: + # Path to json stats file + input_variables: [] + # Names or indices of input channels + output_variables: [] + # Names or indices of output channels + invariant_variables: null + # Names or indices of invariant channels. Optional. + +# Training parameters +training: + hp: + training_duration: 10000000 + # Training duration based on the number of processed samples + total_batch_size: 256 + # Total batch size + batch_size_per_gpu: "auto" + # Batch size per GPU. Set to "auto" to automatically determine the batch + # size based on the number of GPUs. + patch_shape_x: 448 + patch_shape_y: 448 + # Patch size. Only used for `model: patched_diffusion` or `model: + # lt_aware_patched_diffusion`. For custom dataset, this should be + # determined based on an autocorrelation plot. + patch_num: 10 + # Number of patches from a single sample. Total number of patches is + # patch_num * total_batch_size. Only used for `model: patched_diffusion` + # or `model: lt_aware_patched_diffusion`. + lr: 0.0002 + # Learning rate + lr_rampup: 0 + # Rampup for learning rate, in number of samples + io: + regression_checkpoint_path: + # Path to load the regression checkpoint + +# Parameters for wandb logging +wandb: + mode: offline + # Configure whether to use wandb: "offline", "online", "disabled" + results_dir: "./wandb" + # Directory to store wandb results + watch_model: false + # If true, wandb will track model parameters and gradients \ No newline at end of file diff --git a/examples/generative/corrdiff/conf/config_training_gefs_diffusion.yaml b/examples/generative/corrdiff/conf/config_training_gefs_diffusion.yaml deleted file mode 100644 index 14cc8be653..0000000000 --- a/examples/generative/corrdiff/conf/config_training_gefs_diffusion.yaml +++ /dev/null @@ -1,34 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -hydra: - job: - chdir: true - name: gefs_hrrr_diffusion - run: - dir: ./outputs/${hydra:job.name} - -# Get defaults -defaults: - - # Dataset - - dataset/gefs_hrrr - - # Model - - model/corrdiff_patched_diffusion_gefs_hrrr - - # Training - - training/corrdiff_patched_diffusion_gefs_hrrr diff --git a/examples/generative/corrdiff/conf/config_training_gefs_hrrr_diffusion.yaml b/examples/generative/corrdiff/conf/config_training_gefs_hrrr_diffusion.yaml new file mode 100644 index 0000000000..4e44e51905 --- /dev/null +++ b/examples/generative/corrdiff/conf/config_training_gefs_hrrr_diffusion.yaml @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: false + name: gefs_hrrr_diffusion + run: + dir: ./output/${hydra:job.name} + searchpath: + - pkg://conf/base # Do not modify + +# Base parameters for dataset, model, training, and validation +defaults: + + - dataset: gefs_hrrr + # The dataset type for training. + # Accepted values: + # `gefs_hrrr`: full GEFS-HRRR dataset for continental US. + # `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments. + # `cwb`: full CWB dataset for Taiwan. + # `custom`: user-defined dataset. Parameters need to be specified below. + + - model: lt_aware_patched_diffusion + # The model type. + # Accepted values: + # `regression`: a regression UNet for deterministic predictions + # `lt_aware_ce_regression`: similar to `regression` but with lead time + # conditioning + # `diffusion`: a diffusion UNet for residual predictions + # `patched_diffusion`: a more memory-efficient diffusion model + # `lt_aware_patched_diffusion`: similar to `patched_diffusion` but + # with lead time conditioning + + - model_size: normal + # The model size configuration. + # Accepted values: + # `normal`: normal model size + # `mini`: smaller model size for fast experiments + + - training: ${model} + # The base training parameters. Determined by the model type. + + +# Dataset parameters. Used for `custom` dataset type. +# Modify or add below parameters that should be passed as argument to the +# user-defined dataset class. +dataset: + data_path: ./data + # Path to .nc data file + stats_path: ./data/stats.json + # Path to json stats file + +# Training parameters +training: + hp: + training_duration: 10000000 + # Training duration based on the number of processed samples + patch_shape_x: 448 + patch_shape_y: 448 + # Patch size. Patch training is used if these dimensions differ from + # img_shape_x and img_shape_y. + patch_num: 4 + # Number of patches from a single sample. Total number of patches is + # patch_num * batch_size_global. + io: + regression_checkpoint_path: + # Path to load the regression checkpoint + +# Parameters for wandb logging +wandb: + mode: offline + # Configure whether to use wandb: "offline", "online", "disabled" + results_dir: "./wandb" + # Directory to store wandb results + watch_model: false + # If true, wandb will track model parameters and gradients diff --git a/examples/generative/corrdiff/conf/config_training_gefs_hrrr_regression.yaml b/examples/generative/corrdiff/conf/config_training_gefs_hrrr_regression.yaml new file mode 100644 index 0000000000..de544c3cc6 --- /dev/null +++ b/examples/generative/corrdiff/conf/config_training_gefs_hrrr_regression.yaml @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: false + name: gefs_hrrr_regression + run: + dir: ./output/${hydra:job.name} + searchpath: + - pkg://conf/base # Do not modify + +# Base parameters for dataset, model, training, and validation +defaults: + + - dataset: gefs_hrrr + # The dataset type for training. + # Accepted values: + # `gefs_hrrr`: full GEFS-HRRR dataset for continental US. + # `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments. + # `cwb`: full CWB dataset for Taiwan. + # `custom`: user-defined dataset. Parameters need to be specified below. + + - model: lt_aware_ce_regression + # The model type. + # Accepted values: + # `regression`: a regression UNet for deterministic predictions + # `lt_aware_ce_regression`: similar to `regression` but with lead time + # conditioning + # `diffusion`: a diffusion UNet for residual predictions + # `patched_diffusion`: a more memory-efficient diffusion model + # `lt_aware_patched_diffusion`: similar to `patched_diffusion` but + # with lead time conditioning + + - model_size: normal + # The model size configuration. + # Accepted values: + # `normal`: normal model size + # `mini`: smaller model size for fast experiments + + - training: ${model} + # The base training parameters. Determined by the model type. + + +# Dataset parameters. Used for `custom` dataset type. +# Modify or add below parameters that should be passed as argument to the +# user-defined dataset class. +dataset: + data_path: ./data + # Path to .nc data file + stats_path: ./data/stats.json + # Path to json stats file + + +# Training parameters +training: + hp: + training_duration: 1000000 + # Training duration based on the number of processed samples + +# Parameters for wandb logging +wandb: + mode: offline + # Configure whether to use wandb: "offline", "online", "disabled" + results_dir: "./wandb" + # Directory to store wandb results + watch_model: false + # If true, wandb will track model parameters and gradients \ No newline at end of file diff --git a/examples/generative/corrdiff/conf/config_training_hrrr_mini_diffusion.yaml b/examples/generative/corrdiff/conf/config_training_hrrr_mini_diffusion.yaml new file mode 100644 index 0000000000..d8f871edbc --- /dev/null +++ b/examples/generative/corrdiff/conf/config_training_hrrr_mini_diffusion.yaml @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: false + name: hrrr_mini_diffusion + run: + dir: ./output/${hydra:job.name} + searchpath: + - pkg://conf/base # Do not modify + +# Base parameters for dataset, model, training, and validation +defaults: + + - dataset: hrrr_mini + # The dataset type for training. + # Accepted values: + # `gefs_hrrr`: full GEFS-HRRR dataset for continental US. + # `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments. + # `cwb`: full CWB dataset for Taiwan. + # `custom`: user-defined dataset. Parameters need to be specified below. + + - model: diffusion + # The model type. + # Accepted values: + # `regression`: a regression UNet for deterministic predictions + # `lt_aware_ce_regression`: similar to `regression` but with lead time + # conditioning + # `diffusion`: a diffusion UNet for residual predictions + # `patched_diffusion`: a more memory-efficient diffusion model + # `lt_aware_patched_diffusion`: similar to `patched_diffusion` but + # with lead time conditioning + + - model_size: mini + # The model size configuration. + # Accepted values: + # `normal`: normal model size + # `mini`: smaller model size for fast experiments + + - training: ${model} + # The base training parameters. Determined by the model type. + + +# Dataset parameters. Used for `custom` dataset type. +# Modify or add below parameters that should be passed as argument to the +# user-defined dataset class. +dataset: + data_path: ./data/corrdiff-mini/hrrr_mini_train.nc + # Path to .nc data file + stats_path: ./data/corrdiff-mini/stats.json + # Path to json stats file + +# Training parameters +training: + hp: + training_duration: 8000000 + # Training duration based on the number of processed samples + io: + print_progress_freq: 10000 + regression_checkpoint_path: + # Path to load the regression checkpoint + +# Parameters for wandb logging +wandb: + mode: offline + # Configure whether to use wandb: "offline", "online", "disabled" + results_dir: "./wandb" + # Directory to store wandb results + watch_model: false + # If true, wandb will track model parameters and gradients diff --git a/examples/generative/corrdiff/conf/config_training_hrrr_mini_regression.yaml b/examples/generative/corrdiff/conf/config_training_hrrr_mini_regression.yaml new file mode 100644 index 0000000000..c94553749d --- /dev/null +++ b/examples/generative/corrdiff/conf/config_training_hrrr_mini_regression.yaml @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: false + name: hrrr_mini_regression + run: + dir: ./output/${hydra:job.name} + searchpath: + - pkg://conf/base # Do not modify + +# Base parameters for dataset, model, training, and validation +defaults: + + - dataset: hrrr_mini + # The dataset type for training. + # Accepted values: + # `gefs_hrrr`: full GEFS-HRRR dataset for continental US. + # `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments. + # `cwb`: full CWB dataset for Taiwan. + # `custom`: user-defined dataset. Parameters need to be specified below. + + - model: regression + # The model type. + # Accepted values: + # `regression`: a regression UNet for deterministic predictions + # `lt_aware_ce_regression`: similar to `regression` but with lead time + # conditioning + # `diffusion`: a diffusion UNet for residual predictions + # `patched_diffusion`: a more memory-efficient diffusion model + # `lt_aware_patched_diffusion`: similar to `patched_diffusion` but + # with lead time conditioning + + - model_size: mini + # The model size configuration. + # Accepted values: + # `normal`: normal model size + # `mini`: smaller model size for fast experiments + + - training: ${model} + # The base training parameters. Determined by the model type. + + +# Dataset parameters. Used for `custom` dataset type. +# Modify or add below parameters that should be passed as argument to the +# user-defined dataset class. +dataset: + data_path: ./data/corrdiff-mini/hrrr_mini_train.nc + # Path to .nc data file + stats_path: ./data/corrdiff-mini/stats.json + # Path to json stats file + +# Training parameters +training: + hp: + training_duration: 2000000 + # Training duration based on the number of processed samples + io: + print_progress_freq: 10000 + +# Parameters for wandb logging +wandb: + mode: offline + # Configure whether to use wandb: "offline", "online", "disabled" + results_dir: "./wandb" + # Directory to store wandb results + watch_model: false + # If true, wandb will track model parameters and gradients diff --git a/examples/generative/corrdiff/conf/config_training_mini_diffusion.yaml b/examples/generative/corrdiff/conf/config_training_mini_diffusion.yaml deleted file mode 100644 index da0fad7a3d..0000000000 --- a/examples/generative/corrdiff/conf/config_training_mini_diffusion.yaml +++ /dev/null @@ -1,34 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -hydra: - job: - chdir: true - name: mini_diffusion - run: - dir: ./outputs/${hydra:job.name} - -# Get defaults -defaults: - - # Dataset - - dataset/hrrrmini - - # Model - - model/corrdiff_diffusion_mini - - # Training - - training/corrdiff_diffusion_mini diff --git a/examples/generative/corrdiff/conf/config_training_mini_regression.yaml b/examples/generative/corrdiff/conf/config_training_mini_regression.yaml deleted file mode 100644 index 231695ea56..0000000000 --- a/examples/generative/corrdiff/conf/config_training_mini_regression.yaml +++ /dev/null @@ -1,34 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -hydra: - job: - chdir: true - name: mini_regression - run: - dir: ./outputs/${hydra:job.name} - -# Get defaults -defaults: - - # Dataset - - dataset/hrrrmini - - # Model - - model/corrdiff_regression_mini - - # Training - - training/corrdiff_regression_mini diff --git a/examples/generative/corrdiff/conf/config_training_taiwan_diffusion.yaml b/examples/generative/corrdiff/conf/config_training_taiwan_diffusion.yaml new file mode 100644 index 0000000000..c2d3e2b78e --- /dev/null +++ b/examples/generative/corrdiff/conf/config_training_taiwan_diffusion.yaml @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: false + name: taiwan_diffusion + run: + dir: ./output/${hydra:job.name} + searchpath: + - pkg://conf/base # Do not modify + +# Base parameters for dataset, model, training, and validation +defaults: + + - dataset: cwb + # The dataset type for training. + # Accepted values: + # `gefs_hrrr`: full GEFS-HRRR dataset for continental US. + # `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments. + # `cwb`: full CWB dataset for Taiwan. + # `custom`: user-defined dataset. Parameters need to be specified below. + + - model: diffusion + # The model type. + # Accepted values: + # `regression`: a regression UNet for deterministic predictions + # `lt_aware_ce_regression`: similar to `regression` but with lead time + # conditioning + # `diffusion`: a diffusion UNet for residual predictions + # `patched_diffusion`: a more memory-efficient diffusion model + # `lt_aware_patched_diffusion`: similar to `patched_diffusion` but + # with lead time conditioning + + - model_size: normal + # The model size configuration. + # Accepted values: + # `normal`: normal model size + # `mini`: smaller model size for fast experiments + + - training: ${model} + # The base training parameters. Determined by the model type. + + +# Dataset parameters. Used for `custom` dataset type. +# Modify or add below parameters that should be passed as argument to the +# user-defined dataset class. +dataset: + data_path: ./data/2023-01-24-cwb-4years.zarr + +model: + hr_mean_conditioning: false + # High-res mean (regression's output) as additional condition + +# Training parameters +training: + hp: + training_duration: 200000000 + # Training duration based on the number of processed samples + lr_rampup: 10000000 + # Rampup for learning rate, in number of samples + io: + regression_checkpoint_path: + # Path to load the regression checkpoint + +# Additional parameters for validation +validation: + train: false + all_times: false + +# Parameters for wandb logging +wandb: + mode: offline + # Configure whether to use wandb: "offline", "online", "disabled" + results_dir: "./wandb" + # Directory to store wandb results + watch_model: false + # If true, wandb will track model parameters and gradients \ No newline at end of file diff --git a/examples/generative/corrdiff/conf/config_training_taiwan_regression.yaml b/examples/generative/corrdiff/conf/config_training_taiwan_regression.yaml new file mode 100644 index 0000000000..43fca74251 --- /dev/null +++ b/examples/generative/corrdiff/conf/config_training_taiwan_regression.yaml @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: false + name: taiwan_regression + run: + dir: ./output/${hydra:job.name} + searchpath: + - pkg://conf/base # Do not modify + +# Base parameters for dataset, model, training, and validation +defaults: + + - dataset: cwb + # The dataset type for training. + # Accepted values: + # `gefs_hrrr`: full GEFS-HRRR dataset for continental US. + # `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments. + # `cwb`: full CWB dataset for Taiwan. + # `custom`: user-defined dataset. Parameters need to be specified below. + + - model: regression + # The model type. + # Accepted values: + # `regression`: a regression UNet for deterministic predictions + # `lt_aware_ce_regression`: similar to `regression` but with lead time + # conditioning + # `diffusion`: a diffusion UNet for residual predictions + # `patched_diffusion`: a more memory-efficient diffusion model + # `lt_aware_patched_diffusion`: similar to `patched_diffusion` but + # with lead time conditioning + + - model_size: normal + # The model size configuration. + # Accepted values: + # `normal`: normal model size + # `mini`: smaller model size for fast experiments + + - training: ${model} + # The base training parameters. Determined by the model type. + + +# Dataset parameters. Used for `custom` dataset type. +# Modify or add below parameters that should be passed as argument to the +# user-defined dataset class. +dataset: + data_path: ./data/2023-01-24-cwb-4years.zarr + +# Training parameters +training: + hp: + training_duration: 200000000 + # Training duration based on the number of processed samples + lr_rampup: 10000000 + # Rampup for learning rate, in number of samples + +# Parameters for wandb logging +wandb: + mode: offline + # Configure whether to use wandb: "offline", "online", "disabled" + results_dir: "./wandb" + # Directory to store wandb results + watch_model: false + # If true, wandb will track model parameters and gradients \ No newline at end of file diff --git a/examples/generative/corrdiff/conf/dataset/cwb_generate.yaml b/examples/generative/corrdiff/conf/dataset/cwb_generate.yaml deleted file mode 100644 index 8e4360f486..0000000000 --- a/examples/generative/corrdiff/conf/dataset/cwb_generate.yaml +++ /dev/null @@ -1,31 +0,0 @@ - -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -type: cwb -data_path: /code/2023-01-24-cwb-4years.zarr -in_channels: [0, 1, 2, 3, 4, 9, 10, 11, 12, 17, 18, 19] -out_channels: [0, 1, 2, 3] -img_shape_x: 448 -img_shape_y: 448 -add_grid: true -ds_factor: 4 -min_path: null -max_path: null -global_means_path: null -global_stds_path: null -train: False -all_times: True \ No newline at end of file diff --git a/examples/generative/corrdiff/conf/generation/base.yaml b/examples/generative/corrdiff/conf/generation/base.yaml deleted file mode 100644 index 039d033ba8..0000000000 --- a/examples/generative/corrdiff/conf/generation/base.yaml +++ /dev/null @@ -1,63 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -num_ensembles: 64 - # Number of ensembles to generate per input -seed_batch_size: 1 - # Size of the batched inference -inference_mode: all - # Choose between "all" (regression + diffusion), "regression" or "diffusion" -patch_size: 448 -patch_shape_x: 448 -patch_shape_y: 448 - # Patch size. Patch-based sampling will be utilized if these dimensions differ from - # img_shape_x and img_shape_y -overlap_pixels: 0 - # Number of overlapping pixels between adjacent patches -boundary_pixels: 0 - # Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary - # artifact. -hr_mean_conditioning: false -gridtype: "sinusoidal" -N_grid_channels: 4 -sample_res: full - # Sampling resolution -times_range: null -times: - - 2021-02-02T00:00:00 - - 2021-03-02T00:00:00 - - 2021-04-02T00:00:00 - # hurricane - - 2021-09-12T00:00:00 - - 2021-09-12T12:00:00 - -perf: - force_fp16: false - # Whether to force fp16 precision for the model. If false, it'll use the precision - # specified upon training. - use_torch_compile: false - # whether to use torch.compile on the diffusion model - # this will make the first time stamp generation very slow due to compilation overheads - # but will significantly speed up subsequent inference runs - num_writer_workers: 1 - # number of workers to use for writing file - # To support multiple workers a threadsafe version of the netCDF library must be used - -io: - res_ckpt_filename: diffusion_checkpoint.mdlus - # Checkpoint filename for the diffusion model - reg_ckpt_filename: regression_checkpoint.mdlus - # Checkpoint filename for the mean predictor model diff --git a/examples/generative/corrdiff/conf/generation/mini.yaml b/examples/generative/corrdiff/conf/generation/mini.yaml deleted file mode 100644 index a2842d0a93..0000000000 --- a/examples/generative/corrdiff/conf/generation/mini.yaml +++ /dev/null @@ -1,39 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -defaults: - - base - -num_ensembles: 2 - # Number of ensembles to generate per input -seed_batch_size: 1 - # Size of the batched inference -inference_mode: all - # Choose between "all" (regression + diffusion), "regression" or "diffusion" -hr_mean_conditioning: True -gridtype: "sinusoidal" -N_grid_channels: 4 -sample_res: full - # Sampling resolution -times_range: null -times: - - 2020-02-02T00:00:00 - -io: - res_ckpt_filename: diffusion_checkpoint.mdlus - # Checkpoint filename for the diffusion model - reg_ckpt_filename: regression_checkpoint.mdlus - # Checkpoint filename for the mean predictor model diff --git a/examples/generative/corrdiff/conf/generation/patched_based_gefs_hrrr.yaml b/examples/generative/corrdiff/conf/generation/patched_based_gefs_hrrr.yaml deleted file mode 100644 index 26fbc2fa76..0000000000 --- a/examples/generative/corrdiff/conf/generation/patched_based_gefs_hrrr.yaml +++ /dev/null @@ -1,68 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -num_ensembles: 1 - # Number of ensembles to generate per input -seed_batch_size: 1 - # Size of the batched inference -inference_mode: all - # Choose between "all" (regression + diffusion), "regression" or "diffusion" -patch_size: 448 -patch_shape_x: 448 -patch_shape_y: 448 - # Patch size. Patch-based sampling will be utilized if these dimensions differ from - # img_shape_x and img_shape_y -overlap_pixels: 4 - # Number of overlapping pixels between adjacent patches -boundary_pixels: 2 - # Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary - # artifact. -hr_mean_conditioning: true -gridtype: learnable -N_grid_channels: 100 -sample_res: full - # Sampling resolution -times_range: null -times: - - "2024011212f00" - - "2024011212f03" - - "2024011212f06" - - "2024011212f09" - - "2024011212f12" - - "2024011212f15" - - "2024011212f18" - - "2024011212f21" - - "2024011212f24" - -has_lead_time: true - -perf: - force_fp16: false - # Whether to force fp16 precision for the model. If false, it'll use the precision - # specified upon training. - use_torch_compile: false - # whether to use torch.compile on the diffusion model - # this will make the first time stamp generation very slow due to compilation overheads - # but will significantly speed up subsequent inference runs - num_writer_workers: 1 - # number of workers to use for writing file - # To support multiple workers a threadsafe version of the netCDF library must be used - -io: - res_ckpt_filename: EDMPrecondSRV2_updated.0.5821440.mdlus - # Checkpoint filename for the diffusion model - reg_ckpt_filename: UNet_updated.0.1960960.mdlus - # Checkpoint filename for the mean predictor model diff --git a/examples/generative/corrdiff/conf/references/config_data_ref.yaml b/examples/generative/corrdiff/conf/references/config_data_ref.yaml deleted file mode 100644 index e4372926ec..0000000000 --- a/examples/generative/corrdiff/conf/references/config_data_ref.yaml +++ /dev/null @@ -1,597 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -full_field: &FULL_FIELD - - batch_size: 1 - dt: 1 - n_history: 0 - img_shape_x: 448 - img_shape_y: 448 - normalization: v1 #minmax - # era5_data_dir: '/lustre/fsw/sw_climate_fno/cwb-align' #old - # cwb_data_dir: '/lustre/fsw/sw_climate_fno/cwb-rwrf-pad-2212/all_ranges' #old - # relative to CWB_ROOT environment variable - train_data_path: '2023-01-24-cwb-4years.zarr' #new, zarr - num_data_workers: 1 - add_grid: !!bool False #adds position embedding - N_grid_channels: 0 - gridtype: 'sinusoidal' #options 'sinusoidal' or 'linear' - add_topo: !!bool False #adds position embedding - ds_factor: 1 - - - -# era5-cwb-crop448-grid-fcn -full_field_train_crop448_grid_12inchans_fcn_4outchans_4x_6layer: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [0,1,2,3,4,9,10,11,12,17,18,19] - out_channels: [0,17,18,19] #[17,18,19] - roll: !!bool False - patch_size: 448 - crop_size_x: 448 - crop_size_y: 448 - N_grid_channels: 4 - add_grid: True - ds_factor: 4 - - -full_field_val_crop448_grid_12inchans_fcn_4outchans_4x_6layer: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [0,1,2,3,4,9,10,11,12,17,18,19] - out_channels: [0,17,18,19] #[0,9,10,11] - roll: !!bool False - patch_size: 448 - crop_size_x: 448 - crop_size_y: 448 - N_grid_channels: 4 - add_grid: True - ds_factor: 4 - -out_of_sample: - <<: *FULL_FIELD - in_channels: [0,1,2,3,4,9,10,11,12,17,18,19] - out_channels: [0,17,18,19] #[0,9,10,11] - roll: !!bool False - patch_size: 448 - crop_size_x: 448 - crop_size_y: 448 - N_grid_channels: 4 - add_grid: True - ds_factor: 4 - # works on NGC (by noah) - train_data_path: '/root/data/diffusions/2023-05-31-very-out-of-sample.nc' #new, zarr - - -# era5-cwb-crop448-grid-fcn -full_field_train_crop448_grid_12inchans_fcn_4outchans_4x_normv2: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [0,1,2,3,4,9,10,11,12,17,18,19] - out_channels: [0,17,18,19] #[17,18,19] - roll: !!bool False - patch_size: 448 - crop_size_x: 448 - crop_size_y: 448 - N_grid_channels: 4 - add_grid: True - ds_factor: 4 - normalization: v2 - - - -# era5-cwb-crop448-grid-fcn -full_field_train_crop448_grid_12inchans_fcn_4outchans_4x: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [0,1,2,3,4,9,10,11,12,17,18,19] - out_channels: [0,17,18,19] #[17,18,19] - roll: !!bool False - patch_size: 448 - crop_size_x: 448 - crop_size_y: 448 - N_grid_channels: 4 - add_grid: True - ds_factor: 4 - - -full_field_val_crop448_grid_12inchans_fcn_4outchans_4x: &validation #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [0,1,2,3,4,9,10,11,12,17,18,19] - out_channels: [0,17,18,19] #[0,9,10,11] - roll: !!bool False - patch_size: 448 - crop_size_x: 448 - crop_size_y: 448 - N_grid_channels: 4 - add_grid: True - ds_factor: 4 - - -# era5-cwb-crop448-grid -full_field_train_crop448_grid_20inchans_4outchans_4x: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - out_channels: [0,17,18,19] #[17,18,19] - roll: !!bool False - patch_size: 448 - crop_size_x: 448 - crop_size_y: 448 - N_grid_channels: 4 - add_grid: True - ds_factor: 4 - - -full_field_val_crop448_grid_20inchans_4outchans_4x: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - out_channels: [0,17,18,19] #[17, 18, 19] - roll: !!bool False - patch_size: 448 - crop_size_x: 448 - crop_size_y: 448 - N_grid_channels: 4 - add_grid: True - ds_factor: 4 - - -# era5-cwb-crop112-grid -full_field_train_crop112_grid_20inchans_4outchans_4x: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - out_channels: [0,17,18,19] #[17,18,19] - roll: !!bool False - patch_size: 112 - crop_size_x: 112 - crop_size_y: 112 - N_grid_channels: 4 - add_grid: True - ds_factor: 4 - - -full_field_val_crop112_grid_20inchans_4outchans_4x: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - out_channels: [0,17,18,19] #[17, 18, 19] - roll: !!bool False - patch_size: 112 - crop_size_x: 448 - crop_size_y: 448 - N_grid_channels: 4 - add_grid: True - ds_factor: 4 - - -# era5-cwb-crop112-grid -full_field_train_crop112_grid_20inchans_19outchans_4x: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - out_channels: [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] #[17,18,19] - roll: !!bool False - patch_size: 112 - crop_size_x: 112 - crop_size_y: 112 - N_grid_channels: 4 - add_grid: True - ds_factor: 4 - - -full_field_val_crop112_grid_20inchans_19outchans_4x: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - out_channels: [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] #[17, 18, 19] - roll: !!bool False - patch_size: 112 - crop_size_x: 448 - crop_size_y: 448 - N_grid_channels: 4 - add_grid: True - ds_factor: 4 - - - -# era5-cwb-crop112-grid -full_field_train_crop112_grid_20inchans_4x: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - out_channels: [17,18,19] #[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - roll: !!bool False - patch_size: 112 - crop_size_x: 112 - crop_size_y: 112 - N_grid_channels: 4 - add_grid: True - ds_factor: 4 - - -full_field_val_crop112_grid_20inchans_4x: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - out_channels: [17, 18, 19] #[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - roll: !!bool False - patch_size: 112 - crop_size_x: 448 - crop_size_y: 448 - N_grid_channels: 4 - add_grid: True - ds_factor: 4 - - -# era5-cwb-crop64-grid -full_field_train_crop64_grid_20inchans_4x: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - out_channels: [17, 18, 19] #[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - roll: !!bool False - patch_size: 64 - crop_size_x: 64 - crop_size_y: 64 - N_grid_channels: 4 - add_grid: True - ds_factor: 4 - - -full_field_val_crop64_grid_20inchans_4x: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - out_channels: [17, 18, 19] #[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - roll: !!bool False - patch_size: 64 - crop_size_x: 448 - crop_size_y: 448 - N_grid_channels: 4 - add_grid: True - ds_factor: 4 - -# era5-cwb-crop64-grid -full_field_train_crop64_grid_20inchans: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - out_channels: [17, 18, 19] #[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - roll: !!bool False - patch_size: 64 - crop_size_x: 64 - crop_size_y: 64 - N_grid_channels: 4 - add_grid: True - - -full_field_val_crop64_grid_20inchans: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - out_channels: [17, 18, 19] #[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - roll: !!bool False - patch_size: 64 - crop_size_x: 448 - crop_size_y: 448 - N_grid_channels: 4 - add_grid: True - - -# era5-cwb-crop64-grid -full_field_train_crop64_grid: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [17, 18, 19] #[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - out_channels: [17, 18, 19] #[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - roll: !!bool False - patch_size: 64 - crop_size_x: 64 - crop_size_y: 64 - N_grid_channels: 4 - add_grid: True - - -full_field_val_crop64_grid: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [17, 18, 19] #[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - out_channels: [17, 18, 19] #[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - roll: !!bool False - patch_size: 64 - crop_size_x: 448 - crop_size_y: 448 - N_grid_channels: 4 - add_grid: True - - - -# era5-cwb-crop64 -full_field_train_crop64: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [17.18,19] #[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - out_channels: [17, 18, 19] #[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - roll: !!bool False - patch_size: 64 - crop_size_x: 64 - crop_size_y: 64 - - -full_field_val_crop64: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [17, 18, 19] #[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - out_channels: [17, 18, 19] #[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - roll: !!bool False - patch_size: 64 - crop_size_x: 448 - crop_size_y: 448 - - -# era5-cwb-crop112 -full_field_train_crop112: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [17.18,19] #[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - out_channels: [17, 18, 19] #[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - roll: !!bool False - patch_size: 112 - crop_size_x: 112 - crop_size_y: 112 - - -full_field_val_crop112: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [17, 18, 19] #[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - out_channels: [17, 18, 19] #[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - roll: !!bool False - patch_size: 112 - crop_size_x: 112 - crop_size_y: 112 - - -# era5-cwb-crop224 -full_field_train_crop224: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [17.18,19] #[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - out_channels: [17, 18, 19] #[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - roll: !!bool False - patch_size: 224 - crop_size_x: 224 - crop_size_y: 224 - - -full_field_val_crop224: #config for single gpu training to catch bugs and overfit - <<: *FULL_FIELD - in_channels: [17, 18, 19] #[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - out_channels: [17, 18, 19] #[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] - roll: !!bool False - patch_size: 224 - crop_size_x: 224 - crop_size_y: 224 - -validation_small: - <<: *validation - times: - - 2021-02-02T00:00:00 - - 2021-03-02T00:00:00 - - 2021-04-02T00:00:00 - # hurricane - - 2021-09-12T00:00:00 - - 2021-09-12T12:00:00 - -test: - <<: *validation - times: - - 2021-02-02T00:00:00 - -validation_big: - <<: *validation - times: - # old times - - 2021-02-02T00:00:00 - - 2021-03-02T00:00:00 - - 2021-04-02T00:00:00 - # hurricane - - 2021-09-12T00:00:00 - - 2021-09-12T12:00:00 - # new times - # generated this list of 200 times using - # python3 work/noah/get_random_times.py - - 2021-02-02T06:00:00 - - 2021-02-07T07:00:00 - - 2021-02-08T16:00:00 - - 2021-02-09T20:00:00 - - 2021-02-10T14:00:00 - - 2021-02-10T17:00:00 - - 2021-02-11T08:00:00 - - 2021-02-11T20:00:00 - - 2021-02-15T20:00:00 - - 2021-02-17T18:00:00 - - 2021-02-17T21:00:00 - - 2021-02-18T09:00:00 - - 2021-02-18T16:00:00 - - 2021-02-19T09:00:00 - - 2021-02-20T10:00:00 - - 2021-02-22T18:00:00 - - 2021-02-25T20:00:00 - - 2021-02-26T10:00:00 - - 2021-02-27T23:00:00 - - 2021-02-28T04:00:00 - - 2021-03-04T01:00:00 - - 2021-03-04T11:00:00 - - 2021-03-04T22:00:00 - - 2021-03-06T00:00:00 - - 2021-03-07T13:00:00 - - 2021-03-08T01:00:00 - - 2021-03-10T16:00:00 - - 2021-03-13T08:00:00 - - 2021-03-14T08:00:00 - - 2021-03-14T09:00:00 - - 2021-03-14T12:00:00 - - 2021-03-15T16:00:00 - - 2021-03-16T09:00:00 - - 2021-03-16T21:00:00 - - 2021-03-19T05:00:00 - - 2021-03-20T08:00:00 - - 2021-03-22T18:00:00 - - 2021-03-23T02:00:00 - - 2021-03-23T12:00:00 - - 2021-03-24T09:00:00 - - 2021-03-25T22:00:00 - - 2021-03-27T17:00:00 - - 2021-03-27T19:00:00 - - 2021-03-28T17:00:00 - - 2021-03-29T06:00:00 - - 2021-04-01T16:00:00 - - 2021-04-02T01:00:00 - - 2021-04-04T13:00:00 - - 2021-04-05T01:00:00 - - 2021-04-05T08:00:00 - - 2021-04-08T12:00:00 - - 2021-04-12T18:00:00 - - 2021-04-14T02:00:00 - - 2021-04-14T09:00:00 - - 2021-04-17T05:00:00 - - 2021-04-17T21:00:00 - - 2021-04-19T07:00:00 - - 2021-04-22T01:00:00 - - 2021-04-23T01:00:00 - - 2021-04-23T08:00:00 - - 2021-04-24T21:00:00 - - 2021-04-28T02:00:00 - - 2021-04-28T08:00:00 - - 2021-04-28T17:00:00 - - 2021-04-29T21:00:00 - - 2021-07-02T05:00:00 - - 2021-07-03T21:00:00 - - 2021-07-03T22:00:00 - - 2021-07-04T11:00:00 - - 2021-07-09T18:00:00 - - 2021-07-11T15:00:00 - - 2021-07-12T00:00:00 - - 2021-07-12T05:00:00 - - 2021-07-13T20:00:00 - - 2021-07-16T21:00:00 - - 2021-07-18T14:00:00 - - 2021-07-19T11:00:00 - - 2021-07-21T00:00:00 - - 2021-07-22T07:00:00 - - 2021-07-23T01:00:00 - - 2021-07-26T05:00:00 - - 2021-07-28T03:00:00 - - 2021-07-29T11:00:00 - - 2021-07-29T18:00:00 - - 2021-07-31T05:00:00 - - 2021-07-31T06:00:00 - - 2021-08-01T06:00:00 - - 2021-08-03T08:00:00 - - 2021-08-06T21:00:00 - - 2021-08-07T11:00:00 - - 2021-08-08T20:00:00 - - 2021-08-08T23:00:00 - - 2021-08-09T03:00:00 - - 2021-08-10T04:00:00 - - 2021-08-10T19:00:00 - - 2021-08-11T06:00:00 - - 2021-08-15T11:00:00 - - 2021-08-16T01:00:00 - - 2021-08-16T04:00:00 - - 2021-08-16T11:00:00 - - 2021-08-17T09:00:00 - - 2021-08-17T21:00:00 - - 2021-08-20T10:00:00 - - 2021-08-22T20:00:00 - - 2021-08-23T06:00:00 - - 2021-08-24T16:00:00 - - 2021-08-25T11:00:00 - - 2021-08-26T01:00:00 - - 2021-08-26T17:00:00 - - 2021-08-28T16:00:00 - - 2021-08-30T01:00:00 - - 2021-08-30T10:00:00 - - 2021-08-30T19:00:00 - - 2021-08-30T22:00:00 - - 2021-08-31T16:00:00 - - 2021-09-01T09:00:00 - - 2021-09-03T06:00:00 - - 2021-09-05T19:00:00 - - 2021-09-12T13:00:00 - - 2021-09-12T20:00:00 - - 2021-09-14T06:00:00 - - 2021-09-17T06:00:00 - - 2021-09-18T06:00:00 - - 2021-09-18T10:00:00 - - 2021-09-19T06:00:00 - - 2021-09-20T00:00:00 - - 2021-09-21T00:00:00 - - 2021-09-21T01:00:00 - - 2021-09-21T17:00:00 - - 2021-09-22T01:00:00 - - 2021-09-22T17:00:00 - - 2021-09-26T03:00:00 - - 2021-09-26T19:00:00 - - 2021-09-30T05:00:00 - - 2021-09-30T13:00:00 - - 2021-10-03T10:00:00 - - 2021-10-03T13:00:00 - - 2021-10-04T13:00:00 - - 2021-10-04T14:00:00 - - 2021-10-04T17:00:00 - - 2021-10-05T04:00:00 - - 2021-10-05T08:00:00 - - 2021-10-06T22:00:00 - - 2021-10-08T09:00:00 - - 2021-10-08T15:00:00 - - 2021-10-09T03:00:00 - - 2021-10-11T22:00:00 - - 2021-10-18T06:00:00 - - 2021-10-22T03:00:00 - - 2021-10-23T02:00:00 - - 2021-10-24T08:00:00 - - 2021-10-24T11:00:00 - - 2021-10-24T16:00:00 - - 2021-10-24T22:00:00 - - 2021-10-26T04:00:00 - - 2021-10-27T22:00:00 - - 2021-10-28T19:00:00 - - 2021-10-30T12:00:00 - - 2021-10-30T16:00:00 - - 2021-10-30T23:00:00 - - 2021-10-31T08:00:00 - - 2021-11-02T01:00:00 - - 2021-11-04T23:00:00 - - 2021-11-06T06:00:00 - - 2021-11-07T15:00:00 - - 2021-11-09T10:00:00 - - 2021-11-09T20:00:00 - - 2021-11-12T13:00:00 - - 2021-11-13T17:00:00 - - 2021-11-14T14:00:00 - - 2021-11-15T00:00:00 - - 2021-11-15T04:00:00 - - 2021-11-16T06:00:00 - - 2021-11-20T10:00:00 - - 2021-11-23T13:00:00 - - 2021-11-23T21:00:00 - - 2021-11-26T03:00:00 - - 2021-11-26T06:00:00 - - 2021-11-26T07:00:00 - - 2021-11-26T14:00:00 - - 2021-12-02T10:00:00 - - 2021-12-05T20:00:00 - - 2021-12-06T03:00:00 - - 2021-12-10T06:00:00 - - 2021-12-11T07:00:00 - - 2021-12-14T13:00:00 - - 2021-12-16T05:00:00 - - 2021-12-17T01:00:00 - - 2021-12-17T22:00:00 - - 2021-12-18T00:00:00 - - 2021-12-18T03:00:00 - - 2021-12-19T03:00:00 - - 2021-12-19T21:00:00 - - 2021-12-21T07:00:00 - - 2021-12-22T11:00:00 - - 2021-12-22T22:00:00 - - 2021-12-25T08:00:00 - - 2021-12-25T18:00:00 - - 2021-12-25T22:00:00 - - 2021-12-31T23:00:00 diff --git a/examples/generative/corrdiff/conf/training/corrdiff_diffusion.yaml b/examples/generative/corrdiff/conf/training/corrdiff_diffusion.yaml deleted file mode 100644 index a0855e759e..0000000000 --- a/examples/generative/corrdiff/conf/training/corrdiff_diffusion.yaml +++ /dev/null @@ -1,55 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Hyperparameters -hp: - training_duration: 200000000 - # Training duration based on the number of processed samples - total_batch_size: 256 - # Total batch size - batch_size_per_gpu: 2 - # Batch size per GPU - lr: 0.0002 - # Learning rate - grad_clip_threshold: null - # no gradient clipping for defualt non-patch-based training - lr_decay: 1 - # LR decay rate - lr_rampup: 10000000 - # Rampup for learning rate, in number of samples - -# Performance -perf: - fp_optimizations: fp32 - # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] - # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} - dataloader_workers: 4 - # DataLoader worker processes - songunet_checkpoint_level: 0 # 0 means no checkpointing - # Gradient checkpointing level, value is number of layers to checkpoint - -# I/O -io: - regression_checkpoint_path: checkpoints/regression.mdlus - # Where to load the regression checkpoint - print_progress_freq: 1000 - # How often to print progress - save_checkpoint_freq: 5000 - # How often to save the checkpoints, measured in number of processed samples - validation_freq: 5000 - # how often to record the validation loss, measured in number of processed samples - validation_steps: 10 - # how many loss evaluations are used to compute the validation loss per checkpoint diff --git a/examples/generative/corrdiff/conf/training/corrdiff_diffusion_mini.yaml b/examples/generative/corrdiff/conf/training/corrdiff_diffusion_mini.yaml deleted file mode 100644 index 9c9bdf0d14..0000000000 --- a/examples/generative/corrdiff/conf/training/corrdiff_diffusion_mini.yaml +++ /dev/null @@ -1,44 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -defaults: - - corrdiff_diffusion - - -# Hyperparameters -hp: - training_duration: 8000000 - # Training duration based on the number of processed samples - total_batch_size: 256 - # Total batch size - batch_size_per_gpu: "auto" - # Batch size per GPU - lr_rampup: 0 - # Rampup for learning rate, in number of samples - -# Performance -perf: - fp_optimizations: amp-bf16 - # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] - # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} - dataloader_workers: 1 - # DataLoader worker processes - -# I/O -io: - regression_checkpoint_path: checkpoints/regression_mini.mdlus - # Where to load the regression checkpoint - print_progress_freq: 10000 diff --git a/examples/generative/corrdiff/conf/training/corrdiff_patched_diffusion.yaml b/examples/generative/corrdiff/conf/training/corrdiff_patched_diffusion.yaml deleted file mode 100644 index 4c3448fe0f..0000000000 --- a/examples/generative/corrdiff/conf/training/corrdiff_patched_diffusion.yaml +++ /dev/null @@ -1,60 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Hyperparameters -hp: - training_duration: 200000 - # Training duration based on the number of processed images, measured in kilo images (thousands of images) - total_batch_size: 2560 - # Total batch size - batch_size_per_gpu: 20 - # Batch size per GPU - lr: 0.0002 - # Learning rate - grad_clip_threshold: 1e5 - # no gradient clipping for defualt non-patch-based training - lr_decay: 0.7 - # LR decay rate - patch_shape_x: 64 - patch_shape_y: 64 - # Patch size. Patch training is used if these dimensions differ from img_shape_x and img_shape_y - patch_num: 10 - # Number of patches from a single sample. Total number of patches is patch_num * batch_size_global - lr_rampup: 10000000 - # Rampup for learning rate, in number of samples - -# Performance -perf: - fp_optimizations: fp32 - # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] - # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} - dataloader_workers: 4 - # DataLoader worker processes - songunet_checkpoint_level: 0 # 0 means no checkpointing - # Gradient checkpointing level, value is number of layers to checkpoint - -# I/O -io: - regression_checkpoint_path: checkpoints/regression.mdlus - # Where to load the regression checkpoint - print_progress_freq: 1000 - # How often to print progress - save_checkpoint_freq: 500000 - # How often to save the checkpoints, measured in number of processed samples - validation_freq: 50000 - # how often to record the validation loss, measured in number of processed samples - validation_steps: 10 - # how many loss evaluations are used to compute the validation loss per checkpoint diff --git a/examples/generative/corrdiff/conf/training/corrdiff_regression_mini.yaml b/examples/generative/corrdiff/conf/training/corrdiff_regression_mini.yaml deleted file mode 100644 index 92c2a136d1..0000000000 --- a/examples/generative/corrdiff/conf/training/corrdiff_regression_mini.yaml +++ /dev/null @@ -1,43 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -defaults: - - corrdiff_regression - - -# Hyperparameters -hp: - training_duration: 2000000 - # Training duration based on the number of processed samples - total_batch_size: 256 - # Total batch size - batch_size_per_gpu: "auto" - # Batch size per GPU - lr_rampup: 0 - # Rampup for learning rate, in number of samples - -# Performance -perf: - fp_optimizations: amp-bf16 - # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] - # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} - dataloader_workers: 1 - # DataLoader worker processes - -# I/O -io: - # Where to load the regression checkpoint - print_progress_freq: 10000 \ No newline at end of file diff --git a/examples/generative/corrdiff/datasets/base.py b/examples/generative/corrdiff/datasets/base.py index 22b00d252c..d7b783fa47 100644 --- a/examples/generative/corrdiff/datasets/base.py +++ b/examples/generative/corrdiff/datasets/base.py @@ -32,7 +32,11 @@ class ChannelMetadata: class DownscalingDataset(torch.utils.data.Dataset, ABC): - """An abstract class that defines the interface for downscaling datasets.""" + """Abstract class for dataset with downscaling paired data + + A DownscalingDataset has high-resolution output data (target) paired with + low-resolution input data. + """ @abstractmethod def longitude(self) -> np.ndarray: diff --git a/examples/generative/corrdiff/datasets/cwb.py b/examples/generative/corrdiff/datasets/cwb.py index 91f469633c..de0555e581 100644 --- a/examples/generative/corrdiff/datasets/cwb.py +++ b/examples/generative/corrdiff/datasets/cwb.py @@ -97,12 +97,11 @@ def __getitem__(self, idx): idx_to_load = self._get_valid_time_index(idx) target = self.group["cwb"][idx_to_load] input = self.group["era5"][idx_to_load] - label = 0 target = self.normalize_output(target[None, ...])[0] input = self.normalize_input(input[None, ...])[0] - return target, input, label + return target, input def longitude(self): """The longitude. useful for plotting""" @@ -396,7 +395,7 @@ def info(self): return self._dataset.info() def __getitem__(self, idx): - (target, input, _) = self._dataset[idx] + (target, input) = self._dataset[idx] # crop and downsamples # rolling if self.train and self.roll: @@ -437,7 +436,7 @@ def __getitem__(self, idx): target, "tar", *reshape_args, normalize=False ) # 3x720x1440 - return target, input, idx + return target, input def input_channels(self): """Metadata for the input channels. A list of dictionaries, one for each channel""" diff --git a/examples/generative/corrdiff/datasets/dataset.py b/examples/generative/corrdiff/datasets/dataset.py index 8c7f92ee0e..883029ae83 100644 --- a/examples/generative/corrdiff/datasets/dataset.py +++ b/examples/generative/corrdiff/datasets/dataset.py @@ -17,6 +17,8 @@ from typing import Iterable, Tuple, Union import copy import torch +import importlib.util +from pathlib import Path from physicsnemo.utils.generative import InfiniteSampler from physicsnemo.distributed import DistributedManager @@ -32,6 +34,59 @@ } +def register_dataset(dataset_spec: str) -> None: + """ + Register a new dataset class from a file path specification. + + Parameters + ---------- + dataset_spec : str + String specification in the format "path_to_file.py::dataset_class" + + Raises + ------ + ValueError + If the dataset_spec format is invalid or if the file doesn't exist + ImportError + If the dataset class cannot be imported + """ + if dataset_spec in known_datasets: + return # Dataset already registered + try: + file_path, class_name = dataset_spec.split("::") + except ValueError: + raise ValueError( + "Invalid dataset specification. Expected format: " + "'path_to_file.py::dataset_class'" + ) + if class_name in known_datasets: + return # Dataset already registered + + # Convert to Path and validate + file_path = Path(file_path) + if not file_path.exists(): + raise ValueError(f"Dataset file not found: {file_path}") + if not file_path.suffix == ".py": + raise ValueError(f"Dataset file must be a Python file: {file_path}") + + # Import the module and get the class + spec = importlib.util.spec_from_file_location(file_path.stem, str(file_path)) + if spec is None or spec.loader is None: + raise ImportError(f"Could not load spec for {file_path}") + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + try: + dataset_class = getattr(module, class_name) + except AttributeError: + raise ImportError(f"Could not find dataset class '{class_name}' in {file_path}") + + # Register the dataset + known_datasets[dataset_spec] = dataset_class + return + + def init_train_valid_datasets_from_config( dataset_cfg: dict, dataloader_cfg: Union[dict, None] = None, diff --git a/examples/generative/corrdiff/datasets/gefs_hrrr.py b/examples/generative/corrdiff/datasets/gefs_hrrr.py index d17a5d5d2a..114c1d39d4 100644 --- a/examples/generative/corrdiff/datasets/gefs_hrrr.py +++ b/examples/generative/corrdiff/datasets/gefs_hrrr.py @@ -126,6 +126,8 @@ class HrrrForecastGEFSDataset(DownscalingDataset): Expects data to be stored under directory specified by 'location' GEFS under /gefs/ HRRR under /hrrr/ + Within each directory, there should be one zarr file per + year containing the data of interest. """ def __init__( @@ -142,7 +144,7 @@ def __init__( train_years: Iterable[int] = (2020, 2021, 2022, 2023), valid_years: Iterable[int] = (2024,), hrrr_window: Union[Tuple[Tuple[int, int], Tuple[int, int]], None] = None, - sample_shape: Tuple[int, int] = None, + sample_shape: Tuple[int, int] = [-1, -1], ds_factor: int = 1, shard: bool = False, overfit: bool = False, @@ -468,7 +470,7 @@ def image_shape(self) -> Tuple[int, int]: return (y_end - y_start, x_end - x_start) def _get_crop_box(self): - if self.sample_shape == None: + if self.sample_shape == [-1, -1]: return self.hrrr_window ((y_start, y_end), (x_start, x_end)) = self.hrrr_window @@ -480,10 +482,11 @@ def _get_crop_box(self): return ((y0, y1), (x0, x1)) def __getitem__(self, global_idx): + """Return a tuple of: + - hrrr_field: High-resolution HRRR output data + - gefs_field: Low-resolution GEFS input data + - lead_time_label: Lead time """ - Return data as a dict (so we can potentially add extras, metadata, etc if desired - """ - torch.cuda.nvtx.range_push("hrrr_dataloader:get") if self.overfit: global_idx = 42 time_index = self._global_idx_to_datetime(global_idx) @@ -505,7 +508,7 @@ def __getitem__(self, global_idx): ) gefs_sample = self.normalize_input(gefs_sample) torch.cuda.nvtx.range_pop() - return hrrr_sample, gefs_sample, global_idx, int(time_index[-2:]) // 3 + return hrrr_sample, gefs_sample, int(time_index[-2:]) // 3 def _global_idx_to_datetime(self, global_idx): """ diff --git a/examples/generative/corrdiff/datasets/hrrrmini.py b/examples/generative/corrdiff/datasets/hrrrmini.py index 4537e0780a..d9ade48250 100644 --- a/examples/generative/corrdiff/datasets/hrrrmini.py +++ b/examples/generative/corrdiff/datasets/hrrrmini.py @@ -70,7 +70,7 @@ def __init__( ) def __getitem__(self, idx): - """Return the data sample (output, input, 0) at index idx.""" + """Return the data sample (output, input) at index idx.""" x = self.upsample(self.input[idx].copy()) # add invariants to input @@ -82,7 +82,7 @@ def __getitem__(self, idx): x = self.normalize_input(x) y = self.normalize_output(y) - return (y, x, 0) + return (y, x) def __len__(self): return self.input.shape[0] diff --git a/examples/generative/corrdiff/generate.py b/examples/generative/corrdiff/generate.py index 6f3a1ff669..ae5460033a 100644 --- a/examples/generative/corrdiff/generate.py +++ b/examples/generative/corrdiff/generate.py @@ -21,12 +21,15 @@ import nvtx import numpy as np import netCDF4 as nc +import contextlib + from physicsnemo.distributed import DistributedManager from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper +from physicsnemo.utils.patching import GridPatching2D from physicsnemo import Module + from concurrent.futures import ThreadPoolExecutor from functools import partial -from einops import rearrange from torch.distributed import gather @@ -45,6 +48,7 @@ save_images, ) from helpers.train_helpers import set_patch_shape +from datasets.dataset import register_dataset @hydra.main(version_base="1.2", config_path="conf", config_name="config_generate") @@ -85,6 +89,11 @@ def main(cfg: DictConfig) -> None: # Create dataset object dataset_cfg = OmegaConf.to_container(cfg.dataset) + + # Register dataset (if custom dataset) + register_dataset(cfg.dataset.type) + logger0.info(f"Using dataset: {cfg.dataset.type}") + if "has_lead_time" in cfg.generation: has_lead_time = cfg.generation["has_lead_time"] else: @@ -96,19 +105,23 @@ def main(cfg: DictConfig) -> None: img_out_channels = len(dataset.output_channels()) # Parse the patch shape - if hasattr(cfg.generation, "patch_shape_x"): # TODO better config handling + if cfg.generation.patching: patch_shape_x = cfg.generation.patch_shape_x - else: - patch_shape_x = None - if hasattr(cfg.generation, "patch_shape_y"): patch_shape_y = cfg.generation.patch_shape_y else: - patch_shape_y = None + patch_shape_x, patch_shape_y = None, None patch_shape = (patch_shape_y, patch_shape_x) - img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) - if patch_shape != img_shape: + use_patching, img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) + if use_patching: + patching = GridPatching2D( + img_shape=img_shape, + patch_shape=patch_shape, + boundary_pix=cfg.generation.boundary_pix, + overlap_pix=cfg.generation.overlap_pix, + ) logger0.info("Patch-based training enabled") else: + patching = None logger0.info("Patch-based training disabled") # Parse the inference mode @@ -164,44 +177,27 @@ def main(cfg: DictConfig) -> None: solver=cfg.sampler.solver, ) elif cfg.sampler.type == "stochastic": - sampler_fn = partial( - stochastic_sampler, - img_shape=img_shape[1], - patch_shape=patch_shape[1], - boundary_pix=cfg.sampler.boundary_pix, - overlap_pix=cfg.sampler.overlap_pix, - ) + sampler_fn = partial(stochastic_sampler, patching=patching) else: raise ValueError(f"Unknown sampling method {cfg.sampling.type}") # Main generation definition def generate_fn(): - img_shape_y, img_shape_x = img_shape with nvtx.annotate("generate_fn", color="green"): - if cfg.generation.sample_res == "full": - image_lr_patch = image_lr - else: - torch.cuda.nvtx.range_push("rearrange") - image_lr_patch = rearrange( - image_lr, - "b c (h1 h) (w1 w) -> (b h1 w1) c h w", - h1=img_shape_y // patch_shape[0], - w1=img_shape_x // patch_shape[1], - ) - torch.cuda.nvtx.range_pop() - image_lr_patch = image_lr_patch.to(memory_format=torch.channels_last) + # (1, C, H, W) + img_lr = image_lr.to(memory_format=torch.channels_last) if net_reg: with nvtx.annotate("regression_model", color="yellow"): image_reg = regression_step( net=net_reg, - img_lr=image_lr_patch, + img_lr=img_lr, latents_shape=( cfg.generation.seed_batch_size, img_out_channels, img_shape[0], img_shape[1], - ), + ), # (batch_size, C, H, W) lead_time_label=lead_time_label, ) if net_res: @@ -213,16 +209,15 @@ def generate_fn(): image_res = diffusion_step( net=net_res, sampler_fn=sampler_fn, - seed_batch_size=cfg.generation.seed_batch_size, img_shape=img_shape, img_out_channels=img_out_channels, rank_batches=rank_batches, - img_lr=image_lr_patch.expand( + img_lr=img_lr.expand( cfg.generation.seed_batch_size, -1, -1, -1 ).to(memory_format=torch.channels_last), rank=dist.rank, device=device, - hr_mean=mean_hr, + mean_hr=mean_hr, lead_time_label=lead_time_label, ) if cfg.generation.inference_mode == "regression": @@ -232,13 +227,6 @@ def generate_fn(): else: image_out = image_reg + image_res - if cfg.generation.sample_res != "full": - image_out = rearrange( - image_out, - "(b h1 w1) c h w -> b c (h1 h) (w1 w)", - h1=img_shape_y // patch_shape[0], - w1=img_shape_x // patch_shape[1], - ) # Gather tensors on rank 0 if dist.world_size > 1: if dist.rank == 0: @@ -279,8 +267,18 @@ def generate_fn(): # add attributes f.cfg = str(cfg) - with torch.cuda.profiler.profile(): - with torch.autograd.profiler.emit_nvtx(): + torch_cuda_profiler = ( + torch.cuda.profiler.profile() + if torch.cuda.is_available() + else contextlib.nullcontext() + ) + torch_nvtx_profiler = ( + torch.autograd.profiler.emit_nvtx() + if torch.cuda.is_available() + else contextlib.nullcontext() + ) + with torch_cuda_profiler: + with torch_nvtx_profiler: data_loader = torch.utils.data.DataLoader( dataset=dataset, sampler=sampler, batch_size=1, pin_memory=True @@ -302,11 +300,29 @@ def generate_fn(): ) writer_threads = [] - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) + # Create timer objects only if CUDA is available + use_cuda_timing = torch.cuda.is_available() + if use_cuda_timing: + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + else: + # Dummy no-op functions for CPU case + class DummyEvent: + def record(self): + pass + + def synchronize(self): + pass + + def elapsed_time(self, _): + return 0 + + start = end = DummyEvent() times = dataset.time() - for image_tar, image_lr, index, *lead_time_label in iter(data_loader): + for index, (image_tar, image_lr, *lead_time_label) in enumerate( + iter(data_loader) + ): time_index += 1 if dist.rank == 0: logger0.info(f"starting index: {time_index}") @@ -339,15 +355,17 @@ def generate_fn(): image_tar.cpu(), image_lr.cpu(), time_index, - index[0], + index, has_lead_time, ) ) end.record() end.synchronize() - elapsed_time = start.elapsed_time(end) / 1000.0 # Convert ms to s + elapsed_time = ( + start.elapsed_time(end) / 1000.0 if use_cuda_timing else 0 + ) # Convert ms to s timed_steps = time_index + 1 - warmup_steps - if dist.rank == 0: + if dist.rank == 0 and use_cuda_timing: average_time_per_batch_element = elapsed_time / timed_steps / batch_size logger.info( f"Total time to run {timed_steps} steps and {batch_size} members = {elapsed_time} s" diff --git a/examples/generative/corrdiff/helpers/train_helpers.py b/examples/generative/corrdiff/helpers/train_helpers.py index d4529ac821..218d6f1969 100644 --- a/examples/generative/corrdiff/helpers/train_helpers.py +++ b/examples/generative/corrdiff/helpers/train_helpers.py @@ -17,6 +17,7 @@ import torch import numpy as np from omegaconf import ListConfig +import warnings def set_patch_shape(img_shape, patch_shape): @@ -26,12 +27,21 @@ def set_patch_shape(img_shape, patch_shape): patch_shape_x = img_shape_x if (patch_shape_y is None) or (patch_shape_y > img_shape_y): patch_shape_y = img_shape_y - if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: + if patch_shape_x == img_shape_x and patch_shape_y == img_shape_y: + use_patching = False + else: + use_patching = True + if use_patching: if patch_shape_x != patch_shape_y: + warnings.warn( + f"You are using rectangular patches " + f"of shape {(patch_shape_y, patch_shape_x)}, " + f"which are an experimental feature." + ) raise NotImplementedError("Rectangular patch not supported yet") if patch_shape_x % 32 != 0 or patch_shape_y % 32 != 0: raise ValueError("Patch shape needs to be a multiple of 32") - return (img_shape_y, img_shape_x), (patch_shape_y, patch_shape_x) + return use_patching, (img_shape_y, img_shape_x), (patch_shape_y, patch_shape_x) def set_seed(rank): diff --git a/examples/generative/corrdiff/train.py b/examples/generative/corrdiff/train.py index f1a006e2b2..bce79e89b3 100644 --- a/examples/generative/corrdiff/train.py +++ b/examples/generative/corrdiff/train.py @@ -19,18 +19,26 @@ from omegaconf import DictConfig, OmegaConf from torch.nn.parallel import DistributedDataParallel from torch.utils.tensorboard import SummaryWriter +import wandb +from hydra.core.hydra_config import HydraConfig + from physicsnemo import Module -from physicsnemo.models.diffusion import UNet, EDMPrecondSR +from physicsnemo.models.diffusion import UNet, EDMPrecondSuperResolution from physicsnemo.distributed import DistributedManager -from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper -from physicsnemo.metrics.diffusion import RegressionLoss, ResLoss, RegressionLossCE + +from physicsnemo.metrics.diffusion import RegressionLoss, ResidualLoss, RegressionLossCE +from physicsnemo.utils.patching import RandomPatching2D + +from physicsnemo.launch.logging.wandb import initialize_wandb from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper from physicsnemo.launch.utils import ( load_checkpoint, save_checkpoint, get_checkpoint_dir, ) -from datasets.dataset import init_train_valid_datasets_from_config + +from datasets.dataset import init_train_valid_datasets_from_config, register_dataset + from helpers.train_helpers import ( set_patch_shape, set_seed, @@ -41,6 +49,23 @@ ) +def checkpoint_list(path, suffix=".mdlus"): + """Helper function to return sorted list, in ascending order, of checkpoints in a path""" + checkpoints = [] + for file in os.listdir(path): + if file.endswith(suffix): + # Split the filename and extract the index + try: + index = int(file.split(".")[-2]) + checkpoints.append((index, file)) + except ValueError: + continue + + # Sort by index and return filenames + checkpoints.sort(key=lambda x: x[0]) + return [file for _, file in checkpoints] + + # Train the CorrDiff model using the configurations in "conf/config_training.yaml" @hydra.main(version_base="1.2", config_path="conf", config_name="config_training") def main(cfg: DictConfig) -> None: @@ -54,10 +79,24 @@ def main(cfg: DictConfig) -> None: writer = SummaryWriter(log_dir="tensorboard") logger = PythonLogger("main") # General python logger logger0 = RankZeroLoggingWrapper(logger, dist) # Rank 0 logger + initialize_wandb( + project="Modulus-Launch", + entity="Modulus", + name=f"CorrDiff-Training-{HydraConfig.get().job.name}", + group="CorrDiff-DDP-Group", + mode=cfg.wandb.mode, + config=OmegaConf.to_container(cfg), + results_dir=cfg.wandb.results_dir, + ) # Resolve and parse configs OmegaConf.resolve(cfg) dataset_cfg = OmegaConf.to_container(cfg.dataset) # TODO needs better handling + + # Register custom dataset if specified in config + register_dataset(cfg.dataset.type) + logger0.info(f"Using dataset: {cfg.dataset.type}") + if hasattr(cfg, "validation"): train_test_split = True validation_dataset_cfg = OmegaConf.to_container(cfg.validation) @@ -86,7 +125,7 @@ def main(cfg: DictConfig) -> None: data_loader_kwargs = { "pin_memory": True, "num_workers": cfg.training.perf.dataloader_workers, - "prefetch_factor": 2, + "prefetch_factor": 2 if cfg.training.perf.dataloader_workers > 0 else None, } ( dataset, @@ -126,75 +165,31 @@ def main(cfg: DictConfig) -> None: patch_shape_x = None patch_shape_y = None patch_shape = (patch_shape_y, patch_shape_x) - img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) - if patch_shape != img_shape: + use_patching, img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) + if use_patching: + # Utility to perform patches extraction and batching + patching = RandomPatching2D( + img_shape=img_shape, + patch_shape=patch_shape, + patch_num=getattr(cfg.training.hp, "patch_num", 1), + ) logger0.info("Patch-based training enabled") else: + patching = None logger0.info("Patch-based training disabled") # interpolate global channel if patch-based model is used - if img_shape[1] != patch_shape[1]: + if use_patching: img_in_channels += dataset_channels # Instantiate the model and move to device. - if cfg.model.name not in ( - "regression", - "lt_aware_ce_regression", - "diffusion", - "patched_diffusion", - "lt_aware_patched_diffusion", - ): - raise ValueError("Invalid model") model_args = { # default parameters for all networks "img_out_channels": img_out_channels, "img_resolution": list(img_shape), "use_fp16": fp16, + "checkpoint_level": songunet_checkpoint_level, } - standard_model_cfgs = { # default parameters for different network types - "regression": { - "img_channels": 4, - "N_grid_channels": 4, - "embedding_type": "zero", - "checkpoint_level": songunet_checkpoint_level, - }, - "lt_aware_ce_regression": { - "img_channels": 4, - "N_grid_channels": 4, - "embedding_type": "zero", - "lead_time_channels": 4, - "lead_time_steps": 9, - "prob_channels": prob_channels, - "checkpoint_level": songunet_checkpoint_level, - "model_type": "SongUNetPosLtEmbd", - }, - "diffusion": { - "img_channels": img_out_channels, - "gridtype": "sinusoidal", - "N_grid_channels": 4, - "checkpoint_level": songunet_checkpoint_level, - }, - "patched_diffusion": { - "img_channels": img_out_channels, - "gridtype": "learnable", - "N_grid_channels": 100, - "checkpoint_level": songunet_checkpoint_level, - }, - "lt_aware_patched_diffusion": { - "img_channels": img_out_channels, - "gridtype": "learnable", - "N_grid_channels": 100, - "lead_time_channels": 20, - "lead_time_steps": 9, - "checkpoint_level": songunet_checkpoint_level, - "model_type": "SongUNetPosLtEmbd", - }, - } - model_args.update(standard_model_cfgs[cfg.model.name]) - if cfg.model.name in ( - "diffusion", - "patched_diffusion", - "lt_aware_patched_diffusion", - ): - model_args["scale_cond_input"] = cfg.model.scale_cond_input + if cfg.model.name == "lt_aware_ce_regression": + model_args["prob_channels"] = prob_channels if hasattr(cfg.model, "model_args"): # override defaults from config file model_args.update(OmegaConf.to_container(cfg.model.model_args)) if cfg.model.name == "regression": @@ -210,20 +205,36 @@ def main(cfg: DictConfig) -> None: **model_args, ) elif cfg.model.name == "lt_aware_patched_diffusion": - model = EDMPrecondSR( + model = EDMPrecondSuperResolution( img_in_channels=img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"], **model_args, ) - else: # diffusion or patched diffusion - model = EDMPrecondSR( + elif cfg.model.name == "diffusion": + model = EDMPrecondSuperResolution( img_in_channels=img_in_channels + model_args["N_grid_channels"], **model_args, ) + elif cfg.model.name == "patched_diffusion": + model = EDMPrecondSuperResolution( + img_in_channels=img_in_channels + model_args["N_grid_channels"], + **model_args, + ) + else: + raise ValueError(f"Invalid model: {cfg.model.name}") model.train().requires_grad_(True).to(dist.device) + # Check if regression model is used with patching + if ( + cfg.model.name in ["regression", "lt_aware_ce_regression"] + and patching is not None + ): + raise ValueError( + f"Regression model ({cfg.model.name}) cannot be used with patch-based training. " + ) + # Enable distributed data parallel if applicable if dist.world_size > 1: model = DistributedDataParallel( @@ -233,9 +244,14 @@ def main(cfg: DictConfig) -> None: output_device=dist.device, find_unused_parameters=dist.find_unused_parameters, ) + if cfg.wandb.watch_model and dist.rank == 0: + wandb.watch(model) # Load the regression checkpoint if applicable - if hasattr(cfg.training.io, "regression_checkpoint_path"): + if ( + hasattr(cfg.training.io, "regression_checkpoint_path") + and cfg.training.io.regression_checkpoint_path is not None + ): regression_checkpoint_path = to_absolute_path( cfg.training.io.regression_checkpoint_path ) @@ -248,19 +264,13 @@ def main(cfg: DictConfig) -> None: logger0.success("Loaded the pre-trained regression model") # Instantiate the loss function - patch_num = getattr(cfg.training.hp, "patch_num", 1) if cfg.model.name in ( "diffusion", "patched_diffusion", "lt_aware_patched_diffusion", ): - loss_fn = ResLoss( + loss_fn = ResidualLoss( regression_net=regression_net, - img_shape_x=img_shape[1], - img_shape_y=img_shape[0], - patch_shape_x=patch_shape[1], - patch_shape_y=patch_shape[0], - patch_num=patch_num, hr_mean_conditioning=cfg.model.hr_mean_conditioning, ) elif cfg.model.name == "regression": @@ -317,17 +327,20 @@ def main(cfg: DictConfig) -> None: optimizer.zero_grad(set_to_none=True) loss_accum = 0 for _ in range(num_accumulation_rounds): - img_clean, img_lr, labels, *lead_time_label = next(dataset_iterator) + img_clean, img_lr, *lead_time_label = next(dataset_iterator) img_clean = img_clean.to(dist.device).to(torch.float32).contiguous() img_lr = img_lr.to(dist.device).to(torch.float32).contiguous() - labels = labels.to(dist.device).contiguous() loss_fn_kwargs = { "net": model, "img_clean": img_clean, "img_lr": img_lr, - "labels": labels, "augment_pipe": None, } + # Sample new random patches for this iteration and add patching to + # loss arguments + if patching is not None: + patching.reset_patch_indices() + loss_fn_kwargs.update({"patching": patching}) if lead_time_label: lead_time_label = lead_time_label[0].to(dist.device).contiguous() loss_fn_kwargs.update({"lead_time_label": lead_time_label}) @@ -356,6 +369,12 @@ def main(cfg: DictConfig) -> None: writer.add_scalar( "training_loss_running_mean", average_loss_running_mean, cur_nimg ) + wandb.log( + { + "training_loss": average_loss, + "training_loss_running_mean": average_loss_running_mean, + } + ) ptt = is_time_for_periodic_task( cur_nimg, @@ -400,7 +419,7 @@ def main(cfg: DictConfig) -> None: ): with torch.no_grad(): for _ in range(cfg.training.io.validation_steps): - img_clean_valid, img_lr_valid, labels_valid = next( + img_clean_valid, img_lr_valid, *lead_time_label_valid = next( validation_dataset_iterator ) @@ -412,14 +431,20 @@ def main(cfg: DictConfig) -> None: img_lr_valid = ( img_lr_valid.to(dist.device).to(torch.float32).contiguous() ) - labels_valid = labels_valid.to(dist.device).contiguous() - loss_valid = loss_fn( - net=model, - img_clean=img_clean_valid, - img_lr=img_lr_valid, - labels=labels_valid, - augment_pipe=None, - ) + loss_valid_kwargs = { + "net": model, + "img_clean": img_clean_valid, + "img_lr": img_lr_valid, + "augment_pipe": None, + } + if lead_time_label_valid: + lead_time_label_valid = ( + lead_time_label_valid[0].to(dist.device).contiguous() + ) + loss_valid_kwargs.update( + {"lead_time_label": lead_time_label_valid} + ) + loss_valid = loss_fn(**loss_valid_kwargs) loss_valid = ( (loss_valid.sum() / batch_size_per_gpu).cpu().item() ) @@ -439,6 +464,11 @@ def main(cfg: DictConfig) -> None: writer.add_scalar( "validation_loss", average_valid_loss, cur_nimg ) + wandb.log( + { + "validation_loss": average_valid_loss, + } + ) if is_time_for_periodic_task( cur_nimg, @@ -470,7 +500,8 @@ def main(cfg: DictConfig) -> None: f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}" ] logger0.info(" ".join(fields)) - torch.cuda.reset_peak_memory_stats() + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() # Save checkpoints if dist.world_size > 1: @@ -490,6 +521,14 @@ def main(cfg: DictConfig) -> None: epoch=cur_nimg, ) + # Retain only the recent n checkpoints, if desired + if cfg.training.io.save_n_recent_checkpoints > 0: + for suffix in [".mdlus", ".pt"]: + ckpts = checkpoint_list(checkpoint_dir, suffix=suffix) + while len(ckpts) > cfg.training.io.save_n_recent_checkpoints: + os.remove(os.path.join(checkpoint_dir, ckpts[0])) + ckpts = ckpts[1:] + # Done. logger0.info("Training Completed.") diff --git a/physicsnemo/metrics/diffusion/__init__.py b/physicsnemo/metrics/diffusion/__init__.py index 8673ce8eb2..f00d88a026 100644 --- a/physicsnemo/metrics/diffusion/__init__.py +++ b/physicsnemo/metrics/diffusion/__init__.py @@ -20,7 +20,7 @@ EDMLossSR, RegressionLoss, RegressionLossCE, - ResLoss, + ResidualLoss, VELoss, VELoss_dfsr, VPLoss, diff --git a/physicsnemo/metrics/diffusion/loss.py b/physicsnemo/metrics/diffusion/loss.py index 18dde13b51..6d51e8bb8c 100644 --- a/physicsnemo/metrics/diffusion/loss.py +++ b/physicsnemo/metrics/diffusion/loss.py @@ -18,11 +18,13 @@ """Loss functions used in the paper "Elucidating the Design Space of Diffusion-Based Generative Models".""" -import random -from typing import Callable, Optional, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch +from torch import Tensor + +from physicsnemo.utils.patching import RandomPatching2D class VPLoss: @@ -333,7 +335,7 @@ def __call__(self, net, img_clean, img_lr, labels=None, augment_pipe=None): sigma = (rnd_normal * self.P_std + self.P_mean).exp() weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 - # augment for conditional generaiton + # augment for conditional generation img_tot = torch.cat((img_clean, img_lr), dim=1) y_tot, augment_labels = ( augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None) @@ -349,16 +351,13 @@ def __call__(self, net, img_clean, img_lr, labels=None, augment_pipe=None): class RegressionLoss: """ - Regression loss function for the U-Net for deterministic predictions. + Regression loss function for the deterministic predictions. + Note: this loss does not apply any reduction. - Parameters + Attributes ---------- - P_mean: float, optional - Mean value for `sigma` computation, by default -1.2. - P_std: float, optional: - Standard deviation for `sigma` computation, by default 1.2. - sigma_data: float, optional - Standard deviation for data, by default 0.5. + sigma_data: float + Standard deviation for data. Deprecated and ignored. Note ---- @@ -368,43 +367,68 @@ class RegressionLoss: arXiv preprint arXiv:2309.15214. """ - def __init__( - self, P_mean: float = -1.2, P_std: float = 1.2, sigma_data: float = 0.5 - ): - self.P_mean = P_mean - self.P_std = P_std - self.sigma_data = sigma_data + def __init__(self): + """ + Arguments + ---------- + """ + return - def __call__(self, net, img_clean, img_lr, labels=None, augment_pipe=None): + def __call__( + self, + net: torch.nn.Module, + img_clean: torch.Tensor, + img_lr: torch.Tensor, + augment_pipe: Optional[ + Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]] + ] = None, + ) -> torch.Tensor: """ - Calculate and return the loss for the U-Net for deterministic predictions. + Calculate and return the regression loss for + deterministic predictions. - Parameters: + Parameters ---------- - net: torch.nn.Module + net : torch.nn.Module The neural network model that will make predictions. + Expected signature: `net(x, img_lr, + augment_labels=augment_labels, force_fp32=False)`, where: + x (torch.Tensor): Tensor of shape (B, C_hr, H, W). Is zero-filled. + img_lr (torch.Tensor): Low-resolution input of shape (B, C_lr, H, W) + augment_labels (torch.Tensor, optional): Optional augmentation + labels, returned by `augment_pipe`. + force_fp32 (bool, optional): Whether to force the model to use + fp32, by default False. + Returns: + torch.Tensor: Predictions of shape (B, C_hr, H, W) + + img_clean : torch.Tensor + High-resolution input images of shape (B, C_hr, H, W). + Used as ground truth and for data augmentation if 'augment_pipe' is provided. + + img_lr : torch.Tensor + Low-resolution input images of shape (B, C_lr, H, W). + Used as input to the neural network. + + augment_pipe : callable, optional + An optional data augmentation function. + Expected signature: + img_tot (torch.Tensor): Concatenated high and low resolution + images of shape (B, C_hr+C_lr, H, W) + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - Augmented images of shape (B, C_hr+C_lr, H, W) + - Optional augmentation labels - img_clean: torch.Tensor - Input images (high resolution) to the neural network. - - img_lr: torch.Tensor - Input images (low resolution) to the neural network. - - labels: torch.Tensor - Ground truth labels for the input images. - - augment_pipe: callable, optional - An optional data augmentation function that takes images as input and - returns augmented images. If not provided, no data augmentation is applied. - - Returns: + Returns ------- torch.Tensor - A tensor representing the loss calculated based on the network's - predictions. + A tensor representing the per-sample element-wise squared + difference between the network's predictions and the high + resolution images `img_clean` (possibly data-augmented by + `augment_pipe`). + Shape: (B, C_hr, H, W), same as `img_clean`. """ - rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device) - sigma = (rnd_normal * self.P_std + self.P_mean).exp() weight = ( 1.0 # (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 ) @@ -416,100 +440,213 @@ def __call__(self, net, img_clean, img_lr, labels=None, augment_pipe=None): y = y_tot[:, : img_clean.shape[1], :, :] y_lr = y_tot[:, img_clean.shape[1] :, :, :] - input = torch.zeros_like(y, device=img_clean.device) - D_yn = net(input, y_lr, sigma, labels, augment_labels=augment_labels) + zero_input = torch.zeros_like(y, device=img_clean.device) + D_yn = net(zero_input, y_lr, force_fp32=False, augment_labels=augment_labels) loss = weight * ((D_yn - y) ** 2) return loss -class ResLoss: +class ResidualLoss: """ Mixture loss function for denoising score matching. - Parameters + This class implements a loss function that combines deterministic + regression with denoising score matching. It uses a pre-trained regression + network to compute residuals before applying the diffusion process. + + Attributes ---------- - P_mean: float, optional - Mean value for `sigma` computation, by default -1.2. - P_std: float, optional: - Standard deviation for `sigma` computation, by default 1.2. - sigma_data: float, optional - Standard deviation for data, by default 0.5. + regression_net : torch.nn.Module + The regression network used for computing residuals. + P_mean : float + Mean value for noise level computation. + P_std : float + Standard deviation for noise level computation. + sigma_data : float + Standard deviation for data weighting. + hr_mean_conditioning : bool + Flag indicating whether to use high-resolution mean for conditioning. Note ---- Reference: Mardani, M., Brenowitz, N., Cohen, Y., Pathak, J., Chen, C.Y., - Liu, C.C.,Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. - Generative Residual Diffusion Modeling for Km-scale Atmospheric Downscaling. - arXiv preprint arXiv:2309.15214. + Liu, C.C., Vahdat, A., Kashinath, K., Kautz, J. and Pritchard, M., 2023. + Generative Residual Diffusion Modeling for Km-scale Atmospheric + Downscaling. arXiv preprint arXiv:2309.15214. """ def __init__( self, - regression_net, - img_shape_x, - img_shape_y, - patch_shape_x, - patch_shape_y, - patch_num, + regression_net: torch.nn.Module, P_mean: float = 0.0, P_std: float = 1.2, sigma_data: float = 0.5, hr_mean_conditioning: bool = False, ): - self.unet = regression_net + """ + Arguments + ---------- + regression_net : torch.nn.Module + Pre-trained regression network used to compute residuals. + Expected signature: `net(zero_input, y_lr, + lead_time_label=lead_time_label, augment_labels=augment_labels)` or + `net(zero_input, y_lr, augment_labels=augment_labels)`, where: + zero_input (torch.Tensor): Zero tensor of shape (B, C_hr, H, W) + y_lr (torch.Tensor): Low-resolution input of shape (B, C_lr, H, W) + lead_time_label (torch.Tensor, optional): Optional lead time labels + augment_labels (torch.Tensor, optional): Optional augmentation labels + Returns: + torch.Tensor: Predictions of shape (B, C_hr, H, W) + + P_mean : float, optional + Mean value for noise level computation, by default 0.0. + + P_std : float, optional + Standard deviation for noise level computation, by default 1.2. + + sigma_data : float, optional + Standard deviation for data weighting, by default 0.5. + + hr_mean_conditioning : bool, optional + Whether to use high-resolution mean for conditioning predicted, by default False. + When True, the mean prediction from `regression_net` is channel-wise + concatenated with `img_lr` for conditioning. + """ + self.regression_net = regression_net self.P_mean = P_mean self.P_std = P_std self.sigma_data = sigma_data - self.img_shape_x = img_shape_x - self.img_shape_y = img_shape_y - self.patch_shape_x = patch_shape_x - self.patch_shape_y = patch_shape_y - self.patch_num = patch_num self.hr_mean_conditioning = hr_mean_conditioning def __call__( self, - net, - img_clean, - img_lr, - labels=None, - lead_time_label=None, - augment_pipe=None, - ): + net: torch.nn.Module, + img_clean: Tensor, + img_lr: Tensor, + patching: Optional[RandomPatching2D] = None, + lead_time_label: Optional[Tensor] = None, + augment_pipe: Optional[ + Callable[[Tensor], Tuple[Tensor, Optional[Tensor]]] + ] = None, + ) -> Tensor: """ Calculate and return the loss for denoising score matching. - Parameters: - ---------- - net: torch.nn.Module - The neural network model that will make predictions. + This method computes a mixture loss that combines deterministic + regression with denoising score matching. It first computes residuals + using the regression network, then applies the diffusion process to + these residuals. + + In addition to the standard denoising score matching loss, this method + also supports optional patching for multi-diffusion. In this case, the spatial + dimensions of the input are decomposed into `P` smaller patches of shape + (H_patch, W_patch), that are grouped along the batch dimension, and the + model is applied to each patch individually. In the following, if `patching` + is not provided, then the input is not patched and `P=1` and `(H_patch, + W_patch) = (H, W)`. When patching is used, the original non-patched conditioning is + interpolated onto a spatial grid of shape `(H_patch, W_patch)` and channel-wise + concatenated to the patched conditioning. This ensures that each patch + maintains global information from the entire domain. + + The diffusion model `net` is expected to be conditioned on an input with + `C_cond` channels, which should be: + - `C_cond = C_lr` if `hr_mean_conditioning` is `False` and + `patching` is None. + - `C_cond = C_hr + C_lr` if `hr_mean_conditioning` is `True` and + `patching` is None. + - `C_cond = C_hr + 2*C_lr` if `hr_mean_conditioning` is `True` and + `patching` is not None. + - `C_cond = 2*C_lr` if `hr_mean_conditioning` is `False` and + `patching` is not None. + Additionally, `C_cond` should also include any embedding channels, + such as positional embeddings or time embeddings. + + Note: this loss function does not apply any reduction. - img_clean: torch.Tensor - Input images (high resolution) to the neural network. - - img_lr: torch.Tensor - Input images (low resolution) to the neural network. - - labels: torch.Tensor - Ground truth labels for the input images. - - augment_pipe: callable, optional - An optional data augmentation function that takes images as input and - returns augmented images. If not provided, no data augmentation is applied. + Parameters + ---------- + net : torch.nn.Module + The neural network model for the diffusion process. + Expected signature: `net(latent, y_lr, sigma, + embedding_selector=embedding_selector, lead_time_label=lead_time_label, + augment_labels=augment_labels)`, where: + latent (torch.Tensor): Noisy input of shape (B[*P], C_hr, H_patch, W_patch) + y_lr (torch.Tensor): Conditioning of shape (B[*P], C_cond, H_patch, W_patch) + sigma (torch.Tensor): Noise level of shape (B[*P], 1, 1, 1) + embedding_selector (callable, optional): Function to select + positional embeddings. Only used if `patching` is provided. + lead_time_label (torch.Tensor, optional): Lead time labels. + augment_labels (torch.Tensor, optional): Augmentation labels + Returns: + torch.Tensor: Predictions of shape (B[*P], C_hr, H_patch, W_patch) + + img_clean : torch.Tensor + High-resolution input images of shape (B, C_hr, H, W). + Used as ground truth and for data augmentation if 'augment_pipe' is provided. + + img_lr : torch.Tensor + Low-resolution input images of shape (B, C_lr, H, W). + Used as input to the regression network and conditioning for the + diffusion process. + + patching : Optional[RandomPatching2D], optional + Patching strategy for processing large images, by default None. See + :class:`physicsnemo.utils.patching.RandomPatching2D` for details. + When provided, the patching strategy is used for both image patches + and positional embeddings selection in the diffusion model `net`. + Transforms tensors from shape (B, C, H, W) to (B*P, C, H_patch, + W_patch). + + lead_time_label : Optional[torch.Tensor], optional + Labels for lead-time aware predictions, by default None. + Shape can vary based on model requirements, typically (B,) or scalar. + + augment_pipe : Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]]] + Data augmentation function. + Expected signature: + img_tot (torch.Tensor): Concatenated high and low resolution images + of shape (B, C_hr+C_lr, H, W) + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - Augmented images of shape (B, C_hr+C_lr, H, W) + - Optional augmentation labels - Returns: + Returns ------- torch.Tensor - A tensor representing the loss calculated based on the network's - predictions. + If patching is not used: + A tensor of shape (B, C_hr, H, W) representing the per-sample loss. + If patching is used: + A tensor of shape (B*P, C_hr, H_patch, W_patch) representing + the per-patch loss. + + Raises + ------ + ValueError + If patching is provided but is not an instance of RandomPatching2D. + If shapes of img_clean and img_lr are incompatible. """ + # Safety check: enforce patching object + if patching and not isinstance(patching, RandomPatching2D): + raise ValueError("patching must be a 'RandomPatching2D' object.") + # Safety check: enforce shapes + if ( + img_clean.shape[0] != img_lr.shape[0] + or img_clean.shape[2:] != img_lr.shape[2:] + ): + raise ValueError( + f"Shape mismatch between img_clean {img_clean.shape} and " + f"img_lr {img_lr.shape}. " + f"Batch size, height and width must match." + ) + rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device) sigma = (rnd_normal * self.P_std + self.P_mean).exp() weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 - # augment for conditional generaiton + # augment for conditional generation img_tot = torch.cat((img_clean, img_lr), dim=1) y_tot, augment_labels = ( augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None) @@ -517,31 +654,20 @@ def __call__( y = y_tot[:, : img_clean.shape[1], :, :] y_lr = y_tot[:, img_clean.shape[1] :, :, :] y_lr_res = y_lr - - # global index - b = y.shape[0] - Nx = torch.arange(self.img_shape_x).int() - Ny = torch.arange(self.img_shape_y).int() - grid = torch.stack(torch.meshgrid(Ny, Nx, indexing="ij"), dim=0)[ - None, - ].expand(b, -1, -1, -1) + batch_size = y.shape[0] # form residual if lead_time_label is not None: - y_mean = self.unet( + y_mean = self.regression_net( torch.zeros_like(y, device=img_clean.device), y_lr_res, - sigma, - labels, lead_time_label=lead_time_label, augment_labels=augment_labels, ) else: - y_mean = self.unet( + y_mean = self.regression_net( torch.zeros_like(y, device=img_clean.device), y_lr_res, - sigma, - labels, augment_labels=augment_labels, ) @@ -549,82 +675,35 @@ def __call__( if self.hr_mean_conditioning: y_lr = torch.cat((y_mean, y_lr), dim=1).contiguous() - global_index = None + # patchified training # conditioning: cat(y_mean, y_lr, input_interp, pos_embd), 4+12+100+4 - if ( - self.img_shape_x != self.patch_shape_x - or self.img_shape_y != self.patch_shape_y - ): - c_in = y_lr.shape[1] - c_out = y.shape[1] - rnd_normal = torch.randn( - [img_clean.shape[0] * self.patch_num, 1, 1, 1], device=img_clean.device - ) - sigma = (rnd_normal * self.P_std + self.P_mean).exp() - weight = (sigma**2 + self.sigma_data**2) / ( - sigma * self.sigma_data - ) ** 2 - - # global interpolation - input_interp = torch.nn.functional.interpolate( - img_lr, - (self.patch_shape_y, self.patch_shape_x), - mode="bilinear", - ) + if patching: + # Patched residual + # (batch_size * patch_num, c_out, patch_shape_y, patch_shape_x) + y_patched = patching.apply(input=y) + # Patched conditioning on y_lr and interp(img_lr) + # (batch_size * patch_num, 2*c_in, patch_shape_y, patch_shape_x) + y_lr_patched = patching.apply(input=y_lr, additional_input=img_lr) + + # Function to select the correct positional embedding for each + # patch + def patch_embedding_selector(emb): + # emb: (N_pe, image_shape_y, image_shape_x) + # return: (batch_size * patch_num, N_pe, patch_shape_y, patch_shape_x) + return patching.apply(emb[None].expand(batch_size, -1, -1, -1)) + + y = y_patched + y_lr = y_lr_patched + else: + patch_embedding_selector = None - # patch generation from a single sample (not from random samples due to memory consumption of regression) - y_new = torch.zeros( - b * self.patch_num, - c_out, - self.patch_shape_y, - self.patch_shape_x, - device=img_clean.device, - ) - y_lr_new = torch.zeros( - b * self.patch_num, - c_in + input_interp.shape[1], - self.patch_shape_y, - self.patch_shape_x, - device=img_clean.device, - ) - global_index = torch.zeros( - b * self.patch_num, - 2, - self.patch_shape_y, - self.patch_shape_x, - dtype=torch.int, - device=img_clean.device, - ) - for i in range(self.patch_num): - rnd_x = random.randint(0, self.img_shape_x - self.patch_shape_x) - rnd_y = random.randint(0, self.img_shape_y - self.patch_shape_y) - y_new[b * i : b * (i + 1),] = y[ - :, - :, - rnd_y : rnd_y + self.patch_shape_y, - rnd_x : rnd_x + self.patch_shape_x, - ] - global_index[b * i : b * (i + 1),] = grid[ - :, - :, - rnd_y : rnd_y + self.patch_shape_y, - rnd_x : rnd_x + self.patch_shape_x, - ] - y_lr_new[b * i : b * (i + 1),] = torch.cat( - ( - y_lr[ - :, - :, - rnd_y : rnd_y + self.patch_shape_y, - rnd_x : rnd_x + self.patch_shape_x, - ], - input_interp, - ), - 1, - ) - y = y_new - y_lr = y_lr_new + # Noise + rnd_normal = torch.randn([y.shape[0], 1, 1, 1], device=img_clean.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + + # Input + noise latent = y + torch.randn_like(y) * sigma if lead_time_label is not None: @@ -632,8 +711,7 @@ def __call__( latent, y_lr, sigma, - labels, - global_index=global_index, + embedding_selector=patch_embedding_selector, lead_time_label=lead_time_label, augment_labels=augment_labels, ) @@ -642,8 +720,7 @@ def __call__( latent, y_lr, sigma, - labels, - global_index=global_index, + embedding_selector=patch_embedding_selector, augment_labels=augment_labels, ) loss = weight * ((D_yn - y) ** 2) @@ -792,20 +869,19 @@ def __call__(self, net, images, labels, augment_pipe=None): class RegressionLossCE: """ - A regression loss function for the GEFS-HRRR model with probability channels, adapted - from RegressionLoss. In this version, probability channels are evaluated using - CrossEntropyLoss instead of MSELoss. - - Parameters + A regression loss function for deterministic predictions with probability + channels and lead time labels. Adapted from + :class:`physicsnemo.metrics.diffusion.loss.RegressionLoss`. In this version, + probability channels are evaluated using CrossEntropyLoss instead of + squared error. + Note: this loss does not apply any reduction. + + Attributes ---------- - P_mean: float, optional - Mean value for `sigma` computation, by default -1.2. - P_std: float, optional: - Standard deviation for `sigma` computation, by default 1.2. - sigma_data: float, optional - Standard deviation for data, by default 0.5. - prob_channels: list, optional - A index list of output probability channels. + entropy : torch.nn.CrossEntropyLoss + Cross entropy loss function used for probability channels. + prob_channels : list[int] + List of channel indices to be treated as probability channels. Note ---- @@ -817,62 +893,86 @@ class RegressionLossCE: def __init__( self, - P_mean: float = -1.2, - P_std: float = 1.2, - sigma_data: float = 0.5, - prob_channels: list = [4, 5, 6, 7, 8], + prob_channels: list[int] = [4, 5, 6, 7, 8], ): - self.P_mean = P_mean - self.P_std = P_std - self.sigma_data = sigma_data + """ + Arguments + ---------- + prob_channels: list[int], optional + List of channel indices from the target tensor to be treated as + probability channels. Cross entropy loss is computed over these + channels, while the remaining channels are treated as scalar + channels and the squared error loss is computed over them. By + default, [4, 5, 6, 7, 8]. + """ self.entropy = torch.nn.CrossEntropyLoss(reduction="none") self.prob_channels = prob_channels def __call__( self, - net, - img_clean, - img_lr, - lead_time_label=None, - labels=None, - augment_pipe=None, - ): + net: torch.nn.Module, + img_clean: torch.Tensor, + img_lr: torch.Tensor, + lead_time_label: Optional[torch.Tensor] = None, + augment_pipe: Optional[ + Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]] + ] = None, + ) -> torch.Tensor: """ - Calculate and return the loss for the U-Net for deterministic predictions. + Calculate and return the loss for deterministic + predictions, treating specific channels as probability distributions. - Parameters: + Parameters ---------- - net: torch.nn.Module + net : torch.nn.Module The neural network model that will make predictions. + Expected signature: `net(input, img_lr, lead_time_label=lead_time_label, augment_labels=augment_labels)`, + where: + input (torch.Tensor): Tensor of shape (B, C_hr, H, W). Zero-filled. + y_lr (torch.Tensor): Low-resolution input of shape (B, C_lr, H, W) + lead_time_label (torch.Tensor, optional): Optional lead time + labels. If provided, should be of shape (B,). + augment_labels (torch.Tensor, optional): Optional augmentation + labels, returned by `augment_pipe`. + Returns: + torch.Tensor: Predictions of shape (B, C_hr, H, W) + + img_clean : torch.Tensor + High-resolution input images of shape (B, C_hr, H, W). + Used as ground truth and for data augmentation if `augment_pipe` is provided. + + img_lr : torch.Tensor + Low-resolution input images of shape (B, C_lr, H, W). + Used as input to the neural network. + + lead_time_label : Optional[torch.Tensor], optional + Lead time labels for temporal predictions, by default None. + Shape can vary based on model requirements, typically (B,) or scalar. + + augment_pipe : Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]]] + Data augmentation function. + Expected signature: + img_tot (torch.Tensor): Concatenated high and low resolution + images of shape (B, C_hr+C_lr, H, W). + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - Augmented images of shape (B, C_hr+C_lr, H, W) + - Optional augmentation labels - img_clean: torch.Tensor - Input images (high resolution) to the neural network. - - img_lr: torch.Tensor - Input images (low resolution) to the neural network. - - lead_time_label: torch.Tensor - Lead time labels for input batches. - - labels: torch.Tensor - Ground truth labels for the input images. - - augment_pipe: callable, optional - An optional data augmentation function that takes images as input and - returns augmented images. If not provided, no data augmentation is applied. - - Returns: + Returns ------- torch.Tensor - A tensor representing the loss calculated based on the network's - predictions. + A tensor of shape (B, C_loss, H, W) representing the pixel-wise + loss., where `C_loss = C_hr - len(prob_channels) + 1`. More + specifically, the last channel of the output tensor corresponds to + the cross-entropy loss computed over the channels specified in + `prob_channels`, while the first `C_hr - len(prob_channels)` + channels of the output tensor correspond to the squared error loss. """ all_channels = list(range(img_clean.shape[1])) # [0, 1, 2, ..., 10] scalar_channels = [ item for item in all_channels if item not in self.prob_channels ] - rnd_normal = torch.randn([img_clean.shape[0], 1, 1, 1], device=img_clean.device) - sigma = (rnd_normal * self.P_std + self.P_mean).exp() weight = ( 1.0 # (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 ) @@ -890,8 +990,6 @@ def __call__( D_yn = net( input, y_lr, - sigma, - labels, lead_time_label=lead_time_label, augment_labels=augment_labels, ) @@ -899,11 +997,10 @@ def __call__( D_yn = net( input, y_lr, - sigma, - labels, + lead_time_label=lead_time_label, augment_labels=augment_labels, ) - loss1 = weight * ((D_yn[:, scalar_channels] - y[:, scalar_channels]) ** 2) + loss1 = weight * (D_yn[:, scalar_channels] - y[:, scalar_channels]) ** 2 loss2 = ( weight * self.entropy(D_yn[:, self.prob_channels], y[:, self.prob_channels])[ diff --git a/physicsnemo/models/diffusion/__init__.py b/physicsnemo/models/diffusion/__init__.py index 3984bffd42..cee59e6754 100644 --- a/physicsnemo/models/diffusion/__init__.py +++ b/physicsnemo/models/diffusion/__init__.py @@ -29,7 +29,7 @@ from .unet import UNet, StormCastUNet from .preconditioning import ( EDMPrecond, - EDMPrecondSR, + EDMPrecondSuperResolution, VEPrecond, VPPrecond, iDDPMPrecond, diff --git a/physicsnemo/models/diffusion/preconditioning.py b/physicsnemo/models/diffusion/preconditioning.py index 52a1660804..cbc04f4f75 100644 --- a/physicsnemo/models/diffusion/preconditioning.py +++ b/physicsnemo/models/diffusion/preconditioning.py @@ -20,18 +20,13 @@ """ import importlib -import warnings from dataclasses import dataclass -from typing import List, Union +from typing import List, Literal, Tuple, Union import numpy as np import nvtx import torch -from physicsnemo.models.diffusion import ( - DhariwalUNet, # noqa: F401 for globals - SongUNet, # noqa: F401 for globals -) from physicsnemo.models.meta import ModelMetaData from physicsnemo.models.module import Module @@ -694,10 +689,10 @@ def round_sigma(sigma: Union[float, List, torch.Tensor]): @dataclass -class EDMPrecondSRMetaData(ModelMetaData): - """EDMPrecondSR meta data""" +class EDMPrecondSuperResolutionMetaData(ModelMetaData): + """EDMPrecondSuperResolution meta data""" - name: str = "EDMPrecondSR" + name: str = "EDMPrecondSuperResolution" # Optimization jit: bool = False cuda_graphs: bool = False @@ -713,33 +708,50 @@ class EDMPrecondSRMetaData(ModelMetaData): auto_grad: bool = False -class EDMPrecondSR(Module): +class EDMPrecondSuperResolution(Module): """ Improved preconditioning proposed in the paper "Elucidating the Design Space of - Diffusion-Based Generative Models" (EDM) for super-resolution tasks + Diffusion-Based Generative Models" (EDM). + + This is a variant of `EDMPrecond` that is specifically designed for super-resolution + tasks. It wraps a neural network that predicts the denoised high-resolution image + given a noisy high-resolution image, and additional conditioning that includes a + low-resolution image, and a noise level. Parameters ---------- - img_resolution : int - Image resolution. - img_channels : int - Number of color channels. + img_resolution : Union[int, Tuple[int, int]] + Spatial resolution `(H, W)` of the image. If a single int is provided, + the image is assumed to be square. img_in_channels : int - Number of input color channels. + Number of input channels in the low-resolution input image. img_out_channels : int - Number of output color channels. - use_fp16 : bool - Execute the underlying model at FP16 precision?, by default False. - sigma_min : float + Number of output channels in the high-resolution output image. + use_fp16 : bool, optional + Whether to use half-precision floating point (FP16) for model execution, + by default False. + model_type : str, optional + Class name of the underlying model. Must be one of the following: + 'SongUNet', 'SongUNetPosEmbd', 'SongUNetPosLtEmbd', 'DhariwalUNet'. + Defaults to 'SongUNetPosEmbd'. + sigma_data : float, optional + Expected standard deviation of the training data, by default 0.5. + sigma_min : float, optional Minimum supported noise level, by default 0.0. - sigma_max : float + sigma_max : float, optional Maximum supported noise level, by default inf. - sigma_data : float - Expected standard deviation of the training data, by default 0.5. - model_type :str - Class name of the underlying model, by default "SongUNetPosEmbd". **model_kwargs : dict - Keyword arguments for the underlying model. + Keyword arguments passed to the underlying model `__init__` method. + + See Also + -------- + For information on model types and their usage: + :class:`~physicsnemo.models.diffusion.SongUNet`: Basic U-Net for diffusion models + :class:`~physicsnemo.models.diffusion.SongUNetPosEmbd`: U-Net with positional embeddings + :class:`~physicsnemo.models.diffusion.SongUNetPosLtEmbd`: U-Net with positional and lead-time embeddings + + Please refer to the documentation of these classes for details on how to call + and use these models directly. Note ---- @@ -755,28 +767,26 @@ class EDMPrecondSR(Module): def __init__( self, - img_resolution, - img_channels, - img_in_channels, - img_out_channels, - use_fp16=False, + img_resolution: Union[int, Tuple[int, int]], + img_in_channels: int, + img_out_channels: int, + use_fp16: bool = False, + model_type: Literal[ + "SongUNetPosEmbd", "SongUNetPosLtEmbd", "SongUNet", "DhariwalUNet" + ] = "SongUNetPosEmbd", + sigma_data: float = 0.5, sigma_min=0.0, sigma_max=float("inf"), - sigma_data=0.5, - model_type="SongUNetPosEmbd", - scale_cond_input=True, - **model_kwargs, + **model_kwargs: dict, ): - super().__init__(meta=EDMPrecondSRMetaData) + super().__init__(meta=EDMPrecondSuperResolutionMetaData) self.img_resolution = img_resolution - self.img_channels = img_channels # TODO: this is not used, remove it self.img_in_channels = img_in_channels self.img_out_channels = img_out_channels self.use_fp16 = use_fp16 + self.sigma_data = sigma_data self.sigma_min = sigma_min self.sigma_max = sigma_max - self.sigma_data = sigma_data - self.scale_cond_input = scale_cond_input model_class = getattr(network_module, model_type) self.model = model_class( @@ -785,38 +795,74 @@ def __init__( out_channels=img_out_channels, **model_kwargs, ) # TODO needs better handling - self.scaling_fn = self._get_scaling_fn() - - def _get_scaling_fn(self): - if self.scale_cond_input: - warnings.warn( - "scale_cond_input=True does not properly scale the conditional input. " - "(see https://github.com/NVIDIA/modulus/issues/229). " - "This setup will be deprecated. " - "Please set scale_cond_input=False.", - DeprecationWarning, - ) - return self._legacy_scaling_fn - else: - return self._scaling_fn + self.scaling_fn = self._scaling_fn @staticmethod - def _scaling_fn(x, img_lr, c_in): - return torch.cat([c_in * x, img_lr.to(x.dtype)], dim=1) + def _scaling_fn( + x: torch.Tensor, img_lr: torch.Tensor, c_in: torch.Tensor + ) -> torch.Tensor: + """ + Scale input tensors by first scaling the high-resolution tensor and then + concatenating with the low-resolution tensor. - @staticmethod - def _legacy_scaling_fn(x, img_lr, c_in): - return c_in * torch.cat([x, img_lr.to(x.dtype)], dim=1) + Parameters + ---------- + x : torch.Tensor + Noisy high-resolution image of shape (B, C_hr, H, W). + img_lr : torch.Tensor + Low-resolution image of shape (B, C_lr, H, W). + c_in : torch.Tensor + Scaling factor of shape (B, 1, 1, 1). + + Returns + ------- + torch.Tensor + Scaled and concatenated tensor of shape (B, C_in+C_out, H, W). + """ + return torch.cat([c_in * x, img_lr.to(x.dtype)], dim=1) - @nvtx.annotate(message="EDMPrecondSR", color="orange") + @nvtx.annotate(message="EDMPrecondSuperResolution", color="orange") def forward( self, - x, - img_lr, - sigma, - force_fp32=False, - **model_kwargs, - ): + x: torch.Tensor, + img_lr: torch.Tensor, + sigma: torch.Tensor, + force_fp32: bool = False, + **model_kwargs: dict, + ) -> torch.Tensor: + """ + Forward pass of the EDMPrecondSuperResolution model wrapper. + + This method applies the EDM preconditioning to compute the denoised image + from a noisy high-resolution image and low-resolution conditioning image. + + Parameters + ---------- + x : torch.Tensor + Noisy high-resolution image of shape (B, C_hr, H, W). The number of + channels `C_hr` should be equal to `img_out_channels`. + img_lr : torch.Tensor + Low-resolution conditioning image of shape (B, C_lr, H, W). The number + of channels `C_lr` should be equal to `img_in_channels`. + sigma : torch.Tensor + Noise level of shape (B) or (B, 1) or (B, 1, 1, 1). + force_fp32 : bool, optional + Whether to force FP32 precision regardless of the `use_fp16` attribute, + by default False. + **model_kwargs : dict + Additional keyword arguments to pass to the underlying model + `self.model` forward method. + + Returns + ------- + torch.Tensor + Denoised high-resolution image of shape (B, C_hr, H, W). + + Raises + ------ + ValueError + If the model output dtype doesn't match the expected dtype. + """ # Concatenate input channels x = x.to(torch.float32) sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) @@ -853,10 +899,23 @@ def forward( return D_x @staticmethod - def round_sigma(sigma: Union[float, List, torch.Tensor]): + def round_sigma(sigma: Union[float, List, torch.Tensor]) -> torch.Tensor: """ Convert a given sigma value(s) to a tensor representation. - See EDMPrecond.round_sigma + + Parameters + ---------- + sigma : Union[float, List, torch.Tensor] + Sigma value(s) to convert. + + Returns + ------- + torch.Tensor + Tensor representation of sigma values. + + See Also + -------- + EDMPrecond.round_sigma """ return EDMPrecond.round_sigma(sigma) @@ -910,7 +969,8 @@ def __init__( self.img_channels = img_channels self.label_dim = label_dim self.use_fp16 = use_fp16 - self.model = globals()[model_type]( + model_class = getattr(network_module, model_type) + self.model = model_class( img_resolution=img_resolution, in_channels=self.img_channels, out_channels=img_channels, @@ -1009,7 +1069,8 @@ def __init__( self.img_channels = img_channels self.label_dim = label_dim self.use_fp16 = use_fp16 - self.model = globals()[model_type]( + model_class = getattr(network_module, model_type) + self.model = model_class( img_resolution=img_resolution, in_channels=model_kwargs["model_channels"] * 2, out_channels=img_channels, diff --git a/physicsnemo/models/diffusion/song_unet.py b/physicsnemo/models/diffusion/song_unet.py index d38484ba28..f5eeaaf517 100644 --- a/physicsnemo/models/diffusion/song_unet.py +++ b/physicsnemo/models/diffusion/song_unet.py @@ -20,7 +20,7 @@ """ from dataclasses import dataclass -from typing import List, Union +from typing import Callable, List, Optional, Union import numpy as np import nvtx @@ -71,7 +71,8 @@ class SongUNet(Module): Parameters ----------- img_resolution : Union[List[int], int] - The resolution of the input/output image, 1 value represents a square image. + The resolution of the input/output image. Can be a single int for square images + or a list [height, width] for rectangular images. in_channels : int Number of channels in the input image. out_channels : int @@ -81,7 +82,7 @@ class SongUNet(Module): augment_dim : int, optional Dimensionality of augmentation labels; 0 means no augmentation. By default 0. model_channels : int, optional - Base multiplier for the number of channels across the network, by default 128. + Base multiplier for the number of channels across the network. By default 128. channel_mult : List[int], optional Per-resolution multipliers for the number of channels. By default [1,2,2,2]. channel_mult_emb : int, optional @@ -93,29 +94,29 @@ class SongUNet(Module): dropout : float, optional Dropout probability applied to intermediate activations. By default 0.10. label_dropout : float, optional - Dropout probability of class labels for classifier-free guidance. By default 0.0. + Dropout probability of class labels for classifier-free guidance. By default 0.0. embedding_type : str, optional - Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++, 'zero' for none + Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++, 'zero' for none. By default 'positional'. channel_mult_noise : int, optional Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1. encoder_type : str, optional - Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. By default - 'standard'. + Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++, 'skip' for skip connections. + By default 'standard'. decoder_type : str, optional - Decoder architecture: 'standard' for both DDPM++ and NCSN++. By default - 'standard'. - resample_filter : List[int], optional (default=[1,1]) - Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++. - checkpoint_level : int, optional (default=0) - How many layers should use gradient checkpointing, 0 is None - additive_pos_embed: bool = False, - Set to True to add a learned position embedding after the first conv (used in StormCast) - + Decoder architecture: 'standard' or 'skip' for skip connections. By default 'standard'. + resample_filter : List[int], optional + Resampling filter coefficients: [1,1] for DDPM++, [1,3,3,1] for NCSN++. By default [1,1]. + checkpoint_level : int, optional + Number of layers that should use gradient checkpointing (0 disables checkpointing). + Higher values trade memory for computation. By default 0. + additive_pos_embed : bool, optional + If True, adds a learned positional embedding after the first convolution layer. + Used in StormCast model. By default False. Reference ---------- - Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and + Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and Poole, B., 2020. Score-based generative modeling through stochastic differential equations. arXiv preprint arXiv:2011.13456. @@ -413,19 +414,29 @@ def forward(self, x, noise_labels, class_labels, augment_labels=None): class SongUNetPosEmbd(SongUNet): - """ - Reimplementation of the DDPM++ and NCSN++ architectures, U-Net variants with - optional self-attention,embeddings, and encoder-decoder components. + """Extends SongUNet with positional embeddings. This model supports conditional and unconditional setups, as well as several options for various internal architectural choices such as encoder and decoder type, embedding type, etc., making it flexible and adaptable to different tasks and configurations. + This model adds positional embeddings to the base SongUNet architecture. The embeddings + can be selected using either a selector function or global indices, with the selector + approach being more computationally efficient. + + The model provides two methods for selecting positional embeddings: + + 1. Using a selector function (preferred method). See + :meth:`positional_embedding_selector` for details. + 2. Using global indices. See :meth:`positional_embedding_indexing` for + details. + Parameters - ----------- + ---------- img_resolution : Union[List[int], int] - The resolution of the input/output image, 1 value represents a square image. + The resolution of the input/output image. Can be a single int for square images + or a list [height, width] for rectangular images. in_channels : int Number of channels in the input image. out_channels : int @@ -435,39 +446,40 @@ class SongUNetPosEmbd(SongUNet): augment_dim : int, optional Dimensionality of augmentation labels; 0 means no augmentation. By default 0. model_channels : int, optional - Base multiplier for the number of channels across the network, by default 128. + Base multiplier for the number of channels across the network. By default 128. channel_mult : List[int], optional - Per-resolution multipliers for the number of channels. By default [1,2,2,2]. + Per-resolution multipliers for the number of channels. By default [1,2,2,2,2]. channel_mult_emb : int, optional Multiplier for the dimensionality of the embedding vector. By default 4. num_blocks : int, optional Number of residual blocks per resolution. By default 4. attn_resolutions : List[int], optional - Resolutions at which self-attention layers are applied. By default [16]. + Resolutions at which self-attention layers are applied. By default [28]. dropout : float, optional Dropout probability applied to intermediate activations. By default 0.13. label_dropout : float, optional - Dropout probability of class labels for classifier-free guidance. By default 0.0. + Dropout probability of class labels for classifier-free guidance. By default 0.0. embedding_type : str, optional Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++. By default 'positional'. channel_mult_noise : int, optional Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1. encoder_type : str, optional - Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. By default - 'standard'. + Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++, 'skip' for skip connections. + By default 'standard'. decoder_type : str, optional - Decoder architecture: 'standard' for both DDPM++ and NCSN++. By default - 'standard'. - resample_filter : List[int], optional (default=[1,1]) - Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++. - - - Reference - ---------- - Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and - Poole, B., 2020. Score-based generative modeling through stochastic differential - equations. arXiv preprint arXiv:2011.13456. + Decoder architecture: 'standard' or 'skip' for skip connections. By default 'standard'. + resample_filter : List[int], optional + Resampling filter coefficients: [1,1] for DDPM++, [1,3,3,1] for NCSN++. By default [1,1]. + gridtype : str, optional + Type of positional grid to use: 'sinusoidal', 'learnable', 'linear', or 'test'. + Controls how positional information is encoded. By default 'sinusoidal'. + N_grid_channels : int, optional + Number of channels in the positional embedding grid. For 'sinusoidal' must be 4 or + multiple of 4. For 'linear' must be 2. By default 4. + checkpoint_level : int, optional + Number of layers that should use gradient checkpointing (0 disables checkpointing). + Higher values trade memory for computation. By default 0. Note ----- @@ -476,13 +488,41 @@ class SongUNetPosEmbd(SongUNet): Example -------- - >>> model = SongUNet(img_resolution=16, in_channels=2, out_channels=2) + >>> import torch + >>> from physicsnemo.models.diffusion.song_unet import SongUNetPosEmbd + >>> from physicsnemo.utils.patching import GridPatching2D + >>> + >>> # Model initialization - in_channels must include both original input channels (2) + >>> # and the positional embedding channels (N_grid_channels=4 by default) + >>> model = SongUNetPosEmbd(img_resolution=16, in_channels=2+4, out_channels=2) >>> noise_labels = torch.randn([1]) >>> class_labels = torch.randint(0, 1, (1, 1)) + >>> # The input has only the original 2 channels - positional embeddings are + >>> # added automatically inside the forward method >>> input_image = torch.ones([1, 2, 16, 16]) >>> output_image = model(input_image, noise_labels, class_labels) >>> output_image.shape torch.Size([1, 2, 16, 16]) + >>> + >>> # Using a global index to select all positional embeddings + >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(16, 16)) + >>> global_index = patching.global_index(batch_size=1) + >>> output_image = model( + ... input_image, noise_labels, class_labels, + ... global_index=global_index + ... ) + >>> output_image.shape + torch.Size([1, 2, 16, 16]) + >>> + >>> # Using a custom embedding selector to select all positional embeddings + >>> def patch_embedding_selector(emb): + ... return patching.apply(emb[None].expand(1, -1, -1, -1)) + >>> output_image = model( + ... input_image, noise_labels, class_labels, + ... embedding_selector=patch_embedding_selector + ... ) + >>> output_image.shape + torch.Size([1, 2, 16, 16]) """ def __init__( @@ -535,56 +575,184 @@ def __init__( @nvtx.annotate(message="SongUNet", color="blue") def forward( - self, x, noise_labels, class_labels, global_index=None, augment_labels=None + self, + x, + noise_labels, + class_labels, + global_index: Optional[torch.Tensor] = None, + embedding_selector: Optional[Callable] = None, + augment_labels=None, ): - # append positional embedding to input conditioning + if embedding_selector is not None and global_index is not None: + raise ValueError( + "Cannot provide both embedding_selector and global_index. " + "embedding_selector is the preferred approach for better efficiency." + ) + + # Append positional embedding to input conditioning if self.pos_embd is not None: - selected_pos_embd = self.positional_embedding_indexing(x, global_index) + # Select positional embeddings with a selector function + if embedding_selector is not None: + selected_pos_embd = self.positional_embedding_selector( + x, embedding_selector + ) + # Select positional embeddings using global indices (selects all + # embeddings if global_index is None) + else: + selected_pos_embd = self.positional_embedding_indexing(x, global_index) x = torch.cat((x, selected_pos_embd), dim=1) return super().forward(x, noise_labels, class_labels, augment_labels) - def positional_embedding_indexing(self, x, global_index): + def positional_embedding_indexing( + self, + x: torch.Tensor, + global_index: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Select positional embeddings using global indices. + + This method either uses global indices to select specific embeddings or expands + the embeddings for the full input when no indices are provided. + + Typically used in patch-based training, where the batch dimension + contains multiple patches extracted from a larger image. + + Arguments + --------- + x : torch.Tensor + Input tensor of shape (B, C, H, W), used to determine batch size + and device. + global_index : Optional[torch.Tensor] + Optional tensor of indices for selecting embeddings. These should + correspond to the spatial indices of the batch elements in the + input tensor x. When provided, should have shape (B, 2, H, W) where + the second dimension contains y,x coordinates (indices of the + positional embedding grid). + + Returns + ------- + torch.Tensor + Selected positional embeddings with shape: + - If global_index provided: (B, N_pe, H, W) + - If global_index is None: (B, N_pe, H_pe, W_pe) + where N_pe is the number of positional embedding channels, and H_pe + and W_pe are the height and width of the positional embedding grid. + + Example + ------- + >>> # Create global indices using patching utility: + >>> from physicsnemo.utils.patching import GridPatching2D + >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(8, 8)) + >>> global_index = patching.global_index(batch_size=3) + >>> print(global_index.shape) + torch.Size([12, 2, 8, 8]) + + See Also + -------- + :meth:`physicsnemo.utils.patching.RandomPatching2D.global_index` + For generating random patch indices. + :meth:`physicsnemo.utils.patching.GridPatching2D.global_index` + For generating deterministic grid-based patch indices. + See these methods for possible ways to generate the global_index parameter. + """ + # If no global indices are provided, select all embeddings and expand + # to match the batch size of the input if global_index is None: - selected_pos_embd = ( + return ( self.pos_embd.to(x.dtype) .to(x.device)[None] .expand((x.shape[0], -1, -1, -1)) + ) # (B, N_pe, H, W) + + B = global_index.shape[0] + H = global_index.shape[2] + W = global_index.shape[3] + global_index = torch.reshape( + torch.permute(global_index, (1, 0, 2, 3)), (2, -1) + ) # (B, 2, H, W) to (2, B*H*W) + # Use advanced indexing to select the positional embeddings based on + # their y-x coordinates + selected_pos_embd = self.pos_embd.to(x.device)[ + :, global_index[0], global_index[1] + ] # (N_pe, B*H*W) + selected_pos_embd = ( + torch.permute( + torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], B, H, W)), + (1, 0, 2, 3), ) - else: - B = global_index.shape[0] - X = global_index.shape[2] - Y = global_index.shape[3] - global_index = torch.reshape( - torch.permute(global_index, (1, 0, 2, 3)), (2, -1) - ) # (B, 2, X, Y) to (2, B*X*Y) - selected_pos_embd = self.pos_embd.to(x.device)[ - :, global_index[0], global_index[1] - ] # (N_pe, B*X*Y) - selected_pos_embd = ( - torch.permute( - torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], B, X, Y)), - (1, 0, 2, 3), - ) - .to(x.device) - .to(x.dtype) - ) # (B, N_pe, X, Y) + .to(x.device) + .to(x.dtype) + ) # (B, N_pe, H, W) return selected_pos_embd + def positional_embedding_selector( + self, + x: torch.Tensor, + embedding_selector: Callable[[torch.Tensor], torch.Tensor], + ) -> torch.Tensor: + """Select positional embeddings using a selector function. + + Similar to positional_embedding_indexing, but uses a selector function + to select the embeddings. This method provides a more efficient way to + select embeddings for batches of data. + Typically used with patch-based processing, where the batch dimension + contains multiple patches extracted from a larger image. + + Arguments + --------- + x : torch.Tensor + Input tensor of shape (B, C, H, W) only used to determine dtype and + device. + embedding_selector : Callable + Function that takes as input an embedding tensor of shape (N_pe, + H_pe, W_pe) and returns selected embeddings with shape (batch_size, N_pe, H, W). + Each selected embedding should correspond to the positional + information of each batch element in x. + For patch-based processing, typically this should be based on + :meth:`physicsnemo.utils.patching.BasePatching2D.apply` method to + maintain consistency with patch extraction. + + Returns + ------- + torch.Tensor + Selected positional embeddings with shape (B, N_pe, H, W) + where N_pe is the number of positional embedding channels. + + Example + ------- + >>> # Define a selector function with a patching utility: + >>> from physicsnemo.utils.patching import GridPatching2D + >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(8, 8)) + >>> batch_size = 4 + >>> def embedding_selector(emb): + ... return patching.apply(emb[None].expand(batch_size, -1, -1, -1)) + >>> + + See Also + -------- + :meth:`physicsnemo.utils.patching.BasePatching2D.apply` + For the base patching method typically used in embedding_selector. + """ + return embedding_selector( + self.pos_embd.to(x.dtype).to(x.device) + ) # (B, N_pe, H, W) + def _get_positional_embedding(self): if self.N_grid_channels == 0: return None elif self.gridtype == "learnable": grid = torch.nn.Parameter( torch.randn(self.N_grid_channels, self.img_shape_y, self.img_shape_x) - ) + ) # (N_grid_channels, img_shape_y, img_shape_x) elif self.gridtype == "linear": if self.N_grid_channels != 2: raise ValueError("N_grid_channels must be set to 2 for gridtype linear") x = np.meshgrid(np.linspace(-1, 1, self.img_shape_y)) y = np.meshgrid(np.linspace(-1, 1, self.img_shape_x)) grid_x, grid_y = np.meshgrid(y, x) - grid = torch.from_numpy(np.stack((grid_x, grid_y), axis=0)) + grid = torch.from_numpy( + np.stack((grid_x, grid_y), axis=0) + ) # (2, img_shape_y, img_shape_x) grid.requires_grad = False elif self.gridtype == "sinusoidal" and self.N_grid_channels == 4: # print('sinusuidal grid added ......') @@ -600,7 +768,7 @@ def _get_positional_embedding(self): np.stack((grid_x1, grid_y1, grid_x2, grid_y2), axis=0), axis=0 ) ) - ) + ) # (4, img_shape_y, img_shape_x) grid.requires_grad = False elif self.gridtype == "sinusoidal" and self.N_grid_channels != 4: if self.N_grid_channels % 4 != 0: @@ -616,28 +784,39 @@ def _get_positional_embedding(self): for p_fn in [np.sin, np.cos]: grid_list.append(p_fn(grid_x * freq)) grid_list.append(p_fn(grid_y * freq)) - grid = torch.from_numpy(np.stack(grid_list, axis=0)) + grid = torch.from_numpy( + np.stack(grid_list, axis=0) + ) # (N_grid_channels, img_shape_y, img_shape_x) grid.requires_grad = False elif self.gridtype == "test" and self.N_grid_channels == 2: idx_x = torch.arange(self.img_shape_y) idx_y = torch.arange(self.img_shape_x) mesh_x, mesh_y = torch.meshgrid(idx_x, idx_y) - grid = torch.stack((mesh_x, mesh_y), dim=0) + grid = torch.stack((mesh_x, mesh_y), dim=0) # (2, img_shape_y, img_shape_x) else: raise ValueError("Gridtype not supported.") return grid +# TODO: Lots of stuff in common with SongUNetPosEmbd. Should inherit from it +# instead of SongUNet class SongUNetPosLtEmbd(SongUNet): """ - This model is adapated from SongUNetPosEmbd, with the incoporatation of lead-time aware - embedding for the GEFS-HRRR model. The lead-time embedding is activated by setting the - lead_time_channels and lead_time_steps parameters. + This model is adapted from SongUNetPosEmbd, with the incorporation of lead-time aware + embeddings. The lead-time embedding is activated by setting the + `lead_time_channels` and `lead_time_steps` parameters. + + Like SongUNetPosEmbd, this model provides two methods for selecting positional embeddings: + 1. Using a selector function (preferred method). See + :meth:`positional_embedding_selector` for details. + 2. Using global indices. See :meth:`positional_embedding_indexing` for + details. Parameters ----------- img_resolution : Union[List[int], int] - The resolution of the input/output image, 1 value represents a square image. + The resolution of the input/output image. Can be a single int for square images + or a list [height, width] for rectangular images. in_channels : int Number of channels in the input image. out_channels : int @@ -647,43 +826,49 @@ class SongUNetPosLtEmbd(SongUNet): augment_dim : int, optional Dimensionality of augmentation labels; 0 means no augmentation. By default 0. model_channels : int, optional - Base multiplier for the number of channels across the network, by default 128. + Base multiplier for the number of channels across the network. By default 128. channel_mult : List[int], optional - Per-resolution multipliers for the number of channels. By default [1,2,2,2]. + Per-resolution multipliers for the number of channels. By default [1,2,2,2,2]. channel_mult_emb : int, optional Multiplier for the dimensionality of the embedding vector. By default 4. num_blocks : int, optional Number of residual blocks per resolution. By default 4. attn_resolutions : List[int], optional - Resolutions at which self-attention layers are applied. By default [16]. + Resolutions at which self-attention layers are applied. By default [28]. dropout : float, optional Dropout probability applied to intermediate activations. By default 0.13. label_dropout : float, optional - Dropout probability of class labels for classifier-free guidance. By default 0.0. + Dropout probability of class labels for classifier-free guidance. By default 0.0. embedding_type : str, optional Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++. By default 'positional'. channel_mult_noise : int, optional Timestep embedding size: 1 for DDPM++, 2 for NCSN++. By default 1. encoder_type : str, optional - Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. By default - 'standard'. + Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++, 'skip' for skip connections. + By default 'standard'. decoder_type : str, optional - Decoder architecture: 'standard' for both DDPM++ and NCSN++. By default - 'standard'. - resample_filter : List[int], optional (default=[1,1]) - Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++. - lead_time_channels: int, optional - Length of lead time embedding vector - lead_time_steps: int, optional - Total number of lead times - - - Reference - ---------- - Reference: Song, Y., Sohl-Dickstein, J., Kingma, D.P., Kumar, A., Ermon, S. and - Poole, B., 2020. Score-based generative modeling through stochastic differential - equations. arXiv preprint arXiv:2011.13456. + Decoder architecture: 'standard' or 'skip' for skip connections. By default 'standard'. + resample_filter : List[int], optional + Resampling filter coefficients: [1,1] for DDPM++, [1,3,3,1] for NCSN++. By default [1,1]. + gridtype : str, optional + Type of positional grid to use: 'sinusoidal', 'learnable', 'linear', or 'test'. + Controls how positional information is encoded. By default 'sinusoidal'. + N_grid_channels : int, optional + Number of channels in the positional embedding grid. For 'sinusoidal' must be 4 or + multiple of 4. For 'linear' must be 2. By default 4. + lead_time_channels : int, optional + Number of channels in the lead time embedding. These are learned embeddings that + encode temporal forecast information. By default None. + lead_time_steps : int, optional + Number of discrete lead time steps to support. Each step gets its own learned + embedding vector. By default 9. + prob_channels : List[int], optional + Indices of probability output channels that should use softmax activation. + Used for classification outputs. By default empty list. + checkpoint_level : int, optional + Number of layers that should use gradient checkpointing (0 disables checkpointing). + Higher values trade memory for computation. By default 0. Note ----- @@ -692,11 +877,49 @@ class SongUNetPosLtEmbd(SongUNet): Example -------- - >>> model = SongUNet(img_resolution=16, in_channels=2, out_channels=2) + >>> import torch + >>> from physicsnemo.models.diffusion.song_unet import SongUNetPosLtEmbd + >>> from physicsnemo.utils.patching import GridPatching2D + >>> + >>> # Model initialization - in_channels must include original input channels (2), + >>> # positional embedding channels (N_grid_channels=4 by default) and + >>> # lead time embedding channels (4) + >>> model = SongUNetPosLtEmbd( + ... img_resolution=16, in_channels=2+4+4, out_channels=2, + ... lead_time_channels=4, lead_time_steps=9 + ... ) >>> noise_labels = torch.randn([1]) >>> class_labels = torch.randint(0, 1, (1, 1)) + >>> # The input has only the original 2 channels - positional embeddings and + >>> # lead time embeddings are added automatically inside the forward method >>> input_image = torch.ones([1, 2, 16, 16]) - >>> output_image = model(input_image, noise_labels, class_labels) + >>> lead_time_label = torch.tensor([3]) + >>> output_image = model( + ... input_image, noise_labels, class_labels, + ... lead_time_label=lead_time_label + ... ) + >>> output_image.shape + torch.Size([1, 2, 16, 16]) + >>> + >>> # Using global_index to select all the positional and lead time embeddings + >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(16, 16)) + >>> global_index = patching.global_index(batch_size=1) + >>> output_image = model( + ... input_image, noise_labels, class_labels, + ... lead_time_label=lead_time_label, + ... global_index=global_index + ... ) + >>> output_image.shape + torch.Size([1, 2, 16, 16]) + >>> + >>> # Using custom embedding selector to select all the positional and lead time embeddings + >>> def patch_embedding_selector(emb): + ... return patching.apply(emb[None].expand(1, -1, -1, -1)) + >>> output_image = model( + ... input_image, noise_labels, class_labels, + ... lead_time_label=lead_time_label, + ... embedding_selector=patch_embedding_selector + ... ) >>> output_image.shape torch.Size([1, 2, 16, 16]) """ @@ -767,10 +990,17 @@ def forward( noise_labels, class_labels, lead_time_label=None, - global_index=None, + global_index: Optional[torch.Tensor] = None, + embedding_selector: Optional[Callable] = None, augment_labels=None, ): - # append positional embedding to input conditioning + if embedding_selector is not None and global_index is not None: + raise ValueError( + "Cannot provide both embedding_selector and global_index. " + "embedding_selector is the preferred approach for better efficiency." + ) + + # Append positional and lead time embeddings to input conditioning embeds = [] if self.pos_embd is not None: embeds.append(self.pos_embd.to(x.device)) @@ -783,11 +1013,19 @@ def forward( ) if len(embeds) > 0: embeds = torch.cat(embeds, dim=0) - selected_pos_embd = self.positional_embedding_indexing( - x, embeds, global_index - ) + # Select embeddings using either selector function or global indices + if embedding_selector is not None: + selected_pos_embd = self.positional_embedding_selector( + x, embeds, embedding_selector + ) + else: + selected_pos_embd = self.positional_embedding_indexing( + x, embeds, global_index + ) x = torch.cat((x, selected_pos_embd), dim=1) + out = super().forward(x, noise_labels, class_labels, augment_labels) + # if training mode, let crossEntropyLoss do softmax. The model outputs logits. # if eval mode, the model outputs probability all_channels = list(range(out.shape[1])) # [0, 1, 2, ..., 10] @@ -811,30 +1049,127 @@ def forward( out_final = out return out_final - def positional_embedding_indexing(self, x, pos_embd, global_index): + def positional_embedding_indexing( + self, + x: torch.Tensor, + embeds: torch.Tensor, + global_index: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Select positional embeddings using global indices. + + This method either uses global indices to select specific embeddings or expands + the embeddings for the full input when no indices are provided. + + Arguments + --------- + x : torch.Tensor + Input tensor of shape (B, C, H, W), used to determine batch size + and device. + embeds : torch.Tensor + Combined positional and lead time embeddings tensor. + global_index : Optional[torch.Tensor] + Optional tensor of indices for selecting embeddings. These should + correspond to the spatial indices of the batch elements in the + input tensor x. When provided, should have shape (B, 2, H, W) where + the second dimension contains y,x coordinates. + + Returns + ------- + torch.Tensor + Selected embeddings with shape (B, N_pe, H, W) where N_pe is the + total number of embedding channels (positional + lead time). + + Example + ------- + >>> # Create global indices using patching utility: + >>> from physicsnemo.utils.patching import GridPatching2D + >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(8, 8)) + >>> global_index = patching.global_index(batch_size=1) + >>> global_index.shape + torch.Size([4, 2, 8, 8]) + + See Also + -------- + :meth:`physicsnemo.utils.patching.RandomPatching2D.global_index` + For generating random patch indices. + :meth:`physicsnemo.utils.patching.GridPatching2D.global_index` + For generating deterministic grid-based patch indices. + See these methods for possible ways to generate the global_index parameter. + """ if global_index is None: - selected_pos_embd = ( - pos_embd.to(x.dtype).to(x.device)[None].expand((x.shape[0], -1, -1, -1)) + return ( + embeds.to(x.dtype).to(x.device)[None].expand((x.shape[0], -1, -1, -1)) ) - else: - B = global_index.shape[0] - X = global_index.shape[2] - Y = global_index.shape[3] - global_index = torch.reshape( - torch.permute(global_index, (1, 0, 2, 3)), (2, -1) - ) # (B, 2, X, Y) to (2, B*X*Y) - selected_pos_embd = pos_embd.to(x.device)[ - :, global_index[0], global_index[1] - ] # (N_pe, B*X*Y) - selected_pos_embd = ( - torch.permute( - torch.reshape(selected_pos_embd, (pos_embd.shape[0], B, X, Y)), - (1, 0, 2, 3), - ) - .to(x.device) - .to(x.dtype) - ) # (B, N_pe, X, Y) - return selected_pos_embd + + B = global_index.shape[0] + H = global_index.shape[2] + W = global_index.shape[3] + global_index = torch.reshape( + torch.permute(global_index, (1, 0, 2, 3)), (2, -1) + ) # (B, 2, H, W) to (2, B*H*W) + selected_embeds = embeds.to(x.device)[:, global_index[0], global_index[1]] + selected_embeds = ( + torch.permute( + torch.reshape(selected_embeds, (embeds.shape[0], B, H, W)), + (1, 0, 2, 3), + ) + .to(x.device) + .to(x.dtype) + ) # (B, N_pe, H, W) + return selected_embeds + + def positional_embedding_selector( + self, + x: torch.Tensor, + embeds: torch.Tensor, + embedding_selector: Callable[[torch.Tensor], torch.Tensor], + ) -> torch.Tensor: + """Select positional embeddings using a selector function. + + Similar to positional_embedding_indexing, but uses a selector function + to select the embeddings. This method provides a more efficient way to + select embeddings for batches of data. + + Arguments + --------- + x : torch.Tensor + Input tensor of shape (B, C, H, W) used to determine batch + size and device. + embeds : torch.Tensor + Combined positional and lead time embeddings tensor of shape + (N_pe, H_pe, W_pe) where N_pe is the total number of embedding + channels. + embedding_selector : Callable + Function that takes as input an embedding tensor of shape (N_pe, + H_pe, W_pe) and returns selected embeddings with shape (B, N_pe, H, W). + Each selected embedding should correspond to the positional + information of each batch element in x. + For patch-based processing, typically this should be based on + :meth:`physicsnemo.utils.patching.BasePatching2D.apply` method to + maintain consistency with patch extraction. + + Returns + ------- + torch.Tensor + Selected embeddings with shape (B, N_pe, H, W) where N_pe is the + total number of embedding channels (positional + lead time). + + Example + ------- + >>> # Define a selector function with a patching utility: + >>> from physicsnemo.utils.patching import GridPatching2D + >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(8, 8)) + >>> batch_size = 4 + >>> def embedding_selector(emb): + ... return patching.apply(emb[None].expand(batch_size, -1, -1, -1)) + >>> + + See Also + -------- + :meth:`physicsnemo.utils.patching.BasePatching2D.apply` + For the base patching method typically used in embedding_selector. + """ + return embedding_selector(embeds.to(x.dtype).to(x.device)) # (B, N_pe, H, W) def _get_positional_embedding(self): if self.N_grid_channels == 0: @@ -842,14 +1177,16 @@ def _get_positional_embedding(self): elif self.gridtype == "learnable": grid = torch.nn.Parameter( torch.randn(self.N_grid_channels, self.img_shape_y, self.img_shape_x) - ) + ) # (N_grid_channels, img_shape_y, img_shape_x) elif self.gridtype == "linear": if self.N_grid_channels != 2: raise ValueError("N_grid_channels must be set to 2 for gridtype linear") x = np.meshgrid(np.linspace(-1, 1, self.img_shape_y)) y = np.meshgrid(np.linspace(-1, 1, self.img_shape_x)) grid_x, grid_y = np.meshgrid(y, x) - grid = torch.from_numpy(np.stack((grid_x, grid_y), axis=0)) + grid = torch.from_numpy( + np.stack((grid_x, grid_y), axis=0) + ) # (2, img_shape_y, img_shape_x) grid.requires_grad = False elif self.gridtype == "sinusoidal" and self.N_grid_channels == 4: # print('sinusuidal grid added ......') @@ -865,7 +1202,7 @@ def _get_positional_embedding(self): np.stack((grid_x1, grid_y1, grid_x2, grid_y2), axis=0), axis=0 ) ) - ) + ) # (4, img_shape_y, img_shape_x) grid.requires_grad = False elif self.gridtype == "sinusoidal" and self.N_grid_channels != 4: if self.N_grid_channels % 4 != 0: @@ -881,13 +1218,15 @@ def _get_positional_embedding(self): for p_fn in [np.sin, np.cos]: grid_list.append(p_fn(grid_x * freq)) grid_list.append(p_fn(grid_y * freq)) - grid = torch.from_numpy(np.stack(grid_list, axis=0)) + grid = torch.from_numpy( + np.stack(grid_list, axis=0) + ) # (N_grid_channels, img_shape_y, img_shape_x) grid.requires_grad = False elif self.gridtype == "test" and self.N_grid_channels == 2: idx_x = torch.arange(self.img_shape_y) idx_y = torch.arange(self.img_shape_x) mesh_x, mesh_y = torch.meshgrid(idx_x, idx_y) - grid = torch.stack((mesh_x, mesh_y), dim=0) + grid = torch.stack((mesh_x, mesh_y), dim=0) # (2, img_shape_y, img_shape_x) else: raise ValueError("Gridtype not supported.") return grid @@ -902,5 +1241,5 @@ def _get_lead_time_embedding(self): self.img_shape_y, self.img_shape_x, ) - ) + ) # (lead_time_steps, lead_time_channels, img_shape_y, img_shape_x) return grid diff --git a/physicsnemo/models/diffusion/unet.py b/physicsnemo/models/diffusion/unet.py index 72706064d0..c9cc769dae 100644 --- a/physicsnemo/models/diffusion/unet.py +++ b/physicsnemo/models/diffusion/unet.py @@ -16,6 +16,7 @@ import importlib from dataclasses import dataclass +from typing import List, Literal, Tuple, Union import torch @@ -45,31 +46,35 @@ class MetaData(ModelMetaData): class UNet(Module): # TODO a lot of redundancy, need to clean up """ - U-Net Wrapper for CorrDiff. + U-Net Wrapper for CorrDiff deterministic regression model. Parameters ----------- - img_resolution : int - The resolution of the input/output image. - img_channels : int - Number of color channels. + img_resolution : Union[int, Tuple[int, int]] + The resolution of the input/output image. If a single int is provided, + then the image is assumed to be square. img_in_channels : int - Number of input color channels. + Number of channels in the input image. img_out_channels : int - Number of output color channels. + Number of channels in the output image. use_fp16: bool, optional - Execute the underlying model at FP16 precision?, by default False. - sigma_min: float, optional - Minimum supported noise level, by default 0. - sigma_max: float, optional - Maximum supported noise level, by default float('inf'). - sigma_data: float, optional - Expected standard deviation of the training data, by default 0.5. + Execute the underlying model at FP16 precision, by default False. model_type: str, optional - Class name of the underlying model, by default 'DhariwalUNet'. + Class name of the underlying model. Must be one of the following: + 'SongUNet', 'SongUNetPosEmbd', 'SongUNetPosLtEmbd', 'DhariwalUNet'. + Defaults to 'SongUNetPosEmbd'. **model_kwargs : dict - Keyword arguments for the underlying model. + Keyword arguments passed to the underlying model `__init__` method. + + See Also + -------- + For information on model types and their usage: + :class:`~physicsnemo.models.diffusion.SongUNet`: Basic U-Net for diffusion models + :class:`~physicsnemo.models.diffusion.SongUNetPosEmbd`: U-Net with positional embeddings + :class:`~physicsnemo.models.diffusion.SongUNetPosLtEmbd`: U-Net with positional and lead-time embeddings + Please refer to the documentation of these classes for details on how to call + and use these models directly. References ---------- @@ -81,35 +86,28 @@ class UNet(Module): # TODO a lot of redundancy, need to clean up def __init__( self, - img_resolution, - img_channels, - img_in_channels, - img_out_channels, - use_fp16=False, - sigma_min=0, - sigma_max=float("inf"), - sigma_data=0.5, - model_type="SongUNetPosEmbd", - **model_kwargs, + img_resolution: Union[int, Tuple[int, int]], + img_in_channels: int, + img_out_channels: int, + use_fp16: bool = False, + model_type: Literal[ + "SongUNetPosEmbd", "SongUNetPosLtEmbd", "SongUNet", "DhariwalUNet" + ] = "SongUNetPosEmbd", + **model_kwargs: dict, ): super().__init__(meta=MetaData) - self.img_channels = img_channels - # for compatibility with older versions that took only 1 dimension if isinstance(img_resolution, int): self.img_shape_x = self.img_shape_y = img_resolution else: - self.img_shape_x = img_resolution[0] - self.img_shape_y = img_resolution[1] + self.img_shape_y = img_resolution[0] + self.img_shape_x = img_resolution[1] self.img_in_channels = img_in_channels self.img_out_channels = img_out_channels self.use_fp16 = use_fp16 - self.sigma_min = sigma_min - self.sigma_max = sigma_max - self.sigma_data = sigma_data model_class = getattr(network_module, model_type) self.model = model_class( img_resolution=img_resolution, @@ -118,13 +116,46 @@ def __init__( **model_kwargs, ) - def forward(self, x, img_lr, sigma, force_fp32=False, **model_kwargs): + def forward( + self, + x: torch.Tensor, + img_lr: torch.Tensor, + force_fp32: bool = False, + **model_kwargs: dict, + ) -> torch.Tensor: + """ + Forward pass of the UNet wrapper model. + + This method concatenates the input tensor with the low-resolution conditioning tensor + and passes the result through the underlying model. + + Parameters + ---------- + x : torch.Tensor + The input tensor, typically zero-filled, of shape (B, C_hr, H, W). + img_lr : torch.Tensor + Low-resolution conditioning image of shape (B, C_lr, H, W). + force_fp32 : bool, optional + Whether to force FP32 precision regardless of the `use_fp16` attribute, + by default False. + **model_kwargs : dict + Additional keyword arguments to pass to the underlying model + `self.model` forward method. + + Returns + ------- + torch.Tensor + Output tensor (prediction) of shape (B, C_hr, H, W). + + Raises + ------ + ValueError + If the model output dtype doesn't match the expected dtype. + """ # SR: concatenate input channels if img_lr is not None: x = torch.cat((x, img_lr), dim=1) - x = x.to(torch.float32) - sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) dtype = ( torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == "cuda") @@ -133,29 +164,27 @@ def forward(self, x, img_lr, sigma, force_fp32=False, **model_kwargs): F_x = self.model( x.to(dtype), # (c_in * x).to(dtype), - torch.zeros( - sigma.numel(), dtype=sigma.dtype, device=sigma.device - ), # c_noise.flatten() + torch.zeros(x.shape[0], dtype=dtype, device=x.device), # c_noise.flatten() class_labels=None, **model_kwargs, ) if (F_x.dtype != dtype) and not torch.is_autocast_enabled(): raise ValueError( - f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead." + f"Expected the dtype to be {dtype}, " f"but got {F_x.dtype} instead." ) - # skip connection - for SR there's size mismatch bwtween input and output + # skip connection D_x = F_x.to(torch.float32) return D_x - def round_sigma(self, sigma): + def round_sigma(self, sigma: Union[float, List, torch.Tensor]) -> torch.Tensor: """ Convert a given sigma value(s) to a tensor representation. Parameters ---------- - sigma : Union[float list, torch.Tensor] + sigma : Union[float, List, torch.Tensor] The sigma value(s) to convert. Returns @@ -189,7 +218,7 @@ class StormCastUNet(Module): sigma_data: float, optional Expected standard deviation of the training data, by default 0.5. model_type: str, optional - Class name of the underlying model, by default 'DhariwalUNet'. + Class name of the underlying model, by default 'SongUNet'. **model_kwargs : dict Keyword arguments for the underlying model. diff --git a/physicsnemo/utils/corrdiff/utils.py b/physicsnemo/utils/corrdiff/utils.py index 0612d9c9d8..e4be3bf4f5 100644 --- a/physicsnemo/utils/corrdiff/utils.py +++ b/physicsnemo/utils/corrdiff/utils.py @@ -15,6 +15,7 @@ # limitations under the License. import datetime +from typing import Optional import cftime import nvtx @@ -32,33 +33,55 @@ def regression_step( net: torch.nn.Module, img_lr: torch.Tensor, latents_shape: torch.Size, - lead_time_label: torch.Tensor = None, + lead_time_label: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ - Given a low-res input, performs a regression step to produce ensemble mean. - This function performs the regression on a single instance and then replicates - the results across the batch dimension. - - Args: - net (torch.nn.Module): U-Net model for regression. - img_lr (torch.Tensor): Low-resolution input. - latents_shape (torch.Size): Shape of the latent representation. Typically - (batch_size, out_channels, image_shape_x, image_shape_y). - - - Returns: - torch.Tensor: Predicted output at the next time step. + Perform a regression step to produce ensemble mean prediction. + + This function takes a low-resolution input and performs a regression step to produce + an ensemble mean prediction. It processes a single instance and then replicates + the results across the batch dimension if needed. + + Parameters + ---------- + net : torch.nn.Module + U-Net model for regression. + img_lr : torch.Tensor + Low-resolution input to the network with shape (1, channels, height, width). + Must have a batch dimension of 1. + latents_shape : torch.Size + Shape of the latent representation with format + (batch_size, out_channels, image_shape_y, image_shape_x). + lead_time_label : Optional[torch.Tensor], optional + Lead time label tensor for lead time conditioning, + with shape (1, lead_time_dims). Default is None. + + Returns + ------- + torch.Tensor + Predicted ensemble mean at the next time step with shape matching latents_shape. + + Raises + ------ + ValueError + If img_lr has a batch size greater than 1. """ # Create a tensor of zeros with the given shape and move it to the appropriate device x_hat = torch.zeros(latents_shape, dtype=torch.float64, device=net.device) - t_hat = torch.tensor(1.0, dtype=torch.float64, device=net.device) + + # Safety check: avoid silently ignoring batch elements in img_lr + if img_lr.shape[0] > 1: + raise ValueError( + f"Expected img_lr to have a batch size of 1, " + f"but found {img_lr.shape[0]}." + ) # Perform regression on a single batch element with torch.inference_mode(): if lead_time_label is not None: - x = net(x_hat[0:1], img_lr, t_hat, lead_time_label=lead_time_label) + x = net(x=x_hat[0:1], img_lr=img_lr, lead_time_label=lead_time_label) else: - x = net(x_hat[0:1], img_lr, t_hat) + x = net(x=x_hat[0:1], img_lr=img_lr) # If the batch size is greater than 1, repeat the prediction if x_hat.shape[0] > 1: @@ -67,48 +90,85 @@ def regression_step( return x -def diffusion_step( # TODO generalize the module and add defaults +def diffusion_step( net: torch.nn.Module, sampler_fn: callable, - seed_batch_size: int, img_shape: tuple, img_out_channels: int, rank_batches: list, img_lr: torch.Tensor, rank: int, device: torch.device, - hr_mean: torch.Tensor = None, + mean_hr: torch.Tensor = None, lead_time_label: torch.Tensor = None, ) -> torch.Tensor: """ Generate images using diffusion techniques as described in the relevant paper. - Args: - net (torch.nn.Module): The diffusion model network. - sampler_fn (callable): Function used to sample images from the diffusion model. - seed_batch_size (int): Number of seeds per batch. - img_shape (tuple): Shape of the images, (height, width). - img_out_channels (int): Number of output channels for the image. - rank_batches (list): List of batches of seeds to process. - img_lr (torch.Tensor): Low-resolution input image. - rank (int): Rank of the current process for distributed processing. - device (torch.device): Device to perform computations. - mean_hr (torch.Tensor, optional): High-resolution mean tensor, to be used as an additional input. By default None. - - Returns: - torch.Tensor: Generated images concatenated across batches. + This function applies a diffusion model to generate high-resolution images based on + low-resolution inputs. It supports optional conditioning on high-resolution mean + predictions and lead time labels. + + For each low-resolution sample in `img_lr`, the function generates multiple + high-resolution samples, with different random seeds, specified in `rank_batches`. + The function then concatenates these high-resolution samples across the batch dimension. + + Parameters + ---------- + net : torch.nn.Module + The diffusion model network. + sampler_fn : callable + Function used to sample images from the diffusion model. + img_shape : tuple + Shape of the images, (height, width). + img_out_channels : int + Number of output channels for the image. + rank_batches : list + List of batches of seeds to process. + img_lr : torch.Tensor + Low-resolution input image with shape (seed_batch_size, channels_lr, height, width). + rank : int, optional + Rank of the current process for distributed processing. + device : torch.device, optional + Device to perform computations. + mean_hr : torch.Tensor, optional + High-resolution mean tensor to be used as an additional input, + with shape (1, channels_hr, height, width). Default is None. + lead_time_label : torch.Tensor, optional + Lead time label tensor for temporal conditioning, + with shape (batch_size, lead_time_dims). Default is None. + + Returns + ------- + torch.Tensor + Generated images concatenated across batches with shape + (seed_batch_size * len(rank_batches), out_channels, height, width). """ + # Check img_lr dimensions match expected shape + if img_lr.shape[2:] != img_shape: + raise ValueError( + f"img_lr shape {img_lr.shape[2:]} does not match expected shape img_shape {img_shape}" + ) + + # Check mean_hr dimensions if provided + if mean_hr is not None: + if mean_hr.shape[2:] != img_shape: + raise ValueError( + f"mean_hr shape {mean_hr.shape[2:]} does not match expected shape img_shape {img_shape}" + ) + if mean_hr.shape[0] != 1: + raise ValueError(f"mean_hr must have batch size 1, got {mean_hr.shape[0]}") + img_lr = img_lr.to(memory_format=torch.channels_last) # Handling of the high-res mean additional_args = {} - if hr_mean is not None: - additional_args["mean_hr"] = hr_mean + if mean_hr is not None: + additional_args["mean_hr"] = mean_hr if lead_time_label is not None: additional_args["lead_time_label"] = lead_time_label - additional_args["img_shape"] = img_shape # Loop over batches all_images = [] @@ -122,7 +182,7 @@ def diffusion_step( # TODO generalize the module and add defaults rnd = StackedRandomGenerator(device, batch_seeds) latents = rnd.randn( [ - seed_batch_size, + img_lr.shape[0], img_out_channels, img_shape[0], img_shape[1], diff --git a/physicsnemo/utils/generative/__init__.py b/physicsnemo/utils/generative/__init__.py index a708ccb3d6..a08d9784f4 100644 --- a/physicsnemo/utils/generative/__init__.py +++ b/physicsnemo/utils/generative/__init__.py @@ -15,7 +15,7 @@ # limitations under the License. from .deterministic_sampler import deterministic_sampler -from .stochastic_sampler import image_batching, image_fuse, stochastic_sampler +from .stochastic_sampler import stochastic_sampler from .utils import ( EasyDict, InfiniteSampler, diff --git a/physicsnemo/utils/generative/deterministic_sampler.py b/physicsnemo/utils/generative/deterministic_sampler.py index 9d79cf6ce7..5033bd735b 100644 --- a/physicsnemo/utils/generative/deterministic_sampler.py +++ b/physicsnemo/utils/generative/deterministic_sampler.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable, Literal, Optional import numpy as np import nvtx @@ -26,33 +27,142 @@ @nvtx.annotate(message="deterministic_sampler", color="red") def deterministic_sampler( - net, - latents, - img_lr, - img_shape=None, - class_labels=None, - randn_like=torch.randn_like, - num_steps=18, - sigma_min=None, - sigma_max=None, - rho=7, - solver="heun", - discretization="edm", - schedule="linear", - scaling="none", - epsilon_s=1e-3, - C_1=0.001, - C_2=0.008, - M=1000, - alpha=1, - S_churn=0, - S_min=0, - S_max=float("inf"), - S_noise=1, -): + net: torch.nn.Module, + latents: torch.Tensor, + img_lr: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + randn_like: Callable = torch.randn_like, + num_steps: int = 18, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + rho: float = 7.0, + solver: Literal["heun", "euler"] = "heun", + discretization: Literal["vp", "ve", "iddpm", "edm"] = "edm", + schedule: Literal["vp", "ve", "linear"] = "linear", + scaling: Literal["vp", "none"] = "none", + epsilon_s: float = 1e-3, + C_1: float = 0.001, + C_2: float = 0.008, + M: int = 1000, + alpha: float = 1.0, + S_churn: int = 0, + S_min: float = 0.0, + S_max: float = float("inf"), + S_noise: float = 1.0, +) -> torch.Tensor: """ - Generalized sampler, representing the superset of all sampling methods discussed - in the paper "Elucidating the Design Space of Diffusion-Based Generative Models" + Generalized sampler, representing the superset of all sampling methods + discussed in the paper "Elucidating the Design Space of Diffusion-Based + Generative Models" (EDM). + - https://arxiv.org/abs/2206.00364 + + This function integrates an ODE (probability flow) or SDE over multiple + time-steps to generate samples from the diffusion model provided by the + argument 'net'. It can be used to combine multiple choices to + design a custom sampler, including multiple integration solver, + discretization method, noise schedule, and so on. + + Parameters: + ----------- + net : torch.nn.Module + The diffusion model to use in the sampling process. + latents : torch.Tensor + The latent random noise used as the initial condition for the + stochastic ODE. + img_lr : torch.Tensor + Low-resolution input image for conditioning the diffusion process. + Passed as a keywork argument to the model 'net'. + class_labels : Optional[torch.Tensor] + Labels of the classes used as input to a class-conditionned + diffusion model. Passed as a keyword argument to the model 'net'. + If provided, it must be a tensor containing integer values. + Defaults to None, in which case it is ignored. + randn_like: Callable + Random Number Generator to generate random noise that is added + during the stochastic sampling. Must have the same signature as + torch.randn_like and return torch.Tensor. Defaults to + torch.randn_like. + num_steps : Optional[int] + Number of time-steps for the stochastic ODE integration. Defaults + to 18. + sigma_min : Optional[float] + Minimum noise level for the diffusion process. 'sigma_min', + 'sigma_max', and 'rho' are used to compute the time-step + discretization, based on the choice of discretization. For the + default choice ("discretization='heun'"), the noise level schedule + is computed as: + :math:`\sigma_i = (\sigma_{max}^{1/\rho} + i / (num_steps - 1) * (\sigma_{min}^{1/\rho} - \sigma_{max}^{1/\rho}))^{rho}`. + For other choices of 'discretization', see details in the EDM + paper. Defaults to None, in which case defaults values depending + of the specified discretization are used. + sigma_max : Optional[float] + Maximum noise level for the diffusion process. See sigma_min for + details. Defaults to None, in which case defaults values depending + of the specified discretization are used. + rho : float, optional + Exponent used in the noise schedule. See sigma_min for details. + Only used when 'discretization' is 'heun'. Values in the range [5, + 10] produce better images. Lower values lead to truncation errors + equalized over all time steps. Defaults to 7. + solver : Literal["heun", "euler"] + The numerical method used to integrate the stochastic ODE. "euler" + is 1st order solver, which is faster but produces lower-quality + images. "heun" is 2nd order, more expensive, but produces + higher-quality images. Defaults to "heun". + discretization : Literal["vp", "ve", "iddpm", "edm"] + The method to discretize time-steps :math:`t_i` in the + diffusion process. See the EDM papper for details. Defaults to + "edm". + schedule : Literal["vp", "ve", "linear"] + The type of noise level schedule. Defaults to "linear". If + schedule='ve', then :math:`\sigma(t) = \sqrt{t}`. If + schedule='linear', then :math:`\sigma(t) = t`. If schedule='vp', + see EDM paper for details. Defaults to "linear". + scaling : Literal["vp", "none"] + The type of time-dependent signal scaling :math:`s(t)`, such that + :math:`x = s(t) \hat{x}`. See EDM paper for details on the 'vp' + scaling. Defaults to 'none', in which case :math:`s(t)=1`. + epsilon_s : float, optional + Parameter to compute both the noise level schedule and the + time-step discetization. Only used when discretization='vp' or + schedule='vp'. Ignored in other cases. Defaults to 1e-3. + C_1 : float, optional + Parameters to compute the time-step discetization. Only used when + discretization='iddpm'. Defaults to 0.001. + C_2 : float, optional + Same as for C_1. Only used when discretization='iddpm'. Defaults to + 0.008. + M : int, optional + Same as for C_1 and C_2. Only used when discretization='iddpm'. + Defaults to 1000. + alpha : float, optional + Controls (i.e. multiplies) the step size :math:`t_{i+1} - + \hat{t}_i` in the stochastic sampler, where :math:`\hat{t}_i` is + the temporarily increased noise level. Defaults to 1.0, which is + the recommended value. + S_churn : int, optional + Controls the amount of stochasticty injected in the SDE in the + stochatsic sampler. Larger values of S_churn lead to larger values + of :math:`\hat{t}_i`, which in turn lead to injecting more + stochasticity in the SDE by Defaults to 0, which means no + stochasticity is injected. + S_min : float, optional + S_min and S_max control the time-step range obver which + stochasticty is injected in the SDE. Stochasticity is injected + through `\hat{t}_i` for time-steps :math:`t_i` such that + :math:`S_{min} \leq t_i \leq S_{max}`. Defaults to 0.0. + S_max : float, optional + See S_min. Defaults to float("inf"). + S_noise : float, optional + Controls the amount of stochasticty injected in the SDE in the + stochatsic sampler. Added signal noise is proportinal to + :math:`\epsilon_i` where `\epsilon_i ~ N(0, S_{noise}^2)`. Defaults + to 1.0. + + Returns + ------- + torch.Tensor: + Generated batch of samples. Same shape as the input 'latents'. """ # conditioning @@ -89,7 +199,8 @@ def deterministic_sampler( ve_sigma_deriv = lambda t: 0.5 / t.sqrt() ve_sigma_inv = lambda sigma: sigma**2 - # Select default noise level range based on the specified time step discretization. + # Select default noise level range based on the specified + # time step discretization. if sigma_min is None: vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s) sigma_min = {"vp": vp_def, "ve": 0.02, "iddpm": 0.002, "edm": 0.002}[ @@ -223,7 +334,8 @@ def deterministic_sampler( ).to(torch.float64) d_prime = ( sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime) - ) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised + ) * x_prime + -sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised x_next = x_hat + h * ( (1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime ) diff --git a/physicsnemo/utils/generative/stochastic_sampler.py b/physicsnemo/utils/generative/stochastic_sampler.py index ddcf9cc7f8..b2bac0e1ce 100644 --- a/physicsnemo/utils/generative/stochastic_sampler.py +++ b/physicsnemo/utils/generative/stochastic_sampler.py @@ -15,287 +15,21 @@ # limitations under the License. -import math -from typing import Any, Callable, Optional +from typing import Callable, Optional import torch from torch import Tensor - -def image_batching( - input: Tensor, - img_shape_y: int, - img_shape_x: int, - patch_shape_y: int, - patch_shape_x: int, - batch_size: int, - overlap_pix: int, - boundary_pix: int, - input_interp: Optional[Tensor] = None, -) -> Tensor: - """ - Splits a full image into a batch of patched images. - - This function takes a full image and splits it into patches, adding padding where necessary. - It can also concatenate additional interpolated data to each patch if provided. - - Parameters - ---------- - input : Tensor - The input tensor representing the full image with shape (batch_size, channels, img_shape_x, img_shape_y). - img_shape_x : int - The width (x-dimension) of the original full image. - img_shape_y : int - The height (y-dimension) of the original full image. - patch_shape_x : int - The width (x-dimension) of each image patch. - patch_shape_y : int - The height (y-dimension) of each image patch. - batch_size : int - The original batch size before patching. - overlap_pix : int - The number of overlapping pixels between adjacent patches. - boundary_pix : int - The number of pixels to crop as a boundary from each patch. - input_interp : Optional[Tensor], optional - Optional additional data to concatenate to each patch with shape (batch_size, interp_channels, patch_shape_x, patch_shape_y). - By default None. - - Returns - ------- - Tensor - A tensor containing the image patches, with shape (total_patches * batch_size, channels [+ interp_channels], patch_shape_x, patch_shape_y). - """ - patch_num_x = math.ceil(img_shape_x / (patch_shape_x - overlap_pix - boundary_pix)) - patch_num_y = math.ceil(img_shape_y / (patch_shape_y - overlap_pix - boundary_pix)) - padded_shape_x = ( - (patch_shape_x - overlap_pix - boundary_pix) * (patch_num_x - 1) - + patch_shape_x - + boundary_pix - ) - padded_shape_y = ( - (patch_shape_y - overlap_pix - boundary_pix) * (patch_num_y - 1) - + patch_shape_y - + boundary_pix - ) - pad_x_right = padded_shape_x - img_shape_x - boundary_pix - pad_y_right = padded_shape_y - img_shape_y - boundary_pix - input_padded = torch.zeros( - input.shape[0], input.shape[1], padded_shape_y, padded_shape_x - ).to(input.device) - image_padding = torch.nn.ReflectionPad2d( - (boundary_pix, pad_x_right, boundary_pix, pad_y_right) - ).to( - input.device - ) # (padding_left,padding_right,padding_top,padding_bottom) - input_padded = image_padding(input) - patch_num = patch_num_x * patch_num_y - if input_interp is not None: - output = torch.zeros( - patch_num * batch_size, - input.shape[1] + input_interp.shape[1], - patch_shape_y, - patch_shape_x, - ).to(input.device) - else: - output = torch.zeros( - patch_num * batch_size, input.shape[1], patch_shape_y, patch_shape_x - ).to(input.device) - for x_index in range(patch_num_x): - for y_index in range(patch_num_y): - x_start = x_index * (patch_shape_x - overlap_pix - boundary_pix) - y_start = y_index * (patch_shape_y - overlap_pix - boundary_pix) - if input_interp is not None: - output[ - (x_index * patch_num_y + y_index) - * batch_size : (x_index * patch_num_y + y_index + 1) - * batch_size, - ] = torch.cat( - ( - input_padded[ - :, - :, - y_start : y_start + patch_shape_y, - x_start : x_start + patch_shape_x, - ], - input_interp, - ), - dim=1, - ) - else: - output[ - (x_index * patch_num_y + y_index) - * batch_size : (x_index * patch_num_y + y_index + 1) - * batch_size, - ] = input_padded[ - :, - :, - y_start : y_start + patch_shape_y, - x_start : x_start + patch_shape_x, - ] - return output - - -def image_fuse( - input: Tensor, - img_shape_y: int, - img_shape_x: int, - patch_shape_y: int, - patch_shape_x: int, - batch_size: int, - overlap_pix: int, - boundary_pix: int, -) -> Tensor: - """ - Reconstructs a full image from a batch of patched images. - - This function takes a batch of image patches and reconstructs the full image - by stitching the patches together. The function accounts for overlapping and - boundary pixels, ensuring that overlapping areas are averaged. - - Parameters - ---------- - input : Tensor - The input tensor containing the image patches with shape (total_patches * batch_size, channels, patch_shape_x, patch_shape_y). - img_shape_x : int - The width (x-dimension) of the original full image. - img_shape_y : int - The height (y-dimension) of the original full image. - patch_shape_x : int - The width (x-dimension) of each image patch. - patch_shape_y : int - The height (y-dimension) of each image patch. - batch_size : int - The original batch size before patching. - overlap_pix : int - The number of overlapping pixels between adjacent patches. - boundary_pix : int - The number of pixels to crop as a boundary from each patch. - - Returns - ------- - Tensor - The reconstructed full image tensor with shape (batch_size, channels, img_shape_x, img_shape_y). - - """ - patch_num_x = math.ceil(img_shape_x / (patch_shape_x - overlap_pix - boundary_pix)) - patch_num_y = math.ceil(img_shape_y / (patch_shape_y - overlap_pix - boundary_pix)) - padded_shape_x = ( - (patch_shape_x - overlap_pix - boundary_pix) * (patch_num_x - 1) - + patch_shape_x - + boundary_pix - ) - padded_shape_y = ( - (patch_shape_y - overlap_pix - boundary_pix) * (patch_num_y - 1) - + patch_shape_y - + boundary_pix - ) - pad_x_right = padded_shape_x - img_shape_x - boundary_pix - pad_y_right = padded_shape_y - img_shape_y - boundary_pix - residual_x = patch_shape_x - pad_x_right # residual pixels in the last patch - residual_y = patch_shape_y - pad_y_right # residual pixels in the last patch - output = torch.zeros( - batch_size, input.shape[1], img_shape_y, img_shape_x, device=input.device - ) - one_map = torch.ones(1, 1, input.shape[2], input.shape[3], device=input.device) - count_map = torch.zeros( - 1, 1, img_shape_y, img_shape_x, device=input.device - ) # to count the overlapping times - for x_index in range(patch_num_x): - for y_index in range(patch_num_y): - x_start = x_index * (patch_shape_x - overlap_pix - boundary_pix) - y_start = y_index * (patch_shape_y - overlap_pix - boundary_pix) - if (x_index == patch_num_x - 1) and (y_index != patch_num_y - 1): - output[ - :, :, y_start : y_start + patch_shape_y - 2 * boundary_pix, x_start: - ] += input[ - (x_index * patch_num_y + y_index) - * batch_size : (x_index * patch_num_y + y_index + 1) - * batch_size, - :, - boundary_pix : patch_shape_y - boundary_pix, - boundary_pix : residual_x + boundary_pix, - ] - count_map[ - :, :, y_start : y_start + patch_shape_y - 2 * boundary_pix, x_start: - ] += one_map[ - :, - :, - boundary_pix : patch_shape_y - boundary_pix, - boundary_pix : residual_x + boundary_pix, - ] - elif (y_index == patch_num_y - 1) and ((x_index != patch_num_x - 1)): - output[ - :, :, y_start:, x_start : x_start + patch_shape_x - 2 * boundary_pix - ] += input[ - (x_index * patch_num_y + y_index) - * batch_size : (x_index * patch_num_y + y_index + 1) - * batch_size, - :, - boundary_pix : residual_y + boundary_pix, - boundary_pix : patch_shape_x - boundary_pix, - ] - count_map[ - :, :, y_start:, x_start : x_start + patch_shape_x - 2 * boundary_pix - ] += one_map[ - :, - :, - boundary_pix : residual_y + boundary_pix, - boundary_pix : patch_shape_x - boundary_pix, - ] - elif x_index == patch_num_x - 1 and y_index == patch_num_y - 1: - output[:, :, y_start:, x_start:] += input[ - (x_index * patch_num_y + y_index) - * batch_size : (x_index * patch_num_y + y_index + 1) - * batch_size, - :, - boundary_pix : residual_y + boundary_pix, - boundary_pix : residual_x + boundary_pix, - ] - count_map[:, :, y_start:, x_start:] += one_map[ - :, - :, - boundary_pix : residual_y + boundary_pix, - boundary_pix : residual_x + boundary_pix, - ] - else: - output[ - :, - :, - y_start : y_start + patch_shape_y - 2 * boundary_pix, - x_start : x_start + patch_shape_x - 2 * boundary_pix, - ] += input[ - (x_index * patch_num_y + y_index) - * batch_size : (x_index * patch_num_y + y_index + 1) - * batch_size, - :, - boundary_pix : patch_shape_y - boundary_pix, - boundary_pix : patch_shape_x - boundary_pix, - ] - count_map[ - :, - :, - y_start : y_start + patch_shape_y - 2 * boundary_pix, - x_start : x_start + patch_shape_x - 2 * boundary_pix, - ] += one_map[ - :, - :, - boundary_pix : patch_shape_y - boundary_pix, - boundary_pix : patch_shape_x - boundary_pix, - ] - return output / count_map +from physicsnemo.utils.patching import GridPatching2D def stochastic_sampler( - net: Any, + net: torch.nn.Module, latents: Tensor, img_lr: Tensor, class_labels: Optional[Tensor] = None, randn_like: Callable[[Tensor], Tensor] = torch.randn_like, - img_shape: int = 448, - patch_shape: int = 448, - overlap_pix: int = 4, - boundary_pix: int = 2, + patching: Optional[GridPatching2D] = None, mean_hr: Optional[Tensor] = None, lead_time_label: Optional[Tensor] = None, num_steps: int = 18, @@ -308,31 +42,61 @@ def stochastic_sampler( S_noise: float = 1, ) -> Tensor: """ - Proposed EDM sampler (Algorithm 2) with minor changes to enable super-resolution and patch-based diffusion. + Proposed EDM sampler (Algorithm 2) with minor changes to enable + super-resolution and patch-based diffusion. Parameters ---------- - net : Any - The neural network model that generates denoised images from noisy inputs. + net : torch.nn.Module + The neural network model that generates denoised images from noisy + inputs. + Expected signature: `net(x, x_lr, t_hat, class_labels, + lead_time_label=lead_time_label, embedding_selector=embedding_selector)`, + where: + x (torch.Tensor): Noisy input of shape (batch_size, C_out, H, W) + x_lr (torch.Tensor): Conditioning input of shape (batch_size, C_cond, H, W) + t_hat (torch.Tensor): Noise level of shape (batch_size, 1, 1, 1) or scalar + class_labels (torch.Tensor, optional): Optional class labels + lead_time_label (torch.Tensor, optional): Optional lead time labels + embedding_selector (callable, optional): Function to select + positional embeddings. Used for patch-based diffusion. + Returns: + torch.Tensor: Denoised prediction of shape (batch_size, C_out, H, W) + + Required attributes: + sigma_min (float): Minimum supported noise level for the model + sigma_max (float): Maximum supported noise level for the model + round_sigma (callable): Method to convert sigma values to tensor representation latents : Tensor - The latent variables (e.g., noise) used as the initial input for the sampler. + The latent variables (e.g., noise) used as the initial input for the + sampler. Has shape (batch_size, C_out, img_shape_y, img_shape_x). img_lr : Tensor - Low-resolution input image for conditioning the super-resolution process. + Low-resolution input image for conditioning the super-resolution + process. Must have shape (batch_size, C_lr, img_lr_ shape_y, + img_lr_shape_x). class_labels : Optional[Tensor], optional - Class labels for conditional generation, if required by the model. By default None. + Class labels for conditional generation, if required by the model. By + default None. randn_like : Callable[[Tensor], Tensor] - Function to generate random noise with the same shape as the input tensor. + Function to generate random noise with the same shape as the input + tensor. By default torch.randn_like. - img_shape : int - The height and width of the full image (assumed to be square). By default 448. - patch_shape : int - The height and width of each patch (assumed to be square). By default 448. - overlap_pix : int - Number of overlapping pixels between adjacent patches. By default 4. - boundary_pix : int - Number of pixels to be cropped as a boundary from each patch. By default 2. + patching : Optional[GridPatching2D], optional + A patching utility for patch-based diffusion. Implements methods to + extract patches from an image and batch the patches along `dim=0`. + Should also implement a `fuse` method to reconstruct the original image + from a batch of patches. See + :class:`physicsnemo.utils.patching.GridPatching2D` for details. By + default None, in which case non-patched diffusion is used. mean_hr : Optional[Tensor], optional - Optional tensor containing mean high-resolution images for conditioning. By default None. + Optional tensor containing mean high-resolution images for + conditioning. Must have same height and width as `img_lr`, with shape + (B_hr, C_hr, img_lr_shape_y, img_lr_shape_x) where the batch dimension + B_hr can be either 1, either equal to batch_size, or can be omitted. If + B_hr = 1 or is omitted, `mean_hr` will be expanded to match the shape + of `img_lr`. By default None. + lead_time_label : Optional[Tensor], optional + Optional lead time labels. By default None. num_steps : int Number of time steps for the sampler. By default 18. sigma_min : float @@ -342,7 +106,8 @@ def stochastic_sampler( rho : float Exponent used in the time step discretization. By default 7. S_churn : float - Churn parameter controlling the level of noise added in each step. By default 0. + Churn parameter controlling the level of noise added in each step. By + default 0. S_min : float Minimum time step for applying churn. By default 0. S_max : float @@ -353,19 +118,42 @@ def stochastic_sampler( Returns ------- Tensor - The final denoised image produced by the sampler. + The final denoised image produced by the sampler. Same shape as + `latents`: (batch_size, C_out, img_shape_y, img_shape_x). + + See Also + -------- + :class:`physicsnemo.models.diffusion.EDMPrecondSuperResolution`: A model + wrapper that provides preconditioning for super-resolution diffusion + models and implements the required interface for this sampler. """ # Adjust noise levels based on what's supported by the network. - "Proposed EDM sampler (Algorithm 2) with minor changes to enable super-resolution." + # Proposed EDM sampler (Algorithm 2) with minor changes to enable + # super-resolution/ sigma_min = max(sigma_min, net.sigma_min) sigma_max = min(sigma_max, net.sigma_max) - if isinstance(img_shape, tuple): - img_shape_y, img_shape_x = img_shape - else: - img_shape_x = img_shape_y = img_shape - if patch_shape > img_shape_x or patch_shape > img_shape_y: - patch_shape = min(img_shape_x, img_shape_y) + + # Safety check on type of patching + if patching is not None and not isinstance(patching, GridPatching2D): + raise ValueError("patching must be an instance of GridPatching2D.") + + # Safety check: if patching is used then img_lr and latents must have same + # height and width, otherwise there is mismatch in the number + # of patches extracted to form the final batch_size. + if patching: + if img_lr.shape[-2:] != latents.shape[-2:]: + raise ValueError( + f"img_lr and latents must have the same height and width, " + f"but found {img_lr.shape[-2:]} vs {latents.shape[-2:]}. " + ) + # img_lr and latents must also have the same batch_size, otherwise mismatch + # when processed by the network + if img_lr.shape[0] != latents.shape[0]: + raise ValueError( + f"img_lr and latents must have the same batch size, but found " + f"{img_lr.shape[0]} vs {latents.shape[0]}." + ) # Time step discretization. step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) @@ -379,46 +167,32 @@ def stochastic_sampler( [net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] ) # t_N = 0 - b = latents.shape[0] - Nx = torch.arange(img_shape_x) - Ny = torch.arange(img_shape_y) - grid = torch.stack(torch.meshgrid(Ny, Nx, indexing="ij"), dim=0)[ - None, - ].expand(b, -1, -1, -1) + batch_size = img_lr.shape[0] # conditioning = [mean_hr, img_lr, global_lr, pos_embd] - batch_size = img_lr.shape[0] x_lr = img_lr if mean_hr is not None: + if mean_hr.shape[-2:] != img_lr.shape[-2:]: + raise ValueError( + f"mean_hr and img_lr must have the same height and width, " + f"but found {mean_hr.shape[-2:]} vs {img_lr.shape[-2:]}." + ) x_lr = torch.cat((mean_hr.expand(x_lr.shape[0], -1, -1, -1), x_lr), dim=1) - global_index = None # input and position padding + patching - if patch_shape != img_shape_x or patch_shape != img_shape_y: - input_interp = torch.nn.functional.interpolate( - img_lr, (patch_shape, patch_shape), mode="bilinear" - ) - x_lr = image_batching( - x_lr, - img_shape_y, - img_shape_x, - patch_shape, - patch_shape, - batch_size, - overlap_pix, - boundary_pix, - input_interp, - ) - global_index = image_batching( - grid.float(), - img_shape_y, - img_shape_x, - patch_shape, - patch_shape, - batch_size, - overlap_pix, - boundary_pix, - ).int() + if patching: + # Patched conditioning [x_lr, mean_hr] + # (batch_size * patch_num, C_in + C_out, patch_shape_y, patch_shape_x) + x_lr = patching.apply(input=x_lr, additional_input=img_lr) + + # Function to select the correct positional embedding for each patch + def patch_embedding_selector(emb): + # emb: (N_pe, image_shape_y, image_shape_x) + # return: (batch_size * patch_num, N_pe, patch_shape_y, patch_shape_x) + return patching.apply(emb[None].expand(batch_size, -1, -1, -1)) + + else: + patch_embedding_selector = None # Main sampling loop. x_next = latents.to(torch.float64) * t_steps[0] @@ -430,26 +204,14 @@ def stochastic_sampler( x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur) - # Euler step. Perform patching operation on score tensor if patch-based generation is used - # denoised = net(x_hat, t_hat, class_labels,lead_time_label=lead_time_label).to(torch.float64) #x_lr + # Euler step. Perform patching operation on score tensor if patch-based + # generation is used denoised = net(x_hat, t_hat, + # class_labels,lead_time_label=lead_time_label).to(torch.float64) - if patch_shape != img_shape_x or patch_shape != img_shape_y: - x_hat_batch = image_batching( - x_hat, - img_shape_y, - img_shape_x, - patch_shape, - patch_shape, - batch_size, - overlap_pix, - boundary_pix, - ) - else: - x_hat_batch = x_hat - x_hat_batch = x_hat_batch.to(latents.device) + x_hat_batch = (patching.apply(input=x_hat) if patching else x_hat).to( + latents.device + ) x_lr = x_lr.to(latents.device) - if global_index is not None: - global_index = global_index.to(latents.device) if lead_time_label is not None: denoised = net( @@ -458,7 +220,7 @@ def stochastic_sampler( t_hat, class_labels, lead_time_label=lead_time_label, - global_index=global_index, + embedding_selector=patch_embedding_selector, ).to(torch.float64) else: denoised = net( @@ -466,40 +228,24 @@ def stochastic_sampler( x_lr, t_hat, class_labels, - global_index=global_index, + embedding_selector=patch_embedding_selector, ).to(torch.float64) - if patch_shape != img_shape_x or patch_shape != img_shape_y: + if patching: + # Un-patch the denoised image + # (batch_size, C_out, img_shape_y, img_shape_x) + denoised = patching.fuse(input=denoised, batch_size=batch_size) - denoised = image_fuse( - denoised, - img_shape_y, - img_shape_x, - patch_shape, - patch_shape, - batch_size, - overlap_pix, - boundary_pix, - ) d_cur = (x_hat - denoised) / t_hat x_next = x_hat + (t_next - t_hat) * d_cur # Apply 2nd order correction. if i < num_steps - 1: - if patch_shape != img_shape_x or patch_shape != img_shape_y: - x_next_batch = image_batching( - x_next, - img_shape_y, - img_shape_x, - patch_shape, - patch_shape, - batch_size, - overlap_pix, - boundary_pix, - ) - else: - x_next_batch = x_next - # ask about this fix - x_next_batch = x_next_batch.to(latents.device) + # Patched input + # (batch_size * patch_num, C_out, patch_shape_y, patch_shape_x) + x_next_batch = (patching.apply(input=x_next) if patching else x_next).to( + latents.device + ) + if lead_time_label is not None: denoised = net( x_next_batch, @@ -507,7 +253,7 @@ def stochastic_sampler( t_next, class_labels, lead_time_label=lead_time_label, - global_index=global_index, + embedding_selector=patch_embedding_selector, ).to(torch.float64) else: denoised = net( @@ -515,19 +261,13 @@ def stochastic_sampler( x_lr, t_next, class_labels, - global_index=global_index, + embedding_selector=patch_embedding_selector, ).to(torch.float64) - if patch_shape != img_shape_x or patch_shape != img_shape_y: - denoised = image_fuse( - denoised, - img_shape_y, - img_shape_x, - patch_shape, - patch_shape, - batch_size, - overlap_pix, - boundary_pix, - ) + if patching: + # Un-patch the denoised image + # (batch_size, C_out, img_shape_y, img_shape_x) + denoised = patching.fuse(input=denoised, batch_size=batch_size) + d_prime = (x_next - denoised) / t_next x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) return x_next diff --git a/physicsnemo/utils/generative/utils.py b/physicsnemo/utils/generative/utils.py index dcbb127e6f..347457c368 100644 --- a/physicsnemo/utils/generative/utils.py +++ b/physicsnemo/utils/generative/utils.py @@ -29,7 +29,7 @@ import sys import types import warnings -from typing import Any, List, Tuple, Union +from typing import Any, Iterator, List, Tuple, Union import cftime import numpy as np @@ -553,14 +553,37 @@ def decorator(*args, **kwargs): # indefinitely, shuffling items as it goes. -class InfiniteSampler(torch.utils.data.Sampler): # pragma: no cover - """ - Sampler for torch.utils.data.DataLoader that loops over the dataset - indefinitely, shuffling items as it goes. +class InfiniteSampler(torch.utils.data.Sampler[int]): # pragma: no cover + """Sampler for torch.utils.data.DataLoader that loops over the dataset indefinitely. + + This sampler yields indices indefinitely, optionally shuffling items as it goes. + It can also perform distributed sampling when rank and num_replicas are specified. + + Parameters + ---------- + dataset : torch.utils.data.Dataset + The dataset to sample from + rank : int, default=0 + The rank of the current process within num_replicas processes + num_replicas : int, default=1 + The number of processes participating in distributed sampling + shuffle : bool, default=True + Whether to shuffle the indices + seed : int, default=0 + Random seed for reproducibility when shuffling + window_size : float, default=0.5 + Fraction of dataset to use as window for shuffling. Must be between 0 and 1. + A larger window means more thorough shuffling but slower iteration. """ def __init__( - self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5 + self, + dataset: torch.utils.data.Dataset, + rank: int = 0, + num_replicas: int = 1, + shuffle: bool = True, + seed: int = 0, + window_size: float = 0.5, ): if not len(dataset) > 0: raise ValueError("Dataset must contain at least one item") @@ -578,7 +601,7 @@ def __init__( self.seed = seed self.window_size = window_size - def __iter__(self): + def __iter__(self) -> Iterator[int]: order = np.arange(len(self.dataset)) rnd = None window = 0 diff --git a/physicsnemo/utils/patching.py b/physicsnemo/utils/patching.py new file mode 100644 index 0000000000..a3570fd50c --- /dev/null +++ b/physicsnemo/utils/patching.py @@ -0,0 +1,745 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +import random +import warnings +from abc import ABC, abstractmethod +from typing import Optional, Tuple + +import torch +from einops import rearrange +from torch import Tensor + +""" +This module defines utilities, including classes and functions, for domain +decomposition. +""" + + +class BasePatching2D(ABC): + """ + Abstract base class for 2D image patching operations. + + This class provides a foundation for implementing various image patching + strategies. + It handles basic validation and provides abstract methods that must be + implemented by subclasses. + + Parameters + ---------- + img_shape : Tuple[int, int] + The height and width of the input images (img_shape_y, img_shape_x). + patch_shape : Tuple[int, int] + The height and width of the patches (patch_shape_y, patch_shape_x) to + extract. + """ + + def __init__( + self, img_shape: Tuple[int, int], patch_shape: Tuple[int, int] + ) -> None: + # Check that img_shape and patch_shape are 2D + if len(img_shape) != 2: + raise ValueError(f"img_shape must be 2D, got {len(img_shape)}D") + if len(patch_shape) != 2: + raise ValueError(f"patch_shape must be 2D, got {len(patch_shape)}D") + + # Make sure patches fit within the image + if any(p > i for p, i in zip(patch_shape, img_shape)): + warnings.warn( + f"Patch shape {patch_shape} is larger than " + f"image shape {img_shape}. " + f"Patches will be cropped to fit within the image." + ) + self.img_shape = img_shape + self.patch_shape = tuple(min(p, i) for p, i in zip(patch_shape, img_shape)) + + @abstractmethod + def apply(self, input: Tensor, **kwargs) -> Tensor: + """ + Apply the patching operation to the input tensor. + + Parameters + ---------- + input : Tensor + Input tensor of shape (batch_size, channels, img_shape_y, + img_shape_x). + **kwargs : dict + Additional keyword arguments specific to the patching + implementation. + + Returns + ------- + Tensor + Patched tensor, shape depends on specific implementation. + """ + pass + + def fuse(self, input: Tensor, **kwargs) -> Tensor: + """ + Fuse patches back into a complete image. + + Parameters + ---------- + input : Tensor + Input tensor containing patches. + **kwargs : dict + Additional keyword arguments specific to the fusion implementation. + + Returns + ------- + Tensor + Fused tensor, shape depends on specific implementation. + + Raises + ------ + NotImplementedError + If the subclass does not implement this method. + """ + raise NotImplementedError("'fuse' method must be implemented in subclasses.") + + def global_index(self, batch_size: int) -> Tensor: + """ + Returns a tensor containing the global indices for each patch. + + Global indices correspond to (y, x) global grid coordinates of each + element within the original image (before patching). It is typically + used to keep track of the original position of each patch in the + original image. + + Parameters + ---------- + batch_size : int + The size of the batch of images to patch. + + Returns + ------- + Tensor + A tensor of shape (batch_size * self.patch_num, 2, patch_shape_y, + patch_shape_x). `global_index[:, 0, :, :]` contains the + y-coordinate (height), and `global_index[:, 1, :, :]` contains the + x-coordinate (width). + """ + Ny = torch.arange(self.img_shape[0]).int() + Nx = torch.arange(self.img_shape[1]).int() + grid = torch.stack(torch.meshgrid(Ny, Nx, indexing="ij"), dim=0)[ + None, + ].expand(batch_size, -1, -1, -1) + global_index = self.apply(grid) + return global_index + + +class RandomPatching2D(BasePatching2D): + """ + Class for randomly extracting patches from 2D images. + + This class provides utilities to randomly extract patches from images + represented as 4D tensors. It maintains a list of random patch indices + that can be reset as needed. + + Parameters + ---------- + img_shape : Tuple[int, int] + The height and width of the input images (img_shape_y, img_shape_x). + patch_shape : Tuple[int, int] + The height and width of the patches (patch_shape_y, patch_shape_x) to + extract. + patch_num : int + The number of patches to extract. + + Attributes + ---------- + patch_indices : List[Tuple[int, int]] + The indices of the patches to extract from the images. These indices + correspond to the (y, x) coordinates of the lower left corner of each + patch. + + See Also + -------- + :class:`physicsnemo.utils.patching.BasePatching2D` + The base class providing the patching interface. + :class:`physicsnemo.utils.patching.GridPatching2D` + Alternative patching strategy using deterministic patch locations. + """ + + def __init__( + self, img_shape: Tuple[int, int], patch_shape: Tuple[int, int], patch_num: int + ) -> None: + """ + Initialize the RandomPatching2D object with the provided image shape, + patch shape, and number of patches to extract. + + Parameters + ---------- + img_shape : Tuple[int, int] + The height and width of the input images (img_shape_y, + img_shape_x). + patch_shape : Tuple[int, int] + The height and width of the patches (patch_shape_y, patch_shape_x) + to extract. + patch_num : int + The number of patches to extract. + + Returns + ------- + None + """ + super().__init__(img_shape, patch_shape) + self._patch_num = patch_num + # Generate the indices of the patches to extract + self.reset_patch_indices() + + @property + def patch_num(self) -> int: + """ + Get the number of patches to extract. + + Returns + ------- + int + The number of patches to extract. + """ + return self._patch_num + + def set_patch_sum(self, value: int) -> None: + """ + Set the number of patches to extract and reset patch indices. + This is the only way to modify the patch_num value. + + Parameters + ---------- + value : int + The new number of patches to extract. + """ + self._patch_num = value + self.reset_patch_indices() + + def reset_patch_indices(self) -> None: + """ + Generate new random indices for the patches to extract. These are the + starting indices of the patches to extract (upper left corner). + + Returns + ------- + None + """ + self.patch_indices = [ + ( + random.randint(0, self.img_shape[0] - self.patch_shape[0]), + random.randint(0, self.img_shape[1] - self.patch_shape[1]), + ) + for _ in range(self.patch_num) + ] + return + + def apply( + self, + input: Tensor, + additional_input: Optional[Tensor] = None, + ) -> Tensor: + """ + Applies the patching operation by extracting patches specified by + `self.patch_indices` from the `input` Tensor. Extracted patches are + batched along the first dimension of the output. The layout of the + output assumes that for any i, `out[B * i: B * (i + 1)]` + corresponds to the same patch exacted from each batch element of + `input`. + + Arguments + --------- + input : Tensor + The input tensor representing the full image with shape + (batch_size, channels_in, img_shape_y, img_shape_x). + additional_input : Optional[Tensor], optional + If provided, it is concatenated to each patch along `dim=1`. + Must have same batch size as `input`. Bilinear interpolation + is used to interpolate `additional_input` onto a 2D grid of shape + (patch_shape_y, patch_shape_x). + + Returns + ------- + Tensor + A tensor of shape (batch_size * self.patch_num, channels [+ + additional_channels], patch_shape_y, patch_shape_x). If + `additional_input` is provided, its channels are concatenated + along the channel dimension. + """ + B = input.shape[0] + out = torch.zeros( + B * self.patch_num, + ( + input.shape[1] + + (additional_input.shape[1] if additional_input is not None else 0) + ), + self.patch_shape[0], + self.patch_shape[1], + device=input.device, + ) + if additional_input is not None: + add_input_interp = torch.nn.functional.interpolate( + input=additional_input, size=self.patch_shape, mode="bilinear" + ) + for i, (py, px) in enumerate(self.patch_indices): + if additional_input is not None: + out[B * i : B * (i + 1),] = torch.cat( + ( + input[ + :, + :, + py : py + self.patch_shape[0], + px : px + self.patch_shape[1], + ], + add_input_interp, + ), + dim=1, + ) + else: + out[B * i : B * (i + 1),] = input[ + :, + :, + py : py + self.patch_shape[0], + px : px + self.patch_shape[1], + ] + return out + + +class GridPatching2D(BasePatching2D): + """ + Class for deterministically extracting patches from 2D images in a grid pattern. + + This class provides utilities to extract patches from images in a + deterministic manner, with configurable overlap and boundary pixels. + The patches are extracted in a grid-like pattern covering the entire image. + + Parameters + ---------- + img_shape : Tuple[int, int] + The height and width of the input images (img_shape_y, img_shape_x). + patch_shape : Tuple[int, int] + The height and width of the patches (patch_shape_y, patch_shape_x) to + extract. + overlap_pix : int, optional + Number of pixels to overlap between adjacent patches, by default 0. + boundary_pix : int, optional + Number of pixels to crop as boundary from each patch, by default 0. + + Attributes + ---------- + patch_num : int + Total number of patches that will be extracted from the image, + calculated as patch_num_x * patch_num_y. + + See Also + -------- + :class:`physicsnemo.utils.patching.BasePatching2D` + The base class providing the patching interface. + :class:`physicsnemo.utils.patching.RandomPatching2D` + Alternative patching strategy using random patch locations. + """ + + def __init__( + self, + img_shape: Tuple[int, int], + patch_shape: Tuple[int, int], + overlap_pix: int = 0, + boundary_pix: int = 0, + ): + super().__init__(img_shape, patch_shape) + self.overlap_pix = overlap_pix + self.boundary_pix = boundary_pix + patch_num_x = math.ceil( + img_shape[1] / (patch_shape[1] - overlap_pix - boundary_pix) + ) + patch_num_y = math.ceil( + img_shape[0] / (patch_shape[0] - overlap_pix - boundary_pix) + ) + self.patch_num = patch_num_x * patch_num_y + + def apply( + self, + input: Tensor, + additional_input: Optional[Tensor] = None, + ) -> Tensor: + """ + Apply deterministic patching to the input tensor. + + Splits the input tensor into patches in a grid-like pattern. Can + optionally concatenate additional interpolated data to each patch. + Extracted patches are batched along the first dimension of the output. + The layout of the output assumes that for any i, `out[B * i: B * (i + 1)]` + corresponds to the same patch exacted from each batch element of + `input`. The patches can be reconstructed back into the original image + using the fuse method. + + Parameters + ---------- + input : Tensor + Input tensor of shape (batch_size, channels, img_shape_y, + img_shape_x). + additional_input : Optional[Tensor], optional + Additional data to concatenate to each patch. Will be interpolated + to match patch dimensions. Shape must be (batch_size, + additional_channels, H, W), by default None. + + Returns + ------- + Tensor + Tensor containing patches with shape (batch_size * patch_num, + channels [+ additional_channels], patch_shape_y, patch_shape_x). + If additional_input is provided, its channels are concatenated + along the channel dimension. + + See Also + -------- + :func:`physicsnemo.utils.patching.image_batching` + The underlying function used to perform the patching operation. + """ + if additional_input is not None: + add_input_interp = torch.nn.functional.interpolate( + input=additional_input, size=self.patch_shape, mode="bilinear" + ) + else: + add_input_interp = None + out = image_batching( + input=input, + patch_shape_y=self.patch_shape[0], + patch_shape_x=self.patch_shape[1], + overlap_pix=self.overlap_pix, + boundary_pix=self.boundary_pix, + input_interp=add_input_interp, + ) + return out + + def fuse(self, input: Tensor, batch_size: int) -> Tensor: + """ + Fuse patches back into a complete image. + + Reconstructs the original image by stitching together patches, + accounting for overlapping regions and boundary pixels. In overlapping + regions, values are averaged. + + Parameters + ---------- + input : Tensor + Input tensor containing patches with shape (batch_size * patch_num, + channels, patch_shape_y, patch_shape_x). + batch_size : int + The original batch size before patching. + + Returns + ------- + Tensor + Reconstructed image tensor with shape (batch_size, channels, + img_shape_y, img_shape_x). + + See Also + -------- + :func:`physicsnemo.utils.patching.image_fuse` + The underlying function used to perform the fusion operation. + """ + out = image_fuse( + input=input, + img_shape_y=self.img_shape[0], + img_shape_x=self.img_shape[1], + batch_size=batch_size, + overlap_pix=self.overlap_pix, + boundary_pix=self.boundary_pix, + ) + return out + + +def image_batching( + input: Tensor, + patch_shape_y: int, + patch_shape_x: int, + overlap_pix: int, + boundary_pix: int, + input_interp: Optional[Tensor] = None, +) -> Tensor: + """ + Splits a full image into a batch of patched images. + + This function takes a full image and splits it into patches, adding padding + where necessary. It can also concatenate additional interpolated data to + each patch if provided. + + Parameters + ---------- + input : Tensor + The input tensor representing the full image with shape (batch_size, + channels, img_shape_y, img_shape_x). + patch_shape_y : int + The height (y-dimension) of each image patch. + patch_shape_x : int + The width (x-dimension) of each image patch. + overlap_pix : int + The number of overlapping pixels between adjacent patches. + boundary_pix : int + The number of pixels to crop as a boundary from each patch. + input_interp : Optional[Tensor], optional + Optional additional data to concatenate to each patch with shape + (batch_size, interp_channels, patch_shape_y, patch_shape_x). + By default None. + + Returns + ------- + Tensor + A tensor containing the image patches, with shape (total_patches * + batch_size, channels [+ interp_channels], patch_shape_x, + patch_shape_y). + """ + # Infer sizes from input image + batch_size, _, img_shape_y, img_shape_x = input.shape + + # Safety check: make sure patch_shapes are large enough to accommodate + # overlaps and boundaries pixels + if (patch_shape_x - overlap_pix - boundary_pix) < 1: + raise ValueError( + f"patch_shape_x must verify patch_shape_x ({patch_shape_x}) >= " + f"1 + overlap_pix ({overlap_pix}) + boundary_pix ({boundary_pix})" + ) + if (patch_shape_y - overlap_pix - boundary_pix) < 1: + raise ValueError( + f"patch_shape_y must verify patch_shape_y ({patch_shape_y}) >= " + f"1 + overlap_pix ({overlap_pix}) + boundary_pix ({boundary_pix})" + ) + # Safety check: validate input_interp dimensions if provided + if input_interp is not None: + if input_interp.shape[0] != batch_size: + raise ValueError( + f"input_interp batch size ({input_interp.shape[0]}) must match " + f"input batch size ({batch_size})" + ) + if (input_interp.shape[2] != patch_shape_y) or ( + input_interp.shape[3] != patch_shape_x + ): + raise ValueError( + f"input_interp patch shape ({input_interp.shape[2]}, {input_interp.shape[3]}) " + f"must match specified patch shape ({patch_shape_y}, {patch_shape_x})" + ) + + # Safety check: make sure patch_shape is large enough in comparison to + # overlap_pix and boundary_pix. Otherwise, number of patches extracted by + # unfold differs from the expected number of patches. + if patch_shape_x <= overlap_pix + 2 * boundary_pix: + raise ValueError( + f"patch_shape_x ({patch_shape_x}) must verify " + f"patch_shape_x ({patch_shape_x}) > " + f"overlap_pix ({overlap_pix}) + 2 * boundary_pix ({boundary_pix})" + ) + if patch_shape_y <= overlap_pix + 2 * boundary_pix: + raise ValueError( + f"patch_shape_y ({patch_shape_y}) must verify " + f"patch_shape_y ({patch_shape_y}) > " + f"overlap_pix ({overlap_pix}) + 2 * boundary_pix ({boundary_pix})" + ) + + patch_num_x = math.ceil(img_shape_x / (patch_shape_x - overlap_pix - boundary_pix)) + patch_num_y = math.ceil(img_shape_y / (patch_shape_y - overlap_pix - boundary_pix)) + padded_shape_x = ( + (patch_shape_x - overlap_pix - boundary_pix) * (patch_num_x - 1) + + patch_shape_x + + boundary_pix + ) + padded_shape_y = ( + (patch_shape_y - overlap_pix - boundary_pix) * (patch_num_y - 1) + + patch_shape_y + + boundary_pix + ) + pad_x_right = padded_shape_x - img_shape_x - boundary_pix + pad_y_right = padded_shape_y - img_shape_y - boundary_pix + image_padding = torch.nn.ReflectionPad2d( + (boundary_pix, pad_x_right, boundary_pix, pad_y_right) + ).to( + input.device + ) # (padding_left,padding_right,padding_top,padding_bottom) + input_padded = image_padding(input) + patch_num = patch_num_x * patch_num_y + x_unfold = torch.nn.functional.unfold( + input=input_padded.view(_cast_type(input_padded)), # Cast to float + kernel_size=(patch_shape_y, patch_shape_x), + stride=( + patch_shape_y - overlap_pix - boundary_pix, + patch_shape_x - overlap_pix - boundary_pix, + ), + ).to(input_padded.dtype) + x_unfold = rearrange( + x_unfold, + "b (c p_h p_w) (nb_p_h nb_p_w) -> (nb_p_w nb_p_h b) c p_h p_w", + p_h=patch_shape_y, + p_w=patch_shape_x, + nb_p_h=patch_num_y, + nb_p_w=patch_num_x, + ) + if input_interp is not None: + input_interp_repeated = rearrange( + torch.repeat_interleave( + input=input_interp, + repeats=patch_num, + dim=0, + output_size=x_unfold.shape[0], + ), + "(b p) c h w -> (p b) c h w", + p=patch_num, + ) + return torch.cat((x_unfold, input_interp_repeated), dim=1) + else: + return x_unfold + + +def image_fuse( + input: Tensor, + img_shape_y: int, + img_shape_x: int, + batch_size: int, + overlap_pix: int, + boundary_pix: int, +) -> Tensor: + """ + Reconstructs a full image from a batch of patched images. Reverts the patching + operation performed by image_batching(). + + This function takes a batch of image patches and reconstructs the full + image by stitching the patches together. The function accounts for + overlapping and boundary pixels, ensuring that overlapping areas are + averaged. + + Parameters + ---------- + input : Tensor + The input tensor containing the image patches with shape (patch_num * batch_size, channels, patch_shape_y, patch_shape_x). + img_shape_y : int + The height (y-dimension) of the original full image. + img_shape_x : int + The width (x-dimension) of the original full image. + batch_size : int + The original batch size before patching. + overlap_pix : int + The number of overlapping pixels between adjacent patches. + boundary_pix : int + The number of pixels to crop as a boundary from each patch. + + Returns + ------- + Tensor + The reconstructed full image tensor with shape (batch_size, channels, + img_shape_y, img_shape_x). + + See Also + -------- + :func:`physicsnemo.utils.patching.image_batching` + The function this reverses, which splits images into patches. + """ + + # Infer sizes from input image shape + patch_shape_y, patch_shape_x = input.shape[2], input.shape[3] + + # Calculate the number of patches in each dimension + patch_num_x = math.ceil(img_shape_x / (patch_shape_x - overlap_pix - boundary_pix)) + patch_num_y = math.ceil(img_shape_y / (patch_shape_y - overlap_pix - boundary_pix)) + + # Calculate the shape of the input after padding + padded_shape_x = ( + (patch_shape_x - overlap_pix - boundary_pix) * (patch_num_x - 1) + + patch_shape_x + + boundary_pix + ) + padded_shape_y = ( + (patch_shape_y - overlap_pix - boundary_pix) * (patch_num_y - 1) + + patch_shape_y + + boundary_pix + ) + # Calculate the shape of the padding to add to input + pad_x_right = padded_shape_x - img_shape_x - boundary_pix + pad_y_right = padded_shape_y - img_shape_y - boundary_pix + pad = (boundary_pix, pad_x_right, boundary_pix, pad_y_right) + + # Count local overlaps between patches + input_ones = torch.ones( + (batch_size, input.shape[1], padded_shape_y, padded_shape_x), + device=input.device, + ) + overlap_count = torch.nn.functional.unfold( + input=input_ones, + kernel_size=(patch_shape_y, patch_shape_x), + stride=( + patch_shape_y - overlap_pix - boundary_pix, + patch_shape_x - overlap_pix - boundary_pix, + ), + ) + overlap_count = torch.nn.functional.fold( + input=overlap_count, + output_size=(padded_shape_y, padded_shape_x), + kernel_size=(patch_shape_y, patch_shape_x), + stride=( + patch_shape_y - overlap_pix - boundary_pix, + patch_shape_x - overlap_pix - boundary_pix, + ), + ) + + # Reshape input to make it 3D to apply fold + x = rearrange( + input, + "(nb_p_w nb_p_h b) c p_h p_w -> b (c p_h p_w) (nb_p_h nb_p_w)", + p_h=patch_shape_y, + p_w=patch_shape_x, + nb_p_h=patch_num_y, + nb_p_w=patch_num_x, + ) + # Stitch patches together (by summing over overlapping patches) + x_folded = torch.nn.functional.fold( + input=x, + output_size=(padded_shape_y, padded_shape_x), + kernel_size=(patch_shape_y, patch_shape_x), + stride=( + patch_shape_y - overlap_pix - boundary_pix, + patch_shape_x - overlap_pix - boundary_pix, + ), + ) + + # Remove padding + x_no_padding = x_folded[ + ..., pad[2] : pad[2] + img_shape_y, pad[0] : pad[0] + img_shape_x + ] + overlap_count_no_padding = overlap_count[ + ..., pad[2] : pad[2] + img_shape_y, pad[0] : pad[0] + img_shape_x + ] + + # Normalize by overlap count + return x_no_padding / overlap_count_no_padding + + +def _cast_type(input: Tensor) -> torch.dtype: + """Return float type based on input tensor type. + + Parameters + ---------- + input : Tensor + Input tensor to determine float type from + + Returns + ------- + torch.dtype + Float type corresponding to input tensor type for int32/64, + otherwise returns original dtype + """ + if input.dtype == torch.int32: + return torch.float32 + elif input.dtype == torch.int64: + return torch.float64 + else: + return input.dtype diff --git a/test/metrics/diffusion/test_losses.py b/test/metrics/diffusion/test_losses.py index 4e0a4d3b5a..0c8bf27cdb 100644 --- a/test/metrics/diffusion/test_losses.py +++ b/test/metrics/diffusion/test_losses.py @@ -14,15 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from physicsnemo.metrics.diffusion import ( EDMLoss, + RegressionLoss, RegressionLossCE, + ResidualLoss, VELoss, VELoss_dfsr, VPLoss, ) +from physicsnemo.models.diffusion import EDMPrecondSuperResolution, UNet +from physicsnemo.utils.patching import RandomPatching2D # VPLoss tests @@ -53,15 +58,10 @@ def test_sigma_method(): assert sigma_vals.shape == t.shape -def fake_net(y, sigma, labels, augment_labels=None): - return torch.tensor([1.0]) - - -def fake_condition_net(y, sigma, condition, class_labels=None, augment_labels=None): - return torch.tensor([1.0]) - - def test_call_method_vp(): + def fake_net(y, sigma, labels, augment_labels=None): + return torch.tensor([1.0]) + loss_func = VPLoss() images = torch.tensor([[[[1.0]]]]) @@ -97,6 +97,9 @@ def test_veloss_initialization(): def test_call_method_ve(): loss_func = VELoss() + def fake_net(y, sigma, labels, augment_labels=None): + return torch.tensor([1.0]) + images = torch.tensor([[[[1.0]]]]) labels = None @@ -130,6 +133,12 @@ def test_edmloss_initialization(): def test_call_method_edm(): + def fake_condition_net(y, sigma, condition, class_labels=None, augment_labels=None): + return torch.tensor([1.0]) + + def fake_net(y, sigma, labels, augment_labels=None): + return torch.tensor([1.0]) + loss_func = EDMLoss() img = torch.tensor([[[[1.0]]]]) @@ -155,79 +164,79 @@ def mock_augment_pipe(imgs): # RegressionLoss tests -# def test_regressionloss_initialization(): -# loss_func = RegressionLoss() -# assert loss_func.P_mean == -1.2 -# assert loss_func.P_std == 1.2 -# assert loss_func.sigma_data == 0.5 +def test_call_method_regressionloss(): -# loss_func = RegressionLoss(P_mean=-2.0, P_std=2.0, sigma_data=0.3) -# assert loss_func.P_mean == -2.0 -# assert loss_func.P_std == 2.0 -# assert loss_func.sigma_data == 0.3 + # With a fake network + def fake_net(input, y_lr, augment_labels=None, force_fp32=False): + return torch.tensor([1.0]) -# def fake_net(input, y_lr, sigma, labels, augment_labels=None): -# return torch.tensor([1.0]) + loss_func = RegressionLoss() + img_clean = torch.tensor([[[[1.0]]]]) + img_lr = torch.tensor([[[[0.5]]]]) -# def test_call_method(): -# loss_func = RegressionLoss() + # Without augmentation + loss_value = loss_func(fake_net, img_clean, img_lr) + assert isinstance(loss_value, torch.Tensor) -# img_clean = torch.tensor([[[[1.0]]]]) -# img_lr = torch.tensor([[[[0.5]]]]) -# labels = None + # With augmentation + def mock_augment_pipe(imgs): + return imgs, None -# # Without augmentation -# loss_value = loss_func(fake_net, img_clean, img_lr, labels) -# assert isinstance(loss_value, torch.Tensor) + loss_value_with_augmentation = loss_func( + fake_net, img_clean, img_lr, mock_augment_pipe + ) + assert isinstance(loss_value_with_augmentation, torch.Tensor) -# # With augmentation -# def mock_augment_pipe(imgs): -# return imgs, None -# loss_value_with_augmentation = loss_func( -# fake_net, img_clean, img_lr, labels, mock_augment_pipe -# ) -# assert isinstance(loss_value_with_augmentation, torch.Tensor) +# More realistic test with a UNet model +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_call_method_regressionloss_with_unet(device): + + res, inc, outc = 64, 2, 3 + model = UNet( + img_resolution=res, + img_in_channels=inc, + img_out_channels=outc, + model_type="SongUNet", + ).to(device) + img_clean = torch.ones([1, outc, res, res]).to(device) + img_lr = torch.randn([1, inc, res, res]).to(device) + loss_func = RegressionLoss() + loss_value = loss_func(model, img_clean, img_lr) + assert isinstance(loss_value, torch.Tensor) + assert loss_value.shape == img_clean.shape + # RegressionLossCE tests def test_regressionlossce_initialization(): loss_func = RegressionLossCE() - assert loss_func.P_mean == -1.2 - assert loss_func.P_std == 1.2 - assert loss_func.sigma_data == 0.5 assert loss_func.prob_channels == [4, 5, 6, 7, 8] - loss_func = RegressionLossCE( - P_mean=-2.0, P_std=2.0, sigma_data=0.3, prob_channels=[1, 2, 3, 4] - ) - assert loss_func.P_mean == -2.0 - assert loss_func.P_std == 2.0 - assert loss_func.sigma_data == 0.3 + loss_func = RegressionLossCE(prob_channels=[1, 2, 3, 4]) assert loss_func.prob_channels == [1, 2, 3, 4] -def leadtime_fake_net( - input, y_lr, sigma, labels, lead_time_label=None, augment_labels=None -): - return torch.zeros(1, 4, 29, 29) +def test_call_method_regressionlossce(): + def leadtime_fake_net(input, y_lr, lead_time_label=None, augment_labels=None): + return torch.zeros(1, 4, 29, 29) - -def test_call_method(): prob_channels = [0, 2] loss_func = RegressionLossCE(prob_channels=prob_channels) img_clean = torch.zeros(1, 4, 29, 29) img_lr = torch.zeros(1, 4, 29, 29) - labels = None lead_time_label = None # Without augmentation loss_value = loss_func( - leadtime_fake_net, img_clean, img_lr, lead_time_label, labels + leadtime_fake_net, + img_clean, + img_lr, + lead_time_label, ) assert isinstance(loss_value, torch.Tensor) assert loss_value.shape == (1, 3, 29, 29) @@ -237,96 +246,238 @@ def mock_augment_pipe(imgs): return imgs, None loss_value_with_augmentation = loss_func( - leadtime_fake_net, img_clean, img_lr, lead_time_label, labels, mock_augment_pipe + leadtime_fake_net, img_clean, img_lr, lead_time_label, mock_augment_pipe ) assert isinstance(loss_value_with_augmentation, torch.Tensor) assert loss_value.shape == (1, 3, 29, 29) -# MixtureLoss tests - - -# def test_mixtureloss_initialization(): -# loss_func = MixtureLoss() -# assert loss_func.P_mean == -1.2 -# assert loss_func.P_std == 1.2 -# assert loss_func.sigma_data == 0.5 - -# loss_func = MixtureLoss(P_mean=-2.0, P_std=2.0, sigma_data=0.3) -# assert loss_func.P_mean == -2.0 -# assert loss_func.P_std == 2.0 -# assert loss_func.sigma_data == 0.3 - - -# def fake_net(latent, y_lr, sigma, labels, augment_labels=None): -# return torch.tensor([1.0]) - - -# def test_call_method(): -# loss_func = MixtureLoss() - -# img_clean = torch.tensor([[[[1.0]]]]) -# img_lr = torch.tensor([[[[0.5]]]]) -# labels = None - -# # Without augmentation -# loss_value = loss_func(fake_net, img_clean, img_lr, labels) -# assert isinstance(loss_value, torch.Tensor) - -# # With augmentation -# def mock_augment_pipe(imgs): -# return imgs, None - -# loss_value_with_augmentation = loss_func( -# fake_net, img_clean, img_lr, labels, mock_augment_pipe -# ) -# assert isinstance(loss_value_with_augmentation, torch.Tensor) - - -# ResLoss tests +# More realistic test with a UNet model and lead-time conditioning +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_call_method_regressionlossce_with_unet(device): + res, inc, outc = 64, 3, 4 + N_pos, lead_time_channels = 2, 4 + prob_channels = [0, 2] + model = UNet( + img_resolution=res, + img_in_channels=inc + N_pos + lead_time_channels, + img_out_channels=outc, + model_type="SongUNetPosLtEmbd", + gridtype="test", + lead_time_channels=lead_time_channels, + prob_channels=prob_channels, + N_grid_channels=N_pos, + ).to(device) + + img_clean = torch.ones([1, outc, res, res]).to(device) + img_lr = torch.randn([1, inc, res, res]).to(device) + lead_time_label = torch.tensor(8).to(device) + loss_func = RegressionLossCE(prob_channels=prob_channels) + loss_value = loss_func(model, img_clean, img_lr, lead_time_label=lead_time_label) + assert isinstance(loss_value, torch.Tensor) + assert loss_value.shape == (1, outc - len(prob_channels) + 1, res, res) -# def test_resloss_initialization(): -# # Mock the model loading -# ResLoss.unet = torch.nn.Linear(1, 1).cuda() -# loss_func = ResLoss() -# assert loss_func.P_mean == 0.0 -# assert loss_func.P_std == 1.2 -# assert loss_func.sigma_data == 0.5 +# ResidualLoss tests -# loss_func = ResLoss(P_mean=-2.0, P_std=2.0, sigma_data=0.3) -# assert loss_func.P_mean == -2.0 -# assert loss_func.P_std == 2.0 -# assert loss_func.sigma_data == 0.3 +def test_residualloss_initialization(): + # Mock regression network + regression_net = torch.nn.Linear(1, 1) -# def fake_net(latent, y_lr, sigma, labels, augment_labels=None): -# return torch.tensor([1.0]) + # Test default parameters + loss_func = ResidualLoss( + regression_net=regression_net, + ) + assert loss_func.P_mean == 0.0 + assert loss_func.P_std == 1.2 + assert loss_func.sigma_data == 0.5 + assert loss_func.hr_mean_conditioning is False + + # Test custom parameters + loss_func = ResidualLoss( + regression_net=regression_net, + P_mean=1.0, + P_std=2.0, + sigma_data=0.3, + hr_mean_conditioning=True, + ) + assert loss_func.P_mean == 1.0 + assert loss_func.P_std == 2.0 + assert loss_func.sigma_data == 0.3 + assert loss_func.hr_mean_conditioning is True + + +def test_residualloss_call_method(): + def fake_residual_net( + x, + img_lr, + sigma, + labels=None, + global_index=None, + embedding_selector=None, + augment_labels=None, + ): + return torch.zeros_like(x) + + # Mock regression network that returns scaled input + class DummyRegNet(torch.nn.Module): + def forward(self, x, *args, **kwargs): + return 0.9 * x + + regression_net = DummyRegNet() + loss_func = ResidualLoss( + regression_net=regression_net, + ) + # Create test inputs + batch_size = 2 + channels = 3 + img_clean = torch.randn(batch_size, channels, 32, 32) + img_lr = torch.randn(batch_size, channels, 32, 32) -# def test_call_method(): -# # Mock the model loading -# ResLoss.unet = torch.nn.Linear(1, 1).cuda() + # Test without patching or augmentation + loss_value = loss_func(fake_residual_net, img_clean, img_lr) + assert isinstance(loss_value, torch.Tensor) + assert loss_value.shape == (batch_size, channels, 32, 32) -# loss_func = ResLoss() + # Test with augmentation + def mock_augment_pipe(imgs): + return imgs, None -# img_clean = torch.tensor([[[[1.0]]]]) -# img_lr = torch.tensor([[[[0.5]]]]) -# labels = None + loss_value_with_augmentation = loss_func( + fake_residual_net, img_clean, img_lr, augment_pipe=mock_augment_pipe + ) + assert isinstance(loss_value_with_augmentation, torch.Tensor) + assert loss_value_with_augmentation.shape == (batch_size, channels, 32, 32) -# # Without augmentation -# loss_value = loss_func(fake_net, img_clean, img_lr, labels) -# assert isinstance(loss_value, torch.Tensor) + # Test with patching + patch_num = 4 + patch_shape = (16, 16) + patching = RandomPatching2D( + img_shape=(32, 32), patch_shape=patch_shape, patch_num=patch_num + ) + loss_value_with_patching = loss_func( + fake_residual_net, img_clean, img_lr, patching=patching + ) + assert isinstance(loss_value_with_patching, torch.Tensor) + # Shape should be (batch_size * patch_num, channels, patch_shape_y, patch_shape_x) + expected_shape = (batch_size * patch_num, channels, patch_shape[0], patch_shape[1]) + assert loss_value_with_patching.shape == expected_shape + + # Test error on invalid patching object + with pytest.raises(ValueError): + loss_func( + fake_residual_net, img_clean, img_lr, patching="invalid patching object" + ) + + +# More realistic test with a UNet model +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_call_method_residualloss_with_unet(device): + + res, inc, outc = 64, 2, 3 + N_pos = 2 + regression_model = UNet( + img_resolution=res, + img_in_channels=inc + N_pos, + img_out_channels=outc, + model_type="SongUNetPosEmbd", + N_grid_channels=N_pos, + gridtype="test", + ).to(device) + diffusion_model = EDMPrecondSuperResolution( + img_resolution=res, + img_in_channels=inc + N_pos, + img_out_channels=outc, + model_type="SongUNetPosEmbd", + N_grid_channels=N_pos, + gridtype="test", + ).to(device) + + img_clean = torch.ones([1, outc, res, res]).to(device) + img_lr = torch.randn([1, inc, res, res]).to(device) + + # Without hr_mean_conditioning + loss_func = ResidualLoss( + regression_net=regression_model, hr_mean_conditioning=False + ) + loss_value = loss_func(diffusion_model, img_clean, img_lr) + assert isinstance(loss_value, torch.Tensor) + assert loss_value.shape == img_clean.shape + + +# Test with UNets and hr_mean_conditioning +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_call_method_residualloss_with_unet_hr_mean_conditioning(device): + res, inc, outc = 64, 2, 3 + N_pos = 2 + regression_model = UNet( + img_resolution=res, + img_in_channels=inc + N_pos, + img_out_channels=outc, + model_type="SongUNetPosEmbd", + N_grid_channels=N_pos, + gridtype="test", + ).to(device) + diffusion_model = EDMPrecondSuperResolution( + img_resolution=res, + img_in_channels=inc + N_pos + outc, + img_out_channels=outc, + model_type="SongUNetPosEmbd", + N_grid_channels=N_pos, + gridtype="test", + ).to(device) + + img_clean = torch.ones([1, outc, res, res]).to(device) + img_lr = torch.randn([1, inc, res, res]).to(device) + + # With hr_mean_conditioning + loss_func = ResidualLoss(regression_net=regression_model, hr_mean_conditioning=True) + loss_value = loss_func(diffusion_model, img_clean, img_lr) + assert isinstance(loss_value, torch.Tensor) + assert loss_value.shape == img_clean.shape -# # With augmentation -# def mock_augment_pipe(imgs): -# return imgs, None -# loss_value_with_augmentation = loss_func( -# fake_net, img_clean, img_lr, labels, mock_augment_pipe -# ) -# assert isinstance(loss_value_with_augmentation, torch.Tensor) +# Test with UNets, hr_mean_conditioning, and lead-time aware embedding +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_call_method_residualloss_with_lt_unet_hr_mean_conditioning(device): + res, inc, outc = 64, 2, 3 + N_pos, lead_time_channels = 2, 4 + prob_channels = [0, 2] + regression_model = UNet( + img_resolution=res, + img_in_channels=inc + N_pos + lead_time_channels, + img_out_channels=outc, + model_type="SongUNetPosLtEmbd", + N_grid_channels=N_pos, + gridtype="test", + lead_time_channels=lead_time_channels, + prob_channels=prob_channels, + ).to(device) + diffusion_model = EDMPrecondSuperResolution( + img_resolution=res, + img_in_channels=inc + outc + N_pos + lead_time_channels, + img_out_channels=outc, + model_type="SongUNetPosLtEmbd", + N_grid_channels=N_pos, + gridtype="test", + lead_time_channels=lead_time_channels, + prob_channels=prob_channels, + ).to(device) + + img_clean = torch.ones([1, outc, res, res]).to(device) + img_lr = torch.randn([1, inc, res, res]).to(device) + lead_time_label = torch.tensor(8).to(device) + + # With hr_mean_conditioning + loss_func = ResidualLoss(regression_net=regression_model, hr_mean_conditioning=True) + loss_value = loss_func( + diffusion_model, img_clean, img_lr, lead_time_label=lead_time_label + ) + assert isinstance(loss_value, torch.Tensor) + assert loss_value.shape == img_clean.shape # VELoss_dfsr tests @@ -366,6 +517,9 @@ def test_get_beta_schedule_method(): def test_call_method_ve_dfsr(): + def fake_net(y, sigma, labels, augment_labels=None): + return torch.tensor([1.0]) + loss_func = VELoss_dfsr() images = torch.tensor([[[[1.0]]]]) diff --git a/test/models/diffusion/test_preconditioning.py b/test/models/diffusion/test_preconditioning.py index f1f5f86da7..9f15a4b98c 100644 --- a/test/models/diffusion/test_preconditioning.py +++ b/test/models/diffusion/test_preconditioning.py @@ -20,27 +20,24 @@ from physicsnemo.models.diffusion.preconditioning import ( EDMPrecond, - EDMPrecondSR, + EDMPrecondSuperResolution, VEPrecond_dfsr, VEPrecond_dfsr_cond, ) from physicsnemo.models.module import Module -@pytest.mark.parametrize("scale_cond_input", [True, False]) -def test_EDMPrecondSR_forward(scale_cond_input): +def test_EDMPrecondSuperResolution_forward(): b, c_target, x, y = 1, 3, 8, 8 c_cond = 4 # Create an instance of the preconditioner - model = EDMPrecondSR( + model = EDMPrecondSuperResolution( img_resolution=x, - img_channels=c_target, img_in_channels=c_cond, img_out_channels=c_target, use_fp16=False, model_type="SongUNet", - scale_cond_input=scale_cond_input, ) latents = torch.ones((b, c_target, x, y)) @@ -59,15 +56,15 @@ def test_EDMPrecondSR_forward(scale_cond_input): @import_or_fail("termcolor") -def test_EDMPrecondSR_serialization(tmp_path, pytestconfig): +def test_EDMPrecondSuperResolution_serialization(tmp_path, pytestconfig): from physicsnemo.launch.utils import load_checkpoint, save_checkpoint - module = EDMPrecondSR(8, 1, 1, 1, scale_cond_input=False) + module = EDMPrecondSuperResolution(8, 1, 1) model_path = tmp_path / "output.mdlus" module.save(model_path.as_posix()) loaded = Module.from_checkpoint(model_path.as_posix()) - assert isinstance(loaded, EDMPrecondSR) + assert isinstance(loaded, EDMPrecondSuperResolution) save_checkpoint(path=tmp_path, models=module, epoch=1) epoch = load_checkpoint(path=tmp_path) assert epoch == 1 diff --git a/test/models/diffusion/test_song_unet_pos_embd.py b/test/models/diffusion/test_song_unet_pos_embd.py index 224a519cc0..43d6b4a304 100644 --- a/test/models/diffusion/test_song_unet_pos_embd.py +++ b/test/models/diffusion/test_song_unet_pos_embd.py @@ -69,8 +69,10 @@ def test_song_unet_forward(device): def test_song_unet_global_indexing(device): torch.manual_seed(0) N_pos = 2 - batch_shape_x = 32 - batch_shape_y = 64 + patch_shape_y = 64 + patch_shape_x = 32 + offset_y = 12 + offset_x = 45 # Construct the DDM++ UNet model model = UNet( img_resolution=128, @@ -79,20 +81,73 @@ def test_song_unet_global_indexing(device): gridtype="test", N_grid_channels=N_pos, ).to(device) - input_image = torch.ones([1, 2, batch_shape_x, batch_shape_y]).to(device) + input_image = torch.ones([1, 2, patch_shape_y, patch_shape_x]).to(device) noise_labels = noise_labels = torch.randn([1]).to(device) class_labels = torch.randint(0, 1, (1, 1)).to(device) - idx_x = torch.arange(45, 45 + batch_shape_x) - idx_y = torch.arange(12, 12 + batch_shape_y) - mesh_x, mesh_y = torch.meshgrid(idx_x, idx_y) - global_index = torch.stack((mesh_x, mesh_y), dim=0)[None].to(device) + idx_x = torch.arange(patch_shape_x) + offset_x + idx_y = torch.arange(patch_shape_y) + offset_y + mesh_x, mesh_y = torch.meshgrid(idx_y, idx_x, indexing="ij") + global_index = torch.stack((mesh_x, mesh_y), dim=0)[None].to( + device + ) # (2, patch_shape_y, patch_shape_x) output_image = model(input_image, noise_labels, class_labels, global_index) pos_embed = model.positional_embedding_indexing(input_image, global_index) - assert output_image.shape == (1, 2, batch_shape_x, batch_shape_y) + assert output_image.shape == (1, 2, patch_shape_y, patch_shape_x) assert torch.equal(pos_embed, global_index) +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_song_unet_embedding_selector(device): + torch.manual_seed(0) + N_pos = 2 + patch_shape_y = 64 + patch_shape_x = 32 + offset_y = 12 + offset_x = 45 + # Construct the DDM++ UNet model + model = UNet( + img_resolution=128, + in_channels=2 + N_pos, + out_channels=2, + gridtype="test", + N_grid_channels=N_pos, + ).to(device) + input_image = torch.ones([1, 2, patch_shape_y, patch_shape_x]).to(device) + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + + # Expected embeddings should be the same as global_index + idx_x = torch.arange(patch_shape_x) + offset_x + idx_y = torch.arange(patch_shape_y) + offset_y + mesh_x, mesh_y = torch.meshgrid(idx_y, idx_x, indexing="ij") + expected_embeds = torch.stack((mesh_x, mesh_y), dim=0)[None].to( + device + ) # (2, patch_shape_y, patch_shape_x) + + # Function to select embeddings + def embedding_selector(emb): + return emb[ + None, + :, + offset_y : offset_y + patch_shape_y, + offset_x : offset_x + patch_shape_x, + ] + + output_image = model( + input_image, + noise_labels, + class_labels, + embedding_selector=embedding_selector, + ) + selected_embeds = model.positional_embedding_selector( + input_image, embedding_selector + ) + + assert output_image.shape == (1, 2, patch_shape_y, patch_shape_x) + assert torch.equal(selected_embeds, expected_embeds) + + @pytest.mark.parametrize("device", ["cuda:0", "cpu"]) def test_song_unet_constructor(device): """Test the Song UNet constructor options""" diff --git a/test/models/diffusion/test_song_unet_pos_lt_embd.py b/test/models/diffusion/test_song_unet_pos_lt_embd.py index 55df427131..da71578e7f 100644 --- a/test/models/diffusion/test_song_unet_pos_lt_embd.py +++ b/test/models/diffusion/test_song_unet_pos_lt_embd.py @@ -69,8 +69,10 @@ def test_song_unet_forward(device): def test_song_unet_lt_indexing(device): torch.manual_seed(0) N_pos = 2 - batch_shape_x = 32 - batch_shape_y = 64 + patch_shape_y = 64 + patch_shape_x = 32 + offset_y = 12 + offset_x = 45 # Construct the DDM++ UNet model lead_time_channels = 4 model = UNet( @@ -82,44 +84,70 @@ def test_song_unet_lt_indexing(device): prob_channels=[0, 1, 2, 3], N_grid_channels=N_pos, ).to(device) - input_image = torch.ones([1, 10, batch_shape_x, batch_shape_y]).to(device) + input_image = torch.ones([1, 10, patch_shape_y, patch_shape_x]).to(device) noise_labels = noise_labels = torch.randn([1]).to(device) class_labels = torch.randint(0, 1, (1, 1)).to(device) - idx_x = torch.arange(45, 45 + batch_shape_x) - idx_y = torch.arange(12, 12 + batch_shape_y) - mesh_x, mesh_y = torch.meshgrid(idx_x, idx_y) - global_index = torch.stack((mesh_x, mesh_y), dim=0)[None].to(device) + idx_x = torch.arange(offset_x, offset_x + patch_shape_x) + idx_y = torch.arange(offset_y, offset_y + patch_shape_y) + mesh_x, mesh_y = torch.meshgrid(idx_y, idx_x, indexing="ij") + global_index = torch.stack((mesh_x, mesh_y), dim=0)[None].to( + device + ) # (2, patch_shape_y, patch_shape_x) - # pos_embed = model.positional_embedding_indexing(input_image, torch.cat([model.pos_embd, model.lt_embd], dim=0).to(device), global_index) - # assert torch.equal(pos_embed, global_index) + # Define a function to select the embeddings + def embedding_selector(emb): + return emb[ + None, + :, + offset_y : offset_y + patch_shape_y, + offset_x : offset_x + patch_shape_x, + ] model.training = True - output_image = model( + output_image_indexing = model( input_image, noise_labels, class_labels, lead_time_label=torch.tensor(8), global_index=global_index, ) - assert output_image.shape == (1, 10, batch_shape_x, batch_shape_y) + output_image_selector = model( + input_image, + noise_labels, + class_labels, + lead_time_label=torch.tensor(8), + embedding_selector=embedding_selector, + ) + assert output_image_indexing.shape == (1, 10, patch_shape_y, patch_shape_x) + assert torch.allclose(output_image_indexing, output_image_selector, atol=1e-5) model.training = False - output_image = model( + output_image_indexing = model( input_image, noise_labels, class_labels, lead_time_label=torch.tensor(8), global_index=global_index, ) - assert output_image.shape == (1, 10, batch_shape_x, batch_shape_y) + output_image_selector = model( + input_image, + noise_labels, + class_labels, + lead_time_label=torch.tensor(8), + embedding_selector=embedding_selector, + ) + assert output_image_indexing.shape == (1, 10, patch_shape_y, patch_shape_x) + assert torch.allclose(output_image_indexing, output_image_selector, atol=1e-5) @pytest.mark.parametrize("device", ["cuda:0", "cpu"]) def test_song_unet_global_indexing(device): torch.manual_seed(0) N_pos = 2 - batch_shape_x = 32 - batch_shape_y = 64 + patch_shape_y = 32 + patch_shape_x = 64 + offset_y = 12 + offset_x = 45 # Construct the DDM++ UNet model model = UNet( img_resolution=128, @@ -128,13 +156,15 @@ def test_song_unet_global_indexing(device): gridtype="test", N_grid_channels=N_pos, ).to(device) - input_image = torch.ones([1, 2, batch_shape_x, batch_shape_y]).to(device) + input_image = torch.ones([1, 2, patch_shape_y, patch_shape_x]).to(device) noise_labels = noise_labels = torch.randn([1]).to(device) class_labels = torch.randint(0, 1, (1, 1)).to(device) - idx_x = torch.arange(45, 45 + batch_shape_x) - idx_y = torch.arange(12, 12 + batch_shape_y) - mesh_x, mesh_y = torch.meshgrid(idx_x, idx_y) - global_index = torch.stack((mesh_x, mesh_y), dim=0)[None].to(device) + idx_x = torch.arange(offset_x, offset_x + patch_shape_x) + idx_y = torch.arange(offset_y, offset_y + patch_shape_y) + mesh_x, mesh_y = torch.meshgrid(idx_y, idx_x, indexing="ij") + global_index = torch.stack((mesh_x, mesh_y), dim=0)[None].to( + device + ) # (2, patch_shape_y, patch_shape_x) output_image = model( input_image, noise_labels, class_labels, global_index=global_index @@ -142,10 +172,63 @@ def test_song_unet_global_indexing(device): pos_embed = model.positional_embedding_indexing( input_image, model.pos_embd, global_index=global_index ) - assert output_image.shape == (1, 2, batch_shape_x, batch_shape_y) + assert output_image.shape == (1, 2, patch_shape_y, patch_shape_x) assert torch.equal(pos_embed, global_index) +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_song_unet_embedding_selector(device): + torch.manual_seed(0) + N_pos = 2 + patch_shape_y = 32 + patch_shape_x = 64 + offset_y = 12 + offset_x = 45 + # Construct the DDM++ UNet model + model = UNet( + img_resolution=128, + in_channels=2 + N_pos, + out_channels=2, + gridtype="test", + N_grid_channels=N_pos, + ).to(device) + input_image = torch.ones([1, 2, patch_shape_y, patch_shape_x]).to(device) + noise_labels = torch.randn([1]).to(device) + class_labels = torch.randint(0, 1, (1, 1)).to(device) + + # Expected embeddings should be the same as global_index + idx_x = torch.arange(offset_x, offset_x + patch_shape_x) + idx_y = torch.arange(offset_y, offset_y + patch_shape_y) + mesh_x, mesh_y = torch.meshgrid(idx_y, idx_x, indexing="ij") + expected_embeds = torch.stack((mesh_x, mesh_y), dim=0)[None].to( + device + ) # (2, patch_shape_y, patch_shape_x) + + # Define a function to select the embeddings + def embedding_selector(emb): + return emb[ + None, + :, + offset_y : offset_y + patch_shape_y, + offset_x : offset_x + patch_shape_x, + ] + + output_image = model( + input_image, + noise_labels, + class_labels, + embedding_selector=embedding_selector, + ) + assert output_image.shape == (1, 2, patch_shape_y, patch_shape_x) + + # Verify that the embeddings are correctly selected + selected_embeds = model.positional_embedding_selector( + input_image, model.pos_embd, embedding_selector + ) + + assert torch.equal(selected_embeds, expected_embeds) + + @pytest.mark.parametrize("device", ["cuda:0", "cpu"]) def test_song_unet_constructor(device): """Test the Song UNet constructor options""" diff --git a/test/models/diffusion/test_unet_wrappers.py b/test/models/diffusion/test_unet_wrappers.py index 7fc3f5b6d7..d2dbaef9bc 100644 --- a/test/models/diffusion/test_unet_wrappers.py +++ b/test/models/diffusion/test_unet_wrappers.py @@ -36,15 +36,13 @@ def test_unet_forwards(device): res, inc, outc = 64, 2, 3 model = UNet( img_resolution=res, - img_channels=inc, img_in_channels=inc, img_out_channels=outc, model_type="SongUNet", ).to(device) input_image = torch.ones([1, inc, res, res]).to(device) lr_image = torch.randn([1, outc, res, res]).to(device) - sigma = torch.randn([1]).to(device) - output = model(x=input_image, img_lr=lr_image, sigma=sigma) + output = model(x=input_image, img_lr=lr_image) assert output.shape == (1, outc, res, res) # Construct the StormCastUNet model @@ -66,16 +64,14 @@ def setup_model(): model = UNet( img_resolution=res, - img_channels=inc, img_in_channels=inc, img_out_channels=outc, model_type="SongUNet", ).to(device) input_image = torch.ones([1, inc, res, res]).to(device) lr_image = torch.randn([1, outc, res, res]).to(device) - sigma = torch.randn([1]).to(device) - return model, [input_image, lr_image, sigma] + return model, [input_image, lr_image] # Check AMP model, invar = setup_model() @@ -101,14 +97,12 @@ def test_unet_checkpoint(device): res, inc, outc = 64, 2, 3 model_1 = UNet( img_resolution=res, - img_channels=inc, img_in_channels=inc, img_out_channels=outc, model_type="SongUNet", ).to(device) model_2 = UNet( img_resolution=res, - img_channels=inc, img_in_channels=inc, img_out_channels=outc, model_type="SongUNet", @@ -116,10 +110,7 @@ def test_unet_checkpoint(device): input_image = torch.ones([1, inc, res, res]).to(device) lr_image = torch.randn([1, outc, res, res]).to(device) - sigma = torch.randn([1]).to(device) - assert common.validate_checkpoint( - model_1, model_2, (*[input_image, lr_image, sigma],) - ) + assert common.validate_checkpoint(model_1, model_2, (*[input_image, lr_image],)) # Construct StormCastUNet models res, inc, outc = 64, 2, 3 diff --git a/test/utils/corrdiff/test_generation_steps.py b/test/utils/corrdiff/test_generation_steps.py index afda2b8cbd..6a82613ab6 100644 --- a/test/utils/corrdiff/test_generation_steps.py +++ b/test/utils/corrdiff/test_generation_steps.py @@ -15,12 +15,35 @@ # limitations under the License. from functools import partial +from typing import Callable, Optional import pytest import torch from pytest_utils import import_or_fail +# Mock network class +class MockNet: + def __init__(self, sigma_min=0.1, sigma_max=1000): + self.sigma_min = sigma_min + self.sigma_max = sigma_max + + def round_sigma(self, t: torch.Tensor) -> torch.Tensor: + return t + + def __call__( + self, + x: torch.Tensor, + x_lr: torch.Tensor, + t: torch.Tensor, + class_labels: Optional[torch.Tensor], + global_index: Optional[torch.Tensor] = None, + embedding_selector: Optional[Callable] = None, + ) -> torch.Tensor: + # Mock behavior: return input tensor for testing purposes + return x * 0.9 + + @import_or_fail("cftime") @pytest.mark.parametrize("device", ["cuda:0", "cpu"]) def test_regression_step(device, pytestconfig): @@ -30,12 +53,11 @@ def test_regression_step(device, pytestconfig): # define the net mock_unet = UNet( - img_channels=2, - N_grid_channels=4, - embedding_type="zero", + img_resolution=[16, 16], img_in_channels=8, img_out_channels=2, - img_resolution=[16, 16], + N_grid_channels=4, + embedding_type="zero", ).to(device) # Define the input parameters @@ -53,17 +75,13 @@ def test_regression_step(device, pytestconfig): @pytest.mark.parametrize("device", ["cuda:0", "cpu"]) def test_diffusion_step(device, pytestconfig): - from physicsnemo.models.diffusion import EDMPrecondSR + from physicsnemo.models.diffusion import EDMPrecondSuperResolution from physicsnemo.utils.corrdiff import diffusion_step from physicsnemo.utils.generative import deterministic_sampler, stochastic_sampler # Define the preconditioner - mock_precond = EDMPrecondSR( - img_resolution=[16, 16], - img_in_channels=8, - img_out_channels=2, - img_channels=0, - scale_cond_input=False, + mock_precond = EDMPrecondSuperResolution( + img_resolution=[16, 16], img_in_channels=8, img_out_channels=2 ).to(device) # Define the input parameters @@ -79,7 +97,6 @@ def test_diffusion_step(device, pytestconfig): output = diffusion_step( net=mock_precond, sampler_fn=sampler_fn, - seed_batch_size=1, img_shape=(16, 16), img_out_channels=2, rank_batches=[[0]], @@ -101,7 +118,6 @@ def test_diffusion_step(device, pytestconfig): output = diffusion_step( net=mock_precond, sampler_fn=sampler_fn, - seed_batch_size=1, img_shape=(16, 16), img_out_channels=2, rank_batches=[[0]], @@ -112,3 +128,129 @@ def test_diffusion_step(device, pytestconfig): # Assertions assert output.shape == (1, 2, 16, 16), "Output shape mismatch" + + +@import_or_fail("cftime") +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_diffusion_step_rectangle(device, pytestconfig): + from physicsnemo.utils.corrdiff import diffusion_step + from physicsnemo.utils.generative import stochastic_sampler + from physicsnemo.utils.patching import GridPatching2D + + img_shape_y, img_shape_x = 32, 16 + seed_batch_size = 4 + + mock_precond = MockNet() + + # Define the input parameters + img_lr = ( + torch.randn(1, 4, img_shape_y, img_shape_x) + .expand(seed_batch_size, -1, -1, -1) + .to(device) + ) + + # Stochastic sampler without patching + sampler_fn = partial( + stochastic_sampler, + num_steps=2, + ) + + # Call the function + output = diffusion_step( + net=mock_precond, + sampler_fn=sampler_fn, + img_shape=(img_shape_y, img_shape_x), + img_out_channels=2, + rank_batches=[list(range(seed_batch_size))], + img_lr=img_lr, + mean_hr=None, + rank=0, + device=device, + ) + + # Assertions + assert output.shape == ( + seed_batch_size, + 2, + img_shape_y, + img_shape_x, + ), "Output shape mismatch" + + # Test with mean_hr conditioning + + # Define the input parameters + mean_hr = torch.randn(1, 2, img_shape_y, img_shape_x).to(device) + img_lr = ( + torch.randn(1, 4, img_shape_y, img_shape_x) + .expand(seed_batch_size, -1, -1, -1) + .to(device) + ) + + # Stochastic sampler without patching + sampler_fn = partial( + stochastic_sampler, + num_steps=2, + ) + + # Call the function + output = diffusion_step( + net=mock_precond, + sampler_fn=sampler_fn, + img_shape=(img_shape_y, img_shape_x), + img_out_channels=2, + rank_batches=[list(range(seed_batch_size))], + img_lr=img_lr, + mean_hr=mean_hr, + rank=0, + device=device, + ) + + # Assertions + assert output.shape == ( + seed_batch_size, + 2, + img_shape_y, + img_shape_x, + ), "Output shape mismatch" + + # Test with mean_hr conditioning and rectangular patching + + # Define the input parameters + mean_hr = torch.randn(1, 2, img_shape_y, img_shape_x).to(device) + img_lr = ( + torch.randn(1, 4, img_shape_y, img_shape_x) + .expand(seed_batch_size, -1, -1, -1) + .to(device) + ) + + # Define patching utility + patching = GridPatching2D( + img_shape=(img_shape_y, img_shape_x), + patch_shape=(16, 10), + overlap_pix=4, + boundary_pix=2, + ) + + # Stochastic sampler with rectangular patching + sampler_fn = partial(stochastic_sampler, num_steps=2, patching=patching) + + # Call the function + output = diffusion_step( + net=mock_precond, + sampler_fn=sampler_fn, + img_shape=(img_shape_y, img_shape_x), + img_out_channels=2, + rank_batches=[list(range(seed_batch_size))], + img_lr=img_lr, + mean_hr=mean_hr, + rank=0, + device=device, + ) + + # Assertions + assert output.shape == ( + seed_batch_size, + 2, + img_shape_y, + img_shape_x, + ), "Output shape mismatch" diff --git a/test/utils/generative/test_stochastic_sampler.py b/test/utils/generative/test_stochastic_sampler.py index eaf1d9392b..528bc441c3 100644 --- a/test/utils/generative/test_stochastic_sampler.py +++ b/test/utils/generative/test_stochastic_sampler.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Callable, Optional import torch from pytest_utils import import_or_fail @@ -37,6 +37,7 @@ def __call__( t: Tensor, class_labels: Optional[Tensor], global_index: Optional[Tensor] = None, + embedding_selector: Optional[Callable] = None, ) -> Tensor: # Mock behavior: return input tensor for testing purposes return x * 0.9 @@ -57,10 +58,7 @@ def test_stochastic_sampler(pytestconfig): net=net, latents=latents, img_lr=img_lr, - img_shape=448, - patch_shape=448, - overlap_pix=4, - boundary_pix=2, + patching=None, mean_hr=None, num_steps=4, sigma_min=0.002, @@ -80,10 +78,7 @@ def test_stochastic_sampler(pytestconfig): net=net, latents=latents, img_lr=img_lr, - img_shape=448, - patch_shape=448, - overlap_pix=4, - boundary_pix=2, + patching=None, mean_hr=mean_hr, num_steps=2, sigma_min=0.002, @@ -104,10 +99,7 @@ def test_stochastic_sampler(pytestconfig): net=net, latents=latents, img_lr=img_lr, - img_shape=448, - patch_shape=448, - overlap_pix=4, - boundary_pix=2, + patching=None, mean_hr=None, num_steps=3, sigma_min=0.002, @@ -124,204 +116,46 @@ def test_stochastic_sampler(pytestconfig): ), "Churn output shape does not match expected shape" +# The test function for edm_sampler with rectangular domain and patching @import_or_fail("cftime") -def test_image_fuse_basic(pytestconfig): - - from physicsnemo.utils.generative import image_fuse - - # Basic test: No overlap, no boundary, one patch - batch_size = 1 - img_shape_x = img_shape_y = 4 - patch_shape_x = patch_shape_y = 4 - overlap_pix = 0 - boundary_pix = 0 - - input_tensor = torch.arange(1, 17).view(1, 1, 4, 4).cuda().float() - fused_image = image_fuse( - input_tensor, - img_shape_x, - img_shape_y, - patch_shape_x, - patch_shape_y, - batch_size, - overlap_pix, - boundary_pix, - ) - assert fused_image.shape == (batch_size, 1, img_shape_x, img_shape_y) - expected_output = input_tensor - assert torch.allclose( - fused_image, expected_output, atol=1e-5 - ), "Output does not match expected output." - - -@import_or_fail("cftime") -def test_image_fuse_with_boundary(pytestconfig): - - from physicsnemo.utils.generative import image_fuse - - # Test with boundary pixels - batch_size = 1 - img_shape_x = img_shape_y = 4 - patch_shape_x = patch_shape_y = 6 - overlap_pix = 0 - boundary_pix = 1 - - input_tensor = torch.ones(1, 1, 6, 6).cuda().float() # All ones for easy validation - fused_image = image_fuse( - input_tensor, - img_shape_x, - img_shape_y, - patch_shape_x, - patch_shape_y, - batch_size, - overlap_pix, - boundary_pix, - ) - assert fused_image.shape == (batch_size, 1, img_shape_x, img_shape_y) - expected_output = ( - torch.ones(1, 1, 4, 4).cuda().float() - ) # Expected output is just the inner 4x4 part - assert torch.allclose( - fused_image, expected_output, atol=1e-5 - ), "Output with boundary does not match expected output." - - -@import_or_fail("cftime") -def test_image_fuse_with_multiple_batches(pytestconfig): - - from physicsnemo.utils.generative import image_fuse - - # Test with multiple batches - batch_size = 2 - img_shape_x = img_shape_y = 4 - patch_shape_x = patch_shape_y = 4 - overlap_pix = 0 - boundary_pix = 0 - - input_tensor = ( - torch.cat( - [ - torch.arange(1, 17).view(1, 1, 4, 4), - torch.arange(17, 33).view(1, 1, 4, 4), - ] - ) - .cuda() - .float() - ) - input_tensor = input_tensor.repeat(2, 1, 1, 1) - fused_image = image_fuse( - input_tensor, - img_shape_x, - img_shape_y, - patch_shape_x, - patch_shape_y, - batch_size, - overlap_pix, - boundary_pix, - ) - assert fused_image.shape == (batch_size, 1, img_shape_x, img_shape_y) - expected_output = ( - torch.cat( - [ - torch.arange(1, 17).view(1, 1, 4, 4), - torch.arange(17, 33).view(1, 1, 4, 4), - ] - ) - .cuda() - .float() - ) - assert torch.allclose( - fused_image, expected_output, atol=1e-5 - ), "Output for multiple batches does not match expected output." - +def test_stochastic_sampler_rectangle_patching(pytestconfig): + from physicsnemo.utils.generative import stochastic_sampler + from physicsnemo.utils.patching import GridPatching2D -@import_or_fail("cftime") -def test_image_batching_basic(pytestconfig): + net = MockNet() - from physicsnemo.utils.generative import image_batching + img_shape_y, img_shape_x = 256, 64 + patch_shape_y, patch_shape_x = 16, 10 - # Test with no overlap, no boundary, no input_interp - batch_size = 1 - img_shape_x = img_shape_y = 4 - patch_shape_x = patch_shape_y = 4 - overlap_pix = 0 - boundary_pix = 0 + latents = torch.randn(2, 3, img_shape_y, img_shape_x) # Mock latents + img_lr = torch.randn(2, 3, img_shape_y, img_shape_x) # Mock low-res image - input_tensor = torch.arange(1, 17).view(1, 1, 4, 4).cuda().float() - batched_images = image_batching( - input_tensor, - img_shape_x, - img_shape_y, - patch_shape_x, - patch_shape_y, - batch_size, - overlap_pix, - boundary_pix, + # Test with patching + patching = GridPatching2D( + img_shape=(img_shape_y, img_shape_x), + patch_shape=(patch_shape_y, patch_shape_x), + overlap_pix=4, + boundary_pix=2, ) - assert batched_images.shape == (batch_size, 1, patch_shape_x, patch_shape_y) - expected_output = input_tensor - assert torch.allclose( - batched_images, expected_output, atol=1e-5 - ), "Batched images do not match expected output." - - -@import_or_fail("cftime") -def test_image_batching_with_boundary(pytestconfig): - # Test with boundary pixels, no overlap, no input_interp - - from physicsnemo.utils.generative import image_batching - batch_size = 1 - img_shape_x = img_shape_y = 4 - patch_shape_x = patch_shape_y = 6 - overlap_pix = 0 - boundary_pix = 1 - - input_tensor = torch.ones(1, 1, 4, 4).cuda().float() # All ones for easy validation - batched_images = image_batching( - input_tensor, - img_shape_x, - img_shape_y, - patch_shape_x, - patch_shape_y, - batch_size, - overlap_pix, - boundary_pix, + # Test with mean_hr conditioning + mean_hr = torch.randn(2, 3, img_shape_y, img_shape_x) + result_mean_hr = stochastic_sampler( + net=net, + latents=latents, + img_lr=img_lr, + patching=patching, + mean_hr=mean_hr, + num_steps=2, + sigma_min=0.002, + sigma_max=800, + rho=7, + S_churn=0, + S_min=0, + S_max=float("inf"), + S_noise=1, ) - assert batched_images.shape == (1, 1, patch_shape_x, patch_shape_y) - expected_output = torch.ones(1, 1, 6, 6).cuda().float() - assert torch.allclose( - batched_images, expected_output, atol=1e-5 - ), "Batched images with boundary do not match expected output." - -@import_or_fail("cftime") -def test_image_batching_with_input_interp(pytestconfig): - # Test with input_interp tensor - - from physicsnemo.utils.generative import image_batching - - batch_size = 1 - img_shape_x = img_shape_y = 4 - patch_shape_x = patch_shape_y = 4 - overlap_pix = 0 - boundary_pix = 0 - - input_tensor = torch.arange(1, 17).view(1, 1, 4, 4).cuda().float() - input_interp = torch.ones(1, 1, 4, 4).cuda().float() # All ones for easy validation - batched_images = image_batching( - input_tensor, - img_shape_x, - img_shape_y, - patch_shape_x, - patch_shape_y, - batch_size, - overlap_pix, - boundary_pix, - input_interp=input_interp, - ) - assert batched_images.shape == (batch_size, 2, patch_shape_x, patch_shape_y) - expected_output = torch.cat((input_tensor, input_interp), dim=1) - assert torch.allclose( - batched_images, expected_output, atol=1e-5 - ), "Batched images with input_interp do not match expected output." + assert ( + result_mean_hr.shape == latents.shape + ), "Mean HR conditioned output shape does not match expected shape" diff --git a/test/utils/test_patching.py b/test/utils/test_patching.py new file mode 100644 index 0000000000..4f66e8b30b --- /dev/null +++ b/test/utils/test_patching.py @@ -0,0 +1,249 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import torch +from einops import rearrange, repeat +from pytest_utils import import_or_fail + + +@import_or_fail("cftime") +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_image_fuse_basic(pytestconfig, device): + from physicsnemo.utils.patching import image_fuse + + # Basic test: No overlap, no boundary, one patch + batch_size = 1 + for img_shape_y, img_shape_x in ((4, 4), (8, 4)): + overlap_pix = 0 + boundary_pix = 0 + + input_tensor = ( + torch.arange(1, img_shape_y * img_shape_x + 1) + .view(1, 1, img_shape_y, img_shape_x) + .to(device) + .float() + ) + fused_image = image_fuse( + input_tensor, + img_shape_y, + img_shape_x, + batch_size, + overlap_pix, + boundary_pix, + ) + assert fused_image.shape == (batch_size, 1, img_shape_y, img_shape_x) + expected_output = input_tensor + assert torch.allclose( + fused_image, expected_output, atol=1e-5 + ), "Output does not match expected output." + + +@import_or_fail("cftime") +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_image_fuse_with_boundary(pytestconfig, device): + from physicsnemo.utils.patching import image_fuse + + # Test with boundary pixels + overlap_pix = 0 + boundary_pix = 1 + + input_tensor = torch.randn(1, 1, 8, 6).to(device).float() + fused_image = image_fuse( + input_tensor, + img_shape_y=6, + img_shape_x=4, + batch_size=1, + overlap_pix=overlap_pix, + boundary_pix=boundary_pix, + ) + assert fused_image.shape == (1, 1, 6, 4) + expected_output = input_tensor[ + :, :, boundary_pix:-boundary_pix, boundary_pix:-boundary_pix + ] + assert torch.allclose( + fused_image, expected_output, atol=1e-5 + ), "Output with boundary does not match expected output." + + +@import_or_fail("cftime") +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_image_fuse_with_multiple_batches(pytestconfig, device): + from physicsnemo.utils.patching import image_batching, image_fuse + + # Test with multiple batches + batch_size = 2 + + # Test cases: (img_shape_y, img_shape_x, patch_shape_y, patch_shape_x, overlap_pix, boundary_pix) + test_cases = [ + (32, 32, 16, 16, 0, 0), # Square image, no overlap/boundary + (64, 32, 32, 16, 0, 0), # Rectangular image, no overlap/boundary + (48, 48, 16, 16, 4, 2), # Square image, minimal overlap/boundary + (64, 48, 32, 16, 6, 2), # Rectangular, larger overlap/boundary + ] + + for ( + img_shape_y, + img_shape_x, + patch_shape_y, + patch_shape_x, + overlap_pix, + boundary_pix, + ) in test_cases: + # Create original test image + original_image = ( + torch.rand(batch_size, 3, img_shape_y, img_shape_x).to(device).float() + ) + + # Apply image_batching to split the image into patches + batched_images = image_batching( + original_image, patch_shape_y, patch_shape_x, overlap_pix, boundary_pix + ) + + # Apply image_fuse to reconstruct the image from patches + fused_image = image_fuse( + batched_images, + img_shape_y, + img_shape_x, + batch_size, + overlap_pix, + boundary_pix, + ) + + # Verify that image_fuse reverses image_batching + assert torch.allclose(fused_image, original_image, atol=1e-5), ( + f"Failed on {device}: img=({img_shape_y},{img_shape_x}), " + f"patch=({patch_shape_y},{patch_shape_x}), " + f"overlap={overlap_pix}, boundary={boundary_pix}" + ) + + +@import_or_fail("cftime") +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_image_batching_basic(pytestconfig, device): + from physicsnemo.utils.patching import image_batching + + # Test with no overlap, no boundary, no input_interp + batch_size = 1 + patch_shape_x = patch_shape_y = 4 + overlap_pix = 0 + boundary_pix = 0 + + input_tensor = torch.arange(1, 17).view(1, 1, 4, 4).to(device).float() + batched_images = image_batching( + input_tensor, + patch_shape_y, + patch_shape_x, + overlap_pix, + boundary_pix, + ) + assert batched_images.shape == (batch_size, 1, patch_shape_y, patch_shape_x) + expected_output = input_tensor + assert torch.allclose( + batched_images, expected_output, atol=1e-5 + ), "Batched images do not match expected output." + + +@import_or_fail("cftime") +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_image_batching_with_boundary(pytestconfig, device): + from physicsnemo.utils.patching import image_batching + + # Test with boundary pixels, no overlap, no input_interp + patch_shape_y = 8 + patch_shape_x = 6 + overlap_pix = 0 + boundary_pix = 1 + + input_tensor = torch.rand(1, 1, 6, 4).to(device).float() + batched_images = image_batching( + input_tensor, + patch_shape_y, + patch_shape_x, + overlap_pix, + boundary_pix, + ) + # Create expected output using reflection padding + expected_output = torch.nn.functional.pad( + input_tensor, + pad=(boundary_pix, boundary_pix, boundary_pix, boundary_pix), + mode="reflect", + ) + + assert batched_images.shape == (1, 1, patch_shape_y, patch_shape_x) + assert torch.allclose( + batched_images, expected_output, atol=1e-5 + ), "Batched images with boundary do not match expected output." + + +@import_or_fail("cftime") +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_image_batching_with_input_interp(pytestconfig, device): + from physicsnemo.utils.patching import image_batching + + # Test with input_interp tensor + patch_shape_x = patch_shape_y = 4 + overlap_pix = 0 + boundary_pix = 0 + + for img_shape_y, img_shape_x in ((4, 4), (16, 8)): + img_size = img_shape_y * img_shape_x + patch_num = (img_shape_y // patch_shape_y) * (img_shape_x // patch_shape_x) + input_tensor = ( + torch.arange(1, img_size + 1) + .view(1, 1, img_shape_y, img_shape_x) + .to(device) + .float() + ) + input_interp = ( + torch.arange(-patch_shape_y * patch_shape_x, 0) + .view(1, 1, patch_shape_y, patch_shape_x) + .to(device) + .float() + ) + batched_images = image_batching( + input_tensor, + patch_shape_y, + patch_shape_x, + overlap_pix, + boundary_pix, + input_interp=input_interp, + ) + assert batched_images.shape == (patch_num, 2, patch_shape_y, patch_shape_x) + + # Define expected_output using einops operations + expected_output = torch.cat( + ( + rearrange( + input_tensor, + "b c (nb_p_h p_h) (nb_p_w p_w) -> (b nb_p_w nb_p_h) c p_h p_w", + p_h=patch_shape_y, + p_w=patch_shape_x, + ), + repeat( + input_interp, + "b c p_h p_w -> (b nb_p_w nb_p_h) c p_h p_w", + nb_p_h=img_shape_y // patch_shape_y, + nb_p_w=img_shape_x // patch_shape_x, + ), + ), + dim=1, + ) + + assert torch.allclose( + batched_images, expected_output, atol=1e-5 + ), "Batched images with input_interp do not match expected output."