diff --git a/.gitignore b/.gitignore index 5509140f..db10b362 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,7 @@ *.DS_Store + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*.so + diff --git a/examples/train_sfno.py b/examples/train_sfno.py index 8c650b9b..950bb00d 100644 --- a/examples/train_sfno.py +++ b/examples/train_sfno.py @@ -50,6 +50,13 @@ import wandb wandb.login() +def l1_rel_error(truth, test): + batch_size = truth.shape[0] + difference = torch.zeros(batch_size) + for batch in range(batch_size): + difference[batch] = torch.mean(torch.abs(truth[batch] - test[batch]))/(torch.mean(torch.abs(truth[batch]))).item() * 100 + return difference + def l2loss_sphere(solver, prd, tar, relative=False, squared=True): loss = solver.integrate_grid((prd - tar)**2, dimensionless=True).sum(dim=-1) if relative: @@ -231,16 +238,56 @@ def log_weights_and_grads(model, iters=1): store_dict = {'iteration': iters, 'grads': grad_dict, 'weights': weights_dict} torch.save(store_dict, weights_and_grads_fname) + +def plot_prediction_vs_target(prd, tar): + """ + Plots a 3x3 grid with predictions, targets, and their absolute difference. + + Parameters: + phi (array-like): Azimuthal angle data (1D array). + theta (array-like): Polar angle data (1D array). + prd (array-like): Predicted values (shape: [n_points, 3]). + tar (array-like): Target values (shape: [n_points, 3]). + """ + fig, axes = plt.subplots(3, 3, figsize=(12, 10), constrained_layout=True) + + # Compute absolute difference + diff = np.abs(prd - tar) + + # Titles for rows + row_titles = ["Target", "Prediction", "Absolute Difference"] + + for i in range(3): # Loop over rows: prd, tar, |prd-tar| + for j in range(3): # Loop over columns (channels) + ax = axes[j, i] + if i == 0: + contour = ax.contourf(tar[j, :], levels=100, cmap='viridis') + elif i == 1: + contour = ax.contourf(prd[j, :], levels=100, cmap='viridis') + else: + contour = ax.contourf(diff[j, :], levels=100, cmap='viridis') + + + cbar = fig.colorbar(contour, ax=ax, orientation='vertical') + cbar.ax.set_ylabel(f'Channel {j + 1}', rotation=270, labelpad=15) + + ax.set_title(f"{row_titles[i]} - Channel {j + 1}") + ax.set_xlabel('Phi (Azimuthal Angle)') + ax.set_ylabel('Theta (Polar Angle)') + + plt.savefig("sfno_prediction.png") + plt.close('all') + # training function def train_model(model, dataloader, optimizer, gscaler, scheduler=None, - nepochs=20, + nepochs=200, nfuture=0, num_examples=256, - num_valid=8, + num_valid=64, loss_fn='l2', enable_amp=False, log_grads=0): @@ -307,20 +354,28 @@ def train_model(model, # perform validation valid_loss = 0 model.eval() + errors = torch.zeros((num_valid)) with torch.no_grad(): - for inp, tar in dataloader: + for index, (inp, tar) in enumerate(dataloader): prd = model(inp) + batch_size = inp.shape[0] for _ in range(nfuture): prd = model(prd) loss = l2loss_sphere(solver, prd, tar, relative=True) valid_loss += loss.item() * inp.size(0) + errors[batch_size*index:batch_size*(index+1)] = l1_rel_error(tar, prd) + + if index == 0: + plot_prediction_vs_target(prd[0].cpu(), tar[0].cpu()) valid_loss = valid_loss / len(dataloader.dataset) if scheduler is not None: scheduler.step(valid_loss) + + epoch_time = time.time() - epoch_start print(f'--------------------------------------------------------------------------------') @@ -328,6 +383,7 @@ def train_model(model, print(f'time taken: {epoch_time}') print(f'accumulated training loss: {acc_loss}') print(f'relative validation loss: {valid_loss}') + print(f'median relative error: {torch.median(errors).item()}') if wandb.run is not None: current_lr = optimizer.param_groups[0]['lr'] diff --git a/examples/train_sfno_dse.py b/examples/train_sfno_dse.py new file mode 100644 index 00000000..a2f090b3 --- /dev/null +++ b/examples/train_sfno_dse.py @@ -0,0 +1,465 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import os +import time + +from tqdm import tqdm +from functools import partial + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from torch.cuda import amp + +import numpy as np +import pandas as pd + +import matplotlib.pyplot as plt + + +import sys + +from torch_harmonics import * +from torch_harmonics.examples.sfno_dse import PdeDataset, SFNODSEFp, SFNODSEVp + + +# trains with a standard L1 of MSE loss +l1_loss = torch.nn.L1Loss() +mse_loss = torch.nn.MSELoss() + +def l1_rel_error(truth, test): + batch_size = truth.shape[0] + difference = torch.zeros(batch_size) + for batch in range(batch_size): + difference[batch] = torch.mean(torch.abs(truth[batch] - test[batch]))/(torch.mean(torch.abs(truth[batch]))).item() * 100 + return difference + +def plot_prediction_vs_target(phi, theta, prd, tar): + """ + Plots a 3x3 grid with predictions, targets, and their absolute difference. + + Parameters: + phi (array-like): Azimuthal angle data (1D array). + theta (array-like): Polar angle data (1D array). + prd (array-like): Predicted values (shape: [n_points, 3]). + tar (array-like): Target values (shape: [n_points, 3]). + """ + fig, axes = plt.subplots(3, 3, figsize=(12, 10), constrained_layout=True) + + # Compute absolute difference + diff = np.abs(prd - tar) + + # Titles for rows + row_titles = ["Target", "Prediction", "Absolute Difference"] + + for i in range(3): # Loop over rows: prd, tar, |prd-tar| + for j in range(3): # Loop over columns (channels) + ax = axes[j, i] + if i == 0: + contour = ax.tricontourf(phi, theta, tar[j, :], levels=100, cmap='viridis') + elif i == 1: + contour = ax.tricontourf(phi, theta, prd[j, :], levels=100, cmap='viridis') + else: + contour = ax.tricontourf(phi, theta, diff[j, :], levels=100, cmap='viridis') + + + cbar = fig.colorbar(contour, ax=ax, orientation='vertical') + cbar.ax.set_ylabel(f'Channel {j + 1}', rotation=270, labelpad=15) + + ax.set_title(f"{row_titles[i]} - Channel {j + 1}") + ax.set_xlabel('Phi (Azimuthal Angle)') + ax.set_ylabel('Theta (Polar Angle)') + + plt.savefig("example_prediction.png") + plt.close('all') + +# training function +def train_model_fp(model, + dataloader, + optimizer, + gscaler, + select_random_points, + theta_index, + phi_index, + scheduler=None, + nepochs=20, + nfuture=0, + num_examples=256, + num_valid=64, + loss_fn='l1', + enable_amp=False, + log_grads=0, + save_model=False, + plot_results=True): + + train_start = time.time() + + # count iterations + iters = 0 + + for epoch in range(nepochs): + + # time each epoch + epoch_start = time.time() + + dataloader.dataset.set_initial_condition('random') + dataloader.dataset.set_num_examples(num_examples) + + # get the solver for its convenience functions + solver = dataloader.dataset.solver + + # do the training + acc_loss = 0 + model.train() + + for inp, tar in dataloader: + + with torch.autocast(device_type="cuda", enabled=enable_amp): + + # Select a random set of points from the input and target grids + inp = select_random_points.get_random_sphere_data(inp, theta_index, phi_index) + tar = select_random_points.get_random_sphere_data(tar, theta_index, phi_index) + + + prd = model(inp) + for _ in range(nfuture): + prd = model(prd) + + if loss_fn == 'mse': + loss = mse_loss(prd, tar) + elif loss_fn == 'l1': + loss = l1_loss(prd, tar) + else: + raise NotImplementedError(f'Unknown loss function {loss_fn}') + + acc_loss += loss.item() * inp.size(0) + + optimizer.zero_grad(set_to_none=True) + gscaler.scale(loss).backward() + + if log_grads and iters % log_grads == 0: + log_weights_and_grads(model, iters=iters) + + gscaler.step(optimizer) + gscaler.update() + + iters += 1 + + acc_loss = acc_loss / len(dataloader.dataset) + + dataloader.dataset.set_initial_condition('random') + dataloader.dataset.set_num_examples(num_valid) + + # perform validation + valid_loss = 0 + model.eval() + + errors = torch.zeros((num_valid)) + with torch.no_grad(): + for index, (inp, tar) in enumerate(dataloader): + batch_size = inp.shape[0] + + # Select a random set of points from the input and target grids + inp = select_random_points.get_random_sphere_data(inp, theta_index, phi_index) + tar = select_random_points.get_random_sphere_data(tar, theta_index, phi_index) + + prd = model(inp) + for _ in range(nfuture): + prd = model(prd) + + + if loss_fn == 'mse': + loss = mse_loss(prd, tar) + elif loss_fn == 'l1': + loss = l1_loss(prd, tar) + else: + raise NotImplementedError(f'Unknown loss function {loss_fn}') + + valid_loss += loss.item() * inp.size(0) + errors[batch_size*index:batch_size*(index+1)] = l1_rel_error(tar, prd) + + if index == 0 and plot_results: + plot_prediction_vs_target(phi_index, theta_index, prd[0].cpu(), tar[0].cpu()) + + valid_loss = valid_loss / len(dataloader.dataset) + + if scheduler is not None: + scheduler.step(valid_loss) + + + epoch_time = time.time() - epoch_start + + print(f'--------------------------------------------------------------------------------') + print(f'Epoch {epoch} summary:') + for param_group in scheduler.optimizer.param_groups: + print(f"learning rate: {param_group['lr']}") + print(f'time taken: {epoch_time}') + print(f'accumulated training loss: {acc_loss}') + print(f'relative validation loss: {valid_loss}') + print(f'median relative error: {torch.median(errors).item()}') + + + + train_time = time.time() - train_start + + print(f'--------------------------------------------------------------------------------') + print(f'done. Training took {train_time}.') + return valid_loss + + +# training function +def train_model_vp(model, + dataloader, + optimizer, + gscaler, + select_random_points, + scheduler=None, + nepochs=20, + nfuture=0, + num_examples=256, + num_valid=64, + loss_fn='l1', + enable_amp=False, + log_grads=0, + save_model=False): + + train_start = time.time() + + degree = 22 + + # count iterations + iters = 0 + + for epoch in range(nepochs): + + # time each epoch + epoch_start = time.time() + + dataloader.dataset.set_initial_condition('random') + dataloader.dataset.set_num_examples(num_examples) + + # get the solver for its convenience functions + solver = dataloader.dataset.solver + + # do the training + acc_loss = 0 + model.train() + + for inp, tar in dataloader: + + # Select a random set of points + theta_index, phi_index, theta, phi = select_random_points.random_sets_on_sphere(5000, inp.shape[0]) + + + with torch.autocast(device_type="cuda", enabled=enable_amp): + # Select a random set of points from the input and target grids + inp = select_random_points.get_random_sphere_data(inp, theta_index, phi_index) + tar = select_random_points.get_random_sphere_data(tar, theta_index, phi_index) + + + sht_transform = BatchedRealSHTDSE(phi, theta, degree) + + + + prd = model(inp, sht_transform) + for _ in range(nfuture): + prd = model(prd) + + if loss_fn == 'mse': + loss = mse_loss(prd, tar) + elif loss_fn == 'l1': + loss = l1_loss(prd, tar) + else: + raise NotImplementedError(f'Unknown loss function {loss_fn}') + + acc_loss += loss.item() * inp.size(0) + + optimizer.zero_grad(set_to_none=True) + gscaler.scale(loss).backward() + + if log_grads and iters % log_grads == 0: + log_weights_and_grads(model, iters=iters) + + gscaler.step(optimizer) + gscaler.update() + + iters += 1 + + acc_loss = acc_loss / len(dataloader.dataset) + + dataloader.dataset.set_initial_condition('random') + dataloader.dataset.set_num_examples(num_valid) + + # perform validation + valid_loss = 0 + model.eval() + + errors = torch.zeros((num_valid)) + with torch.no_grad(): + for index, (inp, tar) in enumerate(dataloader): + + # Select a random set of points + theta_index, phi_index, theta, phi = select_random_points.random_sets_on_sphere(5000, inp.shape[0]) + batch_size = inp.shape[0] + + # Select a random set of points from the input and target grids + inp = select_random_points.get_random_sphere_data(inp, theta_index, phi_index) + tar = select_random_points.get_random_sphere_data(tar, theta_index, phi_index) + + sht_transform = BatchedRealSHTDSE(phi, theta, degree) + + prd = model(inp, sht_transform) + for _ in range(nfuture): + prd = model(prd) + + + if loss_fn == 'mse': + loss = mse_loss(prd, tar) + elif loss_fn == 'l1': + loss = l1_loss(prd, tar) + else: + raise NotImplementedError(f'Unknown loss function {loss_fn}') + + valid_loss += loss.item() * inp.size(0) + errors[batch_size*index:batch_size*(index+1)] = l1_rel_error(tar, prd) + + valid_loss = valid_loss / len(dataloader.dataset) + + if scheduler is not None: + scheduler.step(valid_loss) + + epoch_time = time.time() - epoch_start + + print(f'--------------------------------------------------------------------------------') + print(f'Epoch {epoch} summary:') + print(f'time taken: {epoch_time}') + print(f'accumulated training loss: {acc_loss}') + print(f'relative validation loss: {valid_loss}') + print(f'median relative error: {torch.median(errors).item()}') + + + + train_time = time.time() - train_start + + print(f'--------------------------------------------------------------------------------') + print(f'done. Training took {train_time}.') + return valid_loss + + +def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): + + # set seed + torch.manual_seed(333) + torch.cuda.manual_seed(333) + + # set device + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + if torch.cuda.is_available(): + torch.cuda.set_device(device.index) + + # 1 hour prediction steps + dt = 1*3600 + dt_solver = 150 + nsteps = dt//dt_solver + dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(256, 512), device=device, normalize=True) + dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=0, persistent_workers=False) + + nlat = dataset.nlat + nlon = dataset.nlon + + # For selecting a fixed/variable set of uniformly randomly distributed points on the sphere + select_random_points = RandomSphericalSampling(nlon, nlat) + + def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + + # Define hyperparameters for the SFNO model + degree = 22 + width = 128 + num_layers = 4 + in_channels = out_channels = 3 + + train_fixed = False + train_variable = True + + ###################################################################### + # Training a model where collocation points are fixed between samples + ###################################################################### + if train_fixed: + # select the set of points at the beginning + num_points = 5000 # yields approximately 5000 valid points, this is not exact + theta_index, phi_index, theta, phi = select_random_points.random_points_on_sphere(num_points) + + # Using Fixed Points: initialize the matrices for the SHT + sht_transform = RealSHTDSE(phi, theta, degree) + + # Initialize the SFNO using fixed, arbitrary points + model = SFNODSEFp(in_channels, out_channels, degree, width, sht_transform, num_layers).to(device) + + # Count the number of parameters + num_params = count_parameters(model) + print(f'number of trainable params: {num_params}') + + optimizer = torch.optim.Adam(model.parameters(), lr=1E-3) + # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9) + gscaler = amp.GradScaler('cuda', enabled=enable_amp) + + # Training a model where collocation points are fixed between samples + train_model_fp(model, dataloader, optimizer, gscaler, select_random_points, theta_index, phi_index, scheduler, nepochs=200, loss_fn='l1') + + + ###################################################################### + # Training a model where collocation points vary between samples + ###################################################################### + + if train_variable: + # Initialize the SFNO using variable, arbitrary points + model = SFNODSEVp(in_channels, out_channels, degree, width, num_layers).to(device) + + # Count the number of parameters + num_params = count_parameters(model) + print(f'number of trainable params: {num_params}') + + optimizer = torch.optim.Adam(model.parameters(), lr=1E-3) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') + gscaler = amp.GradScaler('cuda', enabled=enable_amp) + + train_model_vp(model, dataloader, optimizer, gscaler, select_random_points, scheduler, nepochs=200, loss_fn='l1') + + +if __name__ == "__main__": + import torch.multiprocessing as mp + mp.set_start_method('forkserver', force=True) + + main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0) diff --git a/torch_harmonics.egg-info/PKG-INFO b/torch_harmonics.egg-info/PKG-INFO new file mode 100644 index 00000000..a9d8c136 --- /dev/null +++ b/torch_harmonics.egg-info/PKG-INFO @@ -0,0 +1,295 @@ +Metadata-Version: 2.1 +Name: torch_harmonics +Version: 0.7.2 +Summary: Differentiable signal processing on the sphere for PyTorch. +Author: Boris Bonev, Thorsten Kurth, Mauro Bisson, Massimiliano Fatica, Jean Kossaifi, Nikola Kovachki, Christian Hundt +Maintainer-email: Boris Bonev , Thorsten Kurth +Classifier: Development Status :: 3 - Alpha +Classifier: Programming Language :: Python :: 3.9 +Classifier: License :: OSI Approved :: BSD License +Classifier: Operating System :: OS Independent +Requires-Python: >=3.9 +Description-Content-Type: text/markdown +License-File: LICENSE +License-File: AUTHORS +Requires-Dist: torch>=2.4.0 +Requires-Dist: numpy<1.25,>=1.22.4 +Provides-Extra: dev +Requires-Dist: pytest>=6.0.0; extra == "dev" +Requires-Dist: coverage>=6.5.0; extra == "dev" + + + + + + + + + +# torch-harmonics + +[**Overview**](#overview) | [**Installation**](#installation) | [**More information**](#more-about-torch-harmonics) | [**Getting started**](#getting-started) | [**Contributors**](#contributors) | [**Cite us**](#cite-us) | [**References**](#references) + +[![tests](https://github.com/NVIDIA/torch-harmonics/actions/workflows/tests.yml/badge.svg)](https://github.com/NVIDIA/torch-harmonics/actions/workflows/tests.yml) +[![pypi](https://img.shields.io/pypi/v/torch_harmonics)](https://pypi.org/project/torch_harmonics/) + +## Overview + +torch-harmonics implements differentiable signal processing on the sphere. This includes differentiable implementations of the spherical harmonic transforms, vector spherical harmonic transforms and discrete-continuous convolutions on the sphere. The package was originally implemented to enable Spherical Fourier Neural Operators (SFNO) [1]. + +The SHT algorithm uses quadrature rules to compute the projection onto the associated Legendre polynomials and FFTs for the projection onto the harmonic basis. This algorithm tends to outperform others with better asymptotic scaling for most practical purposes [2]. + +torch-harmonics uses PyTorch primitives to implement these operations, making it fully differentiable. Moreover, the quadrature can be distributed onto multiple ranks making it spatially distributed. + +torch-harmonics has been used to implement a variety of differentiable PDE solvers which generated the animations below. Moreover, it has enabled the development of Spherical Fourier Neural Operators (SFNOs) [1]. + +
+ + + + + + + +
+
+ + +## Installation +A simple installation can be directly done from PyPI: + +```bash +pip install torch-harmonics +``` +If you are planning to use spherical convolutions, we recommend building the corresponding custom CUDA kernels. To enforce this, you can set the `FORCE_CUDA_EXTENSION` flag. You may also want to set appropriate architectures with the `TORCH_CUDA_ARCH_LIST` flag. Finally, make sure to disable build isolation via the `--no-build-isolation` flag to ensure that the custom kernels are built with the existing torch installation. +```bash +export FORCE_CUDA_EXTENSION=1 +export TORCH_CUDA_ARCH_LIST="7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX" +pip install --no-build-isolation torch-harmonics +``` +:warning: Please note that the custom CUDA extensions currently only support CUDA architectures >= 7.0. + +If you want to actively develop torch-harmonics, we recommend building it in your environment from github: + +```bash +git clone git@github.com:NVIDIA/torch-harmonics.git +cd torch-harmonics +pip install -e . +``` + +Alternatively, use the Dockerfile to build your custom container after cloning: + +```bash +git clone git@github.com:NVIDIA/torch-harmonics.git +cd torch-harmonics +docker build . -t torch_harmonics +docker run --gpus all -it --rm --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 torch_harmonics +``` + +## More about torch-harmonics + +### Spherical harmonics + +The [spherical harmonics](https://en.wikipedia.org/wiki/Spherical_harmonics) are special functions defined on the two-dimensional sphere $S^2$ (embedded in three dimensions). They form an orthonormal basis of the space of square-integrable functions defined on the sphere $L^2(S^2)$ and are comparable to the harmonic functions defined on a circle/torus. The spherical harmonics are defined as + +$$ +Y_l^m(\theta, \lambda) = \sqrt{\frac{(2l + 1)}{4 \pi} \frac{(l - m)!}{(l + m)!}} P_l^m(\cos \theta) \exp(im\lambda), +$$ + +where $\theta$ and $\lambda$ are colatitude and longitude respectively, and $P_l^m$ the normalized, [associated Legendre polynomials](https://en.wikipedia.org/wiki/Associated_Legendre_polynomials). + +
+ +
+Spherical harmonics up to degree 5 +
+ +### Spherical harmonic transform + +The spherical harmonic transform (SHT) + +$$ +f_l^m = \int_{S^2} \overline{Y_{l}^{m}}(\theta, \lambda) f(\theta, \lambda) \mathrm{d} \mu(\theta, \lambda) +$$ + +realizes the projection of a signal $f(\theta, \lambda)$ on $S^2$ onto the spherical harmonics basis. The SHT generalizes the Fourier transform on the sphere. Conversely, a truncated series expansion of a function $f$ can be written in terms of spherical harmonics as + +$$ +f (\theta, \lambda) = \sum_{m=-M}^{M} \exp(im\lambda) \sum_{l=|m|}^{M} \hat f_l^m P_l^m (\cos \theta), +$$ + +where $\hat{f}_l^m$, are the expansion coefficients associated to the mode $m$, $n$. + +The implementation of the SHT follows the algorithm as presented in [2]. A direct spherical harmonic transform can be accomplished by a Fourier transform + +$$ +\hat f^m(\theta) = \frac{1}{2 \pi} \int_{0}^{2\pi} f(\theta, \lambda) \exp(-im\lambda) \mathrm{d} \lambda +$$ + +in longitude and a Legendre transform + +$$ +\hat f_l^m = \frac{1}{2} \int^{\pi}_0 \hat f^{m} (\theta) P_l^m (\cos \theta) \sin \theta \mathrm{d} \theta +$$ + +in latitude. + +### Discrete Legendre transform + +The second integral, which computed the projection onto the Legendre polynomials is realized with quadrature. On the Gaussian grid, we use Gaussian quadrature in the $\cos \theta$ domain. The integral + +$$ +\hat f_l^m = \frac{1}{2} \int_{-1}^1 \hat{f}^m(\arccos x) P_l^m (x) \mathrm{d} x +$$ + +is obtained with the substitution $x = \cos \theta$ and then approximated by the sum + +$$ +\hat f_l^m = \sum_{j=1}^{N_\theta} \hat{f}^m(\arccos x_j) P_l^m(x_j) w_j. +$$ + +Here, $x_j \in [-1,1]$ are the quadrature nodes with the respective quadrature weights $w_j$. + +### Discrete-continuous convolutions + +torch-harmonics now provides local discrete-continuous (DISCO) convolutions as outlined in [4] on the sphere. + +## Getting started + +The main functionality of `torch_harmonics` is provided in the form of `torch.nn.Modules` for composability. A minimum example is given by: + +```python +import torch +import torch_harmonics as th + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +nlat = 512 +nlon = 2*nlat +batch_size = 32 +signal = torch.randn(batch_size, nlat, nlon) + +# transform data on an equiangular grid +sht = th.RealSHT(nlat, nlon, grid="equiangular").to(device) + +coeffs = sht(signal) +``` + +To enable scalable model-parallelism, `torch-harmonics` implements a distributed variant of the SHT located in `torch_harmonics.distributed`. + +Detailed usage of torch-harmonics, alongside helpful analysis provided in a series of notebooks: + +1. [Getting started](./notebooks/getting_started.ipynb) +2. [Quadrature](./notebooks/quadrature.ipynb) +3. [Visualizing the spherical harmonics](./notebooks/plot_spherical_harmonics.ipynb) +4. [Spectral fitting vs. SHT](./notebooks/gradient_analysis.ipynb) +5. [Conditioning of the Gramian](./notebooks/conditioning_sht.ipynb) +6. [Solving the Helmholtz equation](./notebooks/helmholtz.ipynb) +7. [Solving the shallow water equations](./notebooks/shallow_water_equations.ipynb) +8. [Training Spherical Fourier Neural Operators (SFNO)](./notebooks/train_sfno.ipynb) +9. [Resampling signals on the sphere](./notebooks/resample_sphere.ipynb) + +## Remarks on automatic mixed precision (AMP) support + +Note that torch-harmonics uses Fourier transforms from `torch.fft` which in turn uses kernels from the optimized `cuFFT` library. This library supports fourier transforms of `float32` and `float64` (i.e. `single` and `double` precision) tensors for all input sizes. For `float16` (i.e. `half` precision) and `bfloat16` inputs however, the dimensions which are transformed are restricted to powers of two. Since data is converted to one of these reduced precision floating point formats when `torch.autocast` is used, torch-harmonics will issue an error when the input shapes are not powers of two. For these cases, we recommend disabling autocast for the harmonics transform specifically: + +```python +import torch +import torch_harmonics as th + +sht = th.RealSHT(512, 1024, grid="equiangular").cuda() + +with torch.autocast(device_type="cuda", enabled = True): + # do some AMP converted math here + x = some_math(x) + # convert tensor to float32 + x = x.to(torch.float32) + # now disable autocast specifically for the transform, + # making sure that the tensors are not converted + # back to reduced precision internally + with torch.autocast(device_type="cuda", enabled = False): + xt = sht(x) + + # continue operating on the transformed tensor + xt = some_more_math(xt) +``` + +Depending on the problem, it might be beneficial to upcast data to `float64` instead of `float32` precision for numerical stability. + +## Contributors + +[Boris Bonev](https://bonevbs.github.io) (bbonev@nvidia.com), [Thorsten Kurth](https://github.com/azrael417) (tkurth@nvidia.com), [Mauro Bisson](https://scholar.google.com/citations?hl=en&user=f0JE-0gAAAAJ) , [Massimiliano Fatica](https://scholar.google.com/citations?user=Deaq4uUAAAAJ&hl=en), [Nikola Kovachki](https://kovachki.github.io), [Jean Kossaifi](http://jeankossaifi.com), [Christian Hundt](https://github.com/gravitino) + +## Cite us + +If you use `torch-harmonics` in an academic paper, please cite [1] + +```bibtex +@misc{bonev2023spherical, + title={Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere}, + author={Boris Bonev and Thorsten Kurth and Christian Hundt and Jaideep Pathak and Maximilian Baust and Karthik Kashinath and Anima Anandkumar}, + year={2023}, + eprint={2306.03838}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} +``` + +## References + +[1] +Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.; +Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere; +arXiv 2306.0383, 2023. + +[2] +Schaeffer N.; +Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations; +G3: Geochemistry, Geophysics, Geosystems, 2013. + +[3] +Wang B., Wang L., Xie Z.; +Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; +Adv Comput Math, 2018. + +[4] +Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603 diff --git a/torch_harmonics.egg-info/SOURCES.txt b/torch_harmonics.egg-info/SOURCES.txt new file mode 100644 index 00000000..684f0dca --- /dev/null +++ b/torch_harmonics.egg-info/SOURCES.txt @@ -0,0 +1,168 @@ +.coveragerc +.gitattributes +.gitignore +AUTHORS +Changelog.md +Dockerfile +LICENSE +README.md +pyproject.toml +setup.py +.github/dependabot.yml +.github/workflows/deploy_pypi.yml +.github/workflows/tests.yml +examples/example_prediction.png +examples/minimal_example.py +examples/sfno_prediction.png +examples/train_sfno.py +examples/.ipynb_checkpoints/example_prediction-checkpoint.png +examples/.ipynb_checkpoints/sfno_prediction-checkpoint.png +examples/.ipynb_checkpoints/train_sfno-checkpoint.py +examples/wandb/debug-cli.llingsch.log +examples/wandb/run-20241201_161810-q8w40874/run-q8w40874.wandb +examples/wandb/run-20241201_161810-q8w40874/files/config.yaml +examples/wandb/run-20241201_161810-q8w40874/files/output.log +examples/wandb/run-20241201_161810-q8w40874/files/requirements.txt +examples/wandb/run-20241201_161810-q8w40874/files/wandb-metadata.json +examples/wandb/run-20241201_161810-q8w40874/files/wandb-summary.json +examples/wandb/run-20241201_161810-q8w40874/logs/debug-internal.log +examples/wandb/run-20241201_161810-q8w40874/logs/debug.log +examples/wandb/run-20241201_173250-vfe7zrm4/run-vfe7zrm4.wandb +examples/wandb/run-20241201_173250-vfe7zrm4/files/config.yaml +examples/wandb/run-20241201_173250-vfe7zrm4/files/output.log +examples/wandb/run-20241201_173250-vfe7zrm4/files/requirements.txt +examples/wandb/run-20241201_173250-vfe7zrm4/files/wandb-metadata.json +examples/wandb/run-20241201_173250-vfe7zrm4/files/wandb-summary.json +examples/wandb/run-20241201_173250-vfe7zrm4/logs/debug-internal.log +examples/wandb/run-20241201_173250-vfe7zrm4/logs/debug.log +examples/wandb/run-20241201_192110-lbvctode/run-lbvctode.wandb +examples/wandb/run-20241201_192110-lbvctode/files/config.yaml +examples/wandb/run-20241201_192110-lbvctode/files/output.log +examples/wandb/run-20241201_192110-lbvctode/files/requirements.txt +examples/wandb/run-20241201_192110-lbvctode/files/wandb-metadata.json +examples/wandb/run-20241201_192110-lbvctode/files/wandb-summary.json +examples/wandb/run-20241201_192110-lbvctode/logs/debug-internal.log +examples/wandb/run-20241201_192110-lbvctode/logs/debug.log +examples/wandb/run-20241201_192236-ksdfj7ol/run-ksdfj7ol.wandb +examples/wandb/run-20241201_192236-ksdfj7ol/files/config.yaml +examples/wandb/run-20241201_192236-ksdfj7ol/files/output.log +examples/wandb/run-20241201_192236-ksdfj7ol/files/requirements.txt +examples/wandb/run-20241201_192236-ksdfj7ol/files/wandb-metadata.json +examples/wandb/run-20241201_192236-ksdfj7ol/files/wandb-summary.json +examples/wandb/run-20241201_192236-ksdfj7ol/logs/debug-internal.log +examples/wandb/run-20241201_192236-ksdfj7ol/logs/debug.log +examples/wandb/run-20241201_192525-s66p1xiw/run-s66p1xiw.wandb +examples/wandb/run-20241201_192525-s66p1xiw/files/config.yaml +examples/wandb/run-20241201_192525-s66p1xiw/files/output.log +examples/wandb/run-20241201_192525-s66p1xiw/files/requirements.txt +examples/wandb/run-20241201_192525-s66p1xiw/files/wandb-metadata.json +examples/wandb/run-20241201_192525-s66p1xiw/files/wandb-summary.json +examples/wandb/run-20241201_192525-s66p1xiw/logs/debug-internal.log +examples/wandb/run-20241201_192525-s66p1xiw/logs/debug.log +examples/wandb/run-20241201_230053-5zwrn21g/run-5zwrn21g.wandb +examples/wandb/run-20241201_230053-5zwrn21g/files/config.yaml +examples/wandb/run-20241201_230053-5zwrn21g/files/output.log +examples/wandb/run-20241201_230053-5zwrn21g/files/requirements.txt +examples/wandb/run-20241201_230053-5zwrn21g/files/wandb-metadata.json +examples/wandb/run-20241201_230053-5zwrn21g/files/wandb-summary.json +examples/wandb/run-20241201_230053-5zwrn21g/logs/debug-internal.log +examples/wandb/run-20241201_230053-5zwrn21g/logs/debug.log +images/allen-cahn.gif +images/ginzburg-landau.gif +images/sfno.gif +images/spherical_harmonics.gif +images/zonal_jet.gif +images/logo/logo.png +notebooks/conditioning_sht.ipynb +notebooks/getting_started.ipynb +notebooks/gradient_analysis.ipynb +notebooks/helmholtz.ipynb +notebooks/plot_spherical_harmonics.ipynb +notebooks/plotting.py +notebooks/quadrature.ipynb +notebooks/resample_sphere.ipynb +notebooks/shallow_water_equations.ipynb +notebooks/train_sfno.ipynb +tests/run_tests.sh +tests/test_convolution.py +tests/test_distributed_convolution.py +tests/test_distributed_sht.py +tests/test_sht.py +torch_harmonics/__init__.py +torch_harmonics/_disco_convolution.py +torch_harmonics/convolution.py +torch_harmonics/legendre.py +torch_harmonics/quadrature.py +torch_harmonics/random_fields.py +torch_harmonics/random_sampling.py +torch_harmonics/resampling.py +torch_harmonics/sht.py +torch_harmonics/sht_dse.py +torch_harmonics.egg-info/PKG-INFO +torch_harmonics.egg-info/SOURCES.txt +torch_harmonics.egg-info/dependency_links.txt +torch_harmonics.egg-info/requires.txt +torch_harmonics.egg-info/top_level.txt +torch_harmonics/.ipynb_checkpoints/__init__-checkpoint.py +torch_harmonics/.ipynb_checkpoints/random_fields-checkpoint.py +torch_harmonics/.ipynb_checkpoints/random_sampling-checkpoint.py +torch_harmonics/.ipynb_checkpoints/resampling-checkpoint.py +torch_harmonics/.ipynb_checkpoints/sht-checkpoint.py +torch_harmonics/.ipynb_checkpoints/sht_dse-checkpoint.py +torch_harmonics/__pycache__/__init__.cpython-311.pyc +torch_harmonics/__pycache__/_disco_convolution.cpython-311.pyc +torch_harmonics/__pycache__/convolution.cpython-311.pyc +torch_harmonics/__pycache__/legendre.cpython-311.pyc +torch_harmonics/__pycache__/quadrature.cpython-311.pyc +torch_harmonics/__pycache__/random_fields.cpython-311.pyc +torch_harmonics/__pycache__/random_sampling.cpython-311.pyc +torch_harmonics/__pycache__/resampling.cpython-311.pyc +torch_harmonics/__pycache__/sht.cpython-311.pyc +torch_harmonics/__pycache__/sht_dse.cpython-311.pyc +torch_harmonics/csrc/disco/disco.h +torch_harmonics/csrc/disco/disco_cuda.cuh +torch_harmonics/csrc/disco/disco_cuda_bwd.cu +torch_harmonics/csrc/disco/disco_cuda_fwd.cu +torch_harmonics/csrc/disco/disco_helpers.cpp +torch_harmonics/csrc/disco/disco_interface.cu +torch_harmonics/csrc/disco/.ipynb_checkpoints/disco_cuda_bwd-checkpoint.cu +torch_harmonics/distributed/__init__.py +torch_harmonics/distributed/distributed_convolution.py +torch_harmonics/distributed/distributed_sht.py +torch_harmonics/distributed/primitives.py +torch_harmonics/distributed/utils.py +torch_harmonics/distributed/.ipynb_checkpoints/distributed_convolution-checkpoint.py +torch_harmonics/distributed/.ipynb_checkpoints/primitives-checkpoint.py +torch_harmonics/distributed/.ipynb_checkpoints/utils-checkpoint.py +torch_harmonics/examples/__init__.py +torch_harmonics/examples/pde_sphere.py +torch_harmonics/examples/shallow_water_equations.py +torch_harmonics/examples/.ipynb_checkpoints/__init__-checkpoint.py +torch_harmonics/examples/.ipynb_checkpoints/pde_sphere-checkpoint.py +torch_harmonics/examples/.ipynb_checkpoints/shallow_water_equations-checkpoint.py +torch_harmonics/examples/__pycache__/__init__.cpython-311.pyc +torch_harmonics/examples/__pycache__/pde_sphere.cpython-311.pyc +torch_harmonics/examples/__pycache__/shallow_water_equations.cpython-311.pyc +torch_harmonics/examples/sfno/__init__.py +torch_harmonics/examples/sfno/.ipynb_checkpoints/__init__-checkpoint.py +torch_harmonics/examples/sfno/__pycache__/__init__.cpython-311.pyc +torch_harmonics/examples/sfno/models/__init__.py +torch_harmonics/examples/sfno/models/activations.py +torch_harmonics/examples/sfno/models/contractions.py +torch_harmonics/examples/sfno/models/factorizations.py +torch_harmonics/examples/sfno/models/layers.py +torch_harmonics/examples/sfno/models/sfno.py +torch_harmonics/examples/sfno/models/.ipynb_checkpoints/__init__-checkpoint.py +torch_harmonics/examples/sfno/models/.ipynb_checkpoints/sfno-checkpoint.py +torch_harmonics/examples/sfno/models/__pycache__/__init__.cpython-311.pyc +torch_harmonics/examples/sfno/models/__pycache__/activations.cpython-311.pyc +torch_harmonics/examples/sfno/models/__pycache__/contractions.cpython-311.pyc +torch_harmonics/examples/sfno/models/__pycache__/layers.cpython-311.pyc +torch_harmonics/examples/sfno/models/__pycache__/sfno.cpython-311.pyc +torch_harmonics/examples/sfno/utils/pde_dataset.py +torch_harmonics/examples/sfno/utils/.ipynb_checkpoints/pde_dataset-checkpoint.py +torch_harmonics/examples/sfno/utils/__pycache__/pde_dataset.cpython-311.pyc +torch_harmonics/examples/sfno_dse/__init__.py +torch_harmonics/examples/sfno_dse/models/__init__.py +torch_harmonics/examples/sfno_dse/models/sfno_dse.py +torch_harmonics/examples/sfno_dse/utils/pde_dataset.py \ No newline at end of file diff --git a/torch_harmonics.egg-info/dependency_links.txt b/torch_harmonics.egg-info/dependency_links.txt new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/torch_harmonics.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/torch_harmonics.egg-info/requires.txt b/torch_harmonics.egg-info/requires.txt new file mode 100644 index 00000000..e8a74ef6 --- /dev/null +++ b/torch_harmonics.egg-info/requires.txt @@ -0,0 +1,6 @@ +torch>=2.4.0 +numpy<1.25,>=1.22.4 + +[dev] +pytest>=6.0.0 +coverage>=6.5.0 diff --git a/torch_harmonics.egg-info/top_level.txt b/torch_harmonics.egg-info/top_level.txt new file mode 100644 index 00000000..6763aa04 --- /dev/null +++ b/torch_harmonics.egg-info/top_level.txt @@ -0,0 +1 @@ +torch_harmonics diff --git a/torch_harmonics/.ipynb_checkpoints/__init__-checkpoint.py b/torch_harmonics/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 00000000..5b3c5dea --- /dev/null +++ b/torch_harmonics/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,41 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +__version__ = "0.7.2" + +from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT +from .sht_dse import RealSHTDSE, BatchedRealSHTDSE +from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 +from .resampling import ResampleS2 +from . import quadrature +from . import random_fields +from . import examples +from .random_sampling import RandomSphericalSampling diff --git a/torch_harmonics/.ipynb_checkpoints/random_fields-checkpoint.py b/torch_harmonics/.ipynb_checkpoints/random_fields-checkpoint.py new file mode 100644 index 00000000..0ba2a6c3 --- /dev/null +++ b/torch_harmonics/.ipynb_checkpoints/random_fields-checkpoint.py @@ -0,0 +1,137 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import torch +from .sht import InverseRealSHT + +class GaussianRandomFieldS2(torch.nn.Module): + def __init__(self, nlat, alpha=2.0, tau=3.0, sigma=None, radius=1.0, grid="equiangular", dtype=torch.float32): + super().__init__() + r""" + A mean-zero Gaussian Random Field on the sphere with Matern covariance: + C = sigma^2 (-Lap + tau^2 I)^(-alpha). + + Lap is the Laplacian on the sphere, I the identity operator, + and sigma, tau, alpha are scalar parameters. + + Note: C is trace-class on L^2 if and only if alpha > 1. + + Parameters + ---------- + nlat : int + Number of latitudinal modes; + longitudinal modes are 2*nlat. + alpha : float, default is 2 + Regularity parameter. Larger means smoother. + tau : float, default is 3 + Lenght-scale parameter. Larger means more scales. + sigma : float, default is None + Scale parameter. Larger means bigger. + If None, sigma = tau**(0.5*(2*alpha - 2.0)). + radius : float, default is 1 + Radius of the sphere. + grid : string, default is "equiangular" + Grid type. Currently supports "equiangular" and + "legendre-gauss". + dtype : torch.dtype, default is torch.float32 + Numerical type for the calculations. + """ + + #Number of latitudinal modes. + self.nlat = nlat + + #Default value of sigma if None is given. + if sigma is None: + assert alpha > 1.0, f"Alpha must be greater than one, got {alpha}." + sigma = tau**(0.5*(2*alpha - 2.0)) + + # Inverse SHT + self.isht = InverseRealSHT(self.nlat, 2*self.nlat, grid=grid, norm='backward').to(dtype=dtype) + + #Square root of the eigenvalues of C. + sqrt_eig = torch.tensor([j*(j+1) for j in range(self.nlat)]).view(self.nlat,1).repeat(1, self.nlat+1) + sqrt_eig = torch.tril(sigma*(((sqrt_eig/radius**2) + tau**2)**(-alpha/2.0))) + sqrt_eig[0,0] = 0.0 + sqrt_eig = sqrt_eig.unsqueeze(0) + self.register_buffer('sqrt_eig', sqrt_eig) + + #Save mean and var of the standard Gaussian. + #Need these to re-initialize distribution on a new device. + mean = torch.tensor([0.0]).to(dtype=dtype) + var = torch.tensor([1.0]).to(dtype=dtype) + self.register_buffer('mean', mean) + self.register_buffer('var', var) + + #Standard normal noise sampler. + self.gaussian_noise = torch.distributions.normal.Normal(self.mean, self.var) + + def forward(self, N, xi=None): + r""" + Sample random functions from a spherical GRF. + + Parameters + ---------- + N : int + Number of functions to sample. + xi : torch.Tensor, default is None + Noise is a complex tensor of size (N, nlat, nlat+1). + If None, new Gaussian noise is sampled. + If xi is provided, N is ignored. + + Output + ------- + u : torch.Tensor + N random samples from the GRF returned as a + tensor of size (N, nlat, 2*nlat) on a equiangular grid. + """ + #Sample Gaussian noise. + if xi is None: + xi = self.gaussian_noise.sample(torch.Size((N, self.nlat, self.nlat + 1, 2))).squeeze() + xi = torch.view_as_complex(xi) + + #Karhunen-Loeve expansion. + u = self.isht(xi*self.sqrt_eig) + + return u + + #Override cuda and to methods so sampler gets initialized with mean + #and variance on the correct device. + def cuda(self, *args, **kwargs): + super().cuda(*args, **kwargs) + self.gaussian_noise = torch.distributions.normal.Normal(self.mean, self.var) + + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.gaussian_noise = torch.distributions.normal.Normal(self.mean, self.var) + + return self diff --git a/torch_harmonics/.ipynb_checkpoints/random_sampling-checkpoint.py b/torch_harmonics/.ipynb_checkpoints/random_sampling-checkpoint.py new file mode 100644 index 00000000..9ad43dd5 --- /dev/null +++ b/torch_harmonics/.ipynb_checkpoints/random_sampling-checkpoint.py @@ -0,0 +1,130 @@ +import numpy as np +import torch + +class RandomSphericalSampling: + r""" + Defines a module for sampling a (uniformly) random set of measurement points from a grid. + + [1] L. Lingsch, M. Michelis, E. de Bezenac, S. M. Perera, R. K. Katzschmann, S. Mishra; + Beyond Regular Grids: Fourier-Based Neural Operators on Arbitrary Domains; ICML 2024 + """ + def __init__(self, number_points_x, number_points_y): + # the data must be equispaced + self.number_points_x = number_points_x + self.number_points_y = number_points_y + np.random.seed(0) + + def random_points_on_sphere(self, n): + r""" + This function generates points within a 2x2x2 cube, centered at the origin. + Points with a radius<=1 are projected to a sphere with radius 1, centered at the origin. + Points with a radius > 1 are excluded. The newly generated random points are used to select + the closest points from the original grid, removing any duplicate points in this selection. + + Inputs: + class variables + n; approximate number of points to be selected (doubled, as about half the randomly generated points must be removed) + Outputs: + theta_index; vector indices of the original grid points to be selected along polar angle + phi_index; vector indices of the original grid points to be selected along azimuthal angle + >> used for selecting the points from the original data + theta_angle; vector of polar angles for points, ranging from 0 to pi + phi_angle; vector of azimuthal angles for points, ranging from 0 to 2*pi + """ + # Double the number of points to be selected, as approximately half will not be valid + n = n*2 + + # Generate random points in 3D space + x = np.random.uniform(-1, 1, n) + y = np.random.uniform(-1, 1, n) + z = np.random.uniform(-1, 1, n) + + # remove all points with radius greater than 1 (slightly less than half of all points) + magnitude = np.sqrt(x**2 + y**2 + z**2) + mask = magnitude <= 1.0 + magnitude_filtered = magnitude[mask] + x = x[mask] + y = y[mask] + z = z[mask] + + # Normalize the points to lie on the unit sphere + x /= magnitude_filtered + y /= magnitude_filtered + z /= magnitude_filtered + + # Return the points on the sphere + r = np.sqrt(x**2 + y**2 + z**2) + theta = np.arccos(z / r) + phi = np.arctan2(y, x) + np.pi + + theta = np.floor(theta*self.number_points_y / np.pi) + phi = np.floor(phi*self.number_points_x / (2*np.pi)) + + # remove duplicate points (there are about 2% duplicates, generally) + # Combine phi and theta into a 2D array + positions = np.column_stack((phi, theta)) + # Remove duplicate positions + unique_positions = np.unique(positions, axis=0) + + # Extract the cleaned phi and theta vectors + phi_index = unique_positions[:, 0] + theta_index = unique_positions[:, 1] + + phi_angle = torch.from_numpy(phi_index) / self.number_points_x * 2 * torch.pi + theta_angle = torch.from_numpy(theta_index) / self.number_points_y * torch.pi + + self.theta_index = theta_index + self.phi_index = phi_index + + return theta_index, phi_index, theta_angle.to(torch.float), phi_angle.to(torch.float) + + def random_sets_on_sphere(self, n, batch_size): + theta_indices = [] + phi_indices = [] + thetas = [] + phis = [] + + for _ in range(batch_size): + theta_index, phi_index, theta, phi = self.random_points_on_sphere(n) + theta_indices.append(theta_index) + phi_indices.append(phi_index) + thetas.append(theta) + phis.append(phi) + + + max_cols = min(matrix.size for matrix in theta_indices) + + padded_vectors = torch.zeros(len(theta_indices), max_cols, dtype=torch.float32, requires_grad=True) + theta_index = np.zeros((len(theta_indices), max_cols)) + phi_index = np.zeros((len(theta_indices), max_cols)) + theta = torch.zeros_like(padded_vectors) + phi = torch.zeros_like(padded_vectors) + + + for i in range(batch_size): + theta_index[i, :max_cols] = theta_indices[i][ :max_cols] + phi_index[i, :max_cols] = phi_indices[i][ :max_cols] + theta[i, :max_cols] = thetas[i][ :max_cols] + phi[i, :max_cols] = phis[i][ :max_cols] + + # for i in range(batch_size): + # n_points = theta_indices[i].size + # theta_index[i, :n_points] = theta_indices[i] + # phi_index[i, :n_points] = phi_indices[i] + # theta[i, :n_points] = thetas[i] + # phi[i, :n_points] = phis[i] + + return theta_index, phi_index, theta, phi + + + def get_random_sphere_data(self, data, thetas, phis): + + batch_size = thetas.shape[0] + num_points = thetas.shape[1] + + batch_indices = torch.arange(batch_size).unsqueeze(1).expand(-1, num_points) + + data_sparse = data[batch_indices,:,thetas,phis] + + return data_sparse.permute(0,2,1) + \ No newline at end of file diff --git a/torch_harmonics/.ipynb_checkpoints/resampling-checkpoint.py b/torch_harmonics/.ipynb_checkpoints/resampling-checkpoint.py new file mode 100644 index 00000000..55e5a0ba --- /dev/null +++ b/torch_harmonics/.ipynb_checkpoints/resampling-checkpoint.py @@ -0,0 +1,134 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from typing import List, Tuple, Union, Optional +import math +import numpy as np + +import torch +import torch.nn as nn + +from torch_harmonics.quadrature import _precompute_latitudes + + +class ResampleS2(nn.Module): + def __init__( + self, + nlat_in: int, + nlon_in: int, + nlat_out: int, + nlon_out: int, + grid_in: Optional[str] = "equiangular", + grid_out: Optional[str] = "equiangular", + mode: Optional[str] = "bilinear", + ): + + super().__init__() + + # currently only bilinear is supported + if mode == "bilinear": + self.mode = mode + else: + raise NotImplementedError(f"unknown interpolation mode {mode}") + + self.nlat_in, self.nlon_in = nlat_in, nlon_in + self.nlat_out, self.nlon_out = nlat_out, nlon_out + + self.grid_in = grid_in + self.grid_out = grid_out + + # for upscaling the latitudes we will use interpolation + self.lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in) + self.lons_in = np.linspace(0, 2 * math.pi, nlon_in, endpoint=False) + self.lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out) + self.lons_out = np.linspace(0, 2 * math.pi, nlon_out, endpoint=False) + + # prepare the interpolation by computing indices to the left and right of each output latitude + lat_idx = np.searchsorted(self.lats_in, self.lats_out, side="right") - 1 + # to guarantee everything stays in bounds + lat_idx = np.where(self.lats_out == self.lats_in[-1], lat_idx - 1, lat_idx) + + # compute the interpolation weights along the latitude + lat_weights = torch.from_numpy((self.lats_out - self.lats_in[lat_idx]) / np.diff(self.lats_in)[lat_idx]).float() + lat_weights = lat_weights.unsqueeze(-1) + + # convert to tensor + lat_idx = torch.LongTensor(lat_idx) + + # register buffers + self.register_buffer("lat_idx", lat_idx, persistent=False) + self.register_buffer("lat_weights", lat_weights, persistent=False) + + # get left and right indices but this time make sure periodicity in the longitude is handled + lon_idx_left = np.searchsorted(self.lons_in, self.lons_out, side="right") - 1 + lon_idx_right = np.where(self.lons_out >= self.lons_in[-1], np.zeros_like(lon_idx_left), lon_idx_left + 1) + + # get the difference + diff = self.lons_in[lon_idx_right] - self.lons_in[lon_idx_left] + diff = np.where(diff < 0.0, diff + 2 * math.pi, diff) + lon_weights = torch.from_numpy((self.lons_out - self.lons_in[lon_idx_left]) / diff).float() + + # convert to tensor + lon_idx_left = torch.LongTensor(lon_idx_left) + lon_idx_right = torch.LongTensor(lon_idx_right) + + # register buffers + self.register_buffer("lon_idx_left", lon_idx_left, persistent=False) + self.register_buffer("lon_idx_right", lon_idx_right, persistent=False) + self.register_buffer("lon_weights", lon_weights, persistent=False) + + def extra_repr(self): + r""" + Pretty print module + """ + return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}" + + def _upscale_longitudes(self, x: torch.Tensor): + # do the interpolation + x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], self.lon_weights) + return x + + # old deprecated method with repeat_interleave + # def _upscale_longitudes(self, x: torch.Tensor): + # # for artifact-free upsampling in the longitudinal direction + # x = torch.repeat_interleave(x, self.lon_scale_factor, dim=-1) + # x = torch.roll(x, - self.lon_shift, dims=-1) + # return x + + def _upscale_latitudes(self, x: torch.Tensor): + # do the interpolation + x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], self.lat_weights) + return x + + def forward(self, x: torch.Tensor): + x = self._upscale_latitudes(x) + x = self._upscale_longitudes(x) + return x diff --git a/torch_harmonics/.ipynb_checkpoints/sht-checkpoint.py b/torch_harmonics/.ipynb_checkpoints/sht-checkpoint.py new file mode 100644 index 00000000..18da076c --- /dev/null +++ b/torch_harmonics/.ipynb_checkpoints/sht-checkpoint.py @@ -0,0 +1,401 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import numpy as np +import torch +import torch.nn as nn +import torch.fft + +from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights +from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly + + +class RealSHT(nn.Module): + r""" + Defines a module for computing the forward (real-valued) SHT. + Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. + The SHT is applied to the last two dimensions of the input + + [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. + [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. + """ + + def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True): + r""" + Initializes the SHT Layer, precomputing the necessary quadrature weights + + Parameters: + nlat: input grid resolution in the latitudinal direction + nlon: input grid resolution in the longitudinal direction + grid: grid in the latitude direction (for now only tensor product grids are supported) + """ + + super().__init__() + + self.nlat = nlat + self.nlon = nlon + self.grid = grid + self.norm = norm + self.csphase = csphase + + # TODO: include assertions regarding the dimensions + + # compute quadrature points + if self.grid == "legendre-gauss": + cost, w = legendre_gauss_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "lobatto": + cost, w = lobatto_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat-1 + elif self.grid == "equiangular": + cost, w = clenshaw_curtiss_weights(nlat, -1, 1) + # cost, w = fejer2_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + else: + raise(ValueError("Unknown quadrature mode")) + + # apply cosine transform and flip them + tq = np.flip(np.arccos(cost)) + + # determine the dimensions + self.mmax = mmax or self.nlon // 2 + 1 + + # combine quadrature weights with the legendre weights + weights = torch.from_numpy(w) + pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase) + pct = torch.from_numpy(pct) + weights = torch.einsum('mlk,k->mlk', pct, weights) + + # remember quadrature weights + self.register_buffer('weights', weights, persistent=False) + + def extra_repr(self): + r""" + Pretty print module + """ + return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' + + def forward(self, x: torch.Tensor): + + if x.dim() < 2: + raise ValueError(f"Expected tensor with at least 2 dimensions but got {x.dim()} instead") + + assert(x.shape[-2] == self.nlat) + assert(x.shape[-1] == self.nlon) + + # apply real fft in the longitudinal direction + x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward") + + # do the Legendre-Gauss quadrature + x = torch.view_as_real(x) + + # distributed contraction: fork + out_shape = list(x.size()) + out_shape[-3] = self.lmax + out_shape[-2] = self.mmax + xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device) + + # contraction + xout[..., 0] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 0], self.weights.to(x.dtype) ) + xout[..., 1] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 1], self.weights.to(x.dtype) ) + x = torch.view_as_complex(xout) + + return x + +class InverseRealSHT(nn.Module): + r""" + Defines a module for computing the inverse (real-valued) SHT. + Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. + nlat, nlon: Output dimensions + lmax, mmax: Input dimensions (spherical coefficients). For convenience, these are inferred from the output dimensions + + [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. + [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. + """ + + def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True): + + super().__init__() + + self.nlat = nlat + self.nlon = nlon + self.grid = grid + self.norm = norm + self.csphase = csphase + + # compute quadrature points + if self.grid == "legendre-gauss": + cost, _ = legendre_gauss_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "lobatto": + cost, _ = lobatto_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat-1 + elif self.grid == "equiangular": + cost, _ = clenshaw_curtiss_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + else: + raise(ValueError("Unknown quadrature mode")) + + # apply cosine transform and flip them + t = np.flip(np.arccos(cost)) + + # determine the dimensions + self.mmax = mmax or self.nlon // 2 + 1 + + pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase) + pct = torch.from_numpy(pct) + + # register buffer + self.register_buffer('pct', pct, persistent=False) + + def extra_repr(self): + r""" + Pretty print module + """ + return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' + + def forward(self, x: torch.Tensor): + + if len(x.shape) < 2: + raise ValueError(f"Expected tensor with at least 2 dimensions but got {len(x.shape)} instead") + + assert(x.shape[-2] == self.lmax) + assert(x.shape[-1] == self.mmax) + + # Evaluate associated Legendre functions on the output nodes + x = torch.view_as_real(x) + + rl = torch.einsum('...lm, mlk->...km', x[..., 0], self.pct.to(x.dtype) ) + im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct.to(x.dtype) ) + xs = torch.stack((rl, im), -1) + + # apply the inverse (real) FFT + x = torch.view_as_complex(xs) + x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") + + return x + + +class RealVectorSHT(nn.Module): + r""" + Defines a module for computing the forward (real) vector SHT. + Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. + The SHT is applied to the last three dimensions of the input. + + [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. + [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. + """ + + def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True): + r""" + Initializes the vector SHT Layer, precomputing the necessary quadrature weights + + Parameters: + nlat: input grid resolution in the latitudinal direction + nlon: input grid resolution in the longitudinal direction + grid: type of grid the data lives on + """ + + super().__init__() + + self.nlat = nlat + self.nlon = nlon + self.grid = grid + self.norm = norm + self.csphase = csphase + + # compute quadrature points + if self.grid == "legendre-gauss": + cost, w = legendre_gauss_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "lobatto": + cost, w = lobatto_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat-1 + elif self.grid == "equiangular": + cost, w = clenshaw_curtiss_weights(nlat, -1, 1) + # cost, w = fejer2_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + else: + raise(ValueError("Unknown quadrature mode")) + + # apply cosine transform and flip them + tq = np.flip(np.arccos(cost)) + + # determine the dimensions + self.mmax = mmax or self.nlon // 2 + 1 + + weights = torch.from_numpy(w) + dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase) + dpct = torch.from_numpy(dpct) + + # combine integration weights, normalization factor in to one: + l = torch.arange(0, self.lmax) + norm_factor = 1. / l / (l+1) + norm_factor[0] = 1. + weights = torch.einsum('dmlk,k,l->dmlk', dpct, weights, norm_factor) + # since the second component is imaginary, we need to take complex conjugation into account + weights[1] = -1 * weights[1] + + # remember quadrature weights + self.register_buffer('weights', weights, persistent=False) + + def extra_repr(self): + r""" + Pretty print module + """ + return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' + + def forward(self, x: torch.Tensor): + + if x.dim() < 3: + raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead") + + assert(x.shape[-2] == self.nlat) + assert(x.shape[-1] == self.nlon) + + # apply real fft in the longitudinal direction + x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward") + + # do the Legendre-Gauss quadrature + x = torch.view_as_real(x) + + # distributed contraction: fork + out_shape = list(x.size()) + out_shape[-3] = self.lmax + out_shape[-2] = self.mmax + xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device) + + # contraction - spheroidal component + # real component + xout[..., 0, :, :, 0] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 0], self.weights[0].to(x.dtype)) \ + - torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 1], self.weights[1].to(x.dtype)) + + # iamg component + xout[..., 0, :, :, 1] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 1], self.weights[0].to(x.dtype)) \ + + torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 0], self.weights[1].to(x.dtype)) + + # contraction - toroidal component + # real component + xout[..., 1, :, :, 0] = - torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 1], self.weights[1].to(x.dtype)) \ + - torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 0], self.weights[0].to(x.dtype)) + # imag component + xout[..., 1, :, :, 1] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 0], self.weights[1].to(x.dtype)) \ + - torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 1], self.weights[0].to(x.dtype)) + + return torch.view_as_complex(xout) + + +class InverseRealVectorSHT(nn.Module): + r""" + Defines a module for computing the inverse (real-valued) vector SHT. + Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. + + [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. + [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math. + """ + def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True): + + super().__init__() + + self.nlat = nlat + self.nlon = nlon + self.grid = grid + self.norm = norm + self.csphase = csphase + + # compute quadrature points + if self.grid == "legendre-gauss": + cost, _ = legendre_gauss_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + elif self.grid == "lobatto": + cost, _ = lobatto_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat-1 + elif self.grid == "equiangular": + cost, _ = clenshaw_curtiss_weights(nlat, -1, 1) + self.lmax = lmax or self.nlat + else: + raise(ValueError("Unknown quadrature mode")) + + # apply cosine transform and flip them + t = np.flip(np.arccos(cost)) + + # determine the dimensions + self.mmax = mmax or self.nlon // 2 + 1 + + dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase) + dpct = torch.from_numpy(dpct) + + # register weights + self.register_buffer('dpct', dpct, persistent=False) + + def extra_repr(self): + r""" + Pretty print module + """ + return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}' + + def forward(self, x: torch.Tensor): + + if x.dim() < 3: + raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead") + + assert(x.shape[-2] == self.lmax) + assert(x.shape[-1] == self.mmax) + + # Evaluate associated Legendre functions on the output nodes + x = torch.view_as_real(x) + + # contraction - spheroidal component + # real component + srl = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[0].to(x.dtype)) \ + - torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[1].to(x.dtype)) + # iamg component + sim = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[0].to(x.dtype)) \ + + torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[1].to(x.dtype)) + + # contraction - toroidal component + # real component + trl = - torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[1].to(x.dtype)) \ + - torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[0].to(x.dtype)) + # imag component + tim = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[1].to(x.dtype)) \ + - torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[0].to(x.dtype)) + + # reassemble + s = torch.stack((srl, sim), -1) + t = torch.stack((trl, tim), -1) + xs = torch.stack((s, t), -4) + + # apply the inverse (real) FFT + x = torch.view_as_complex(xs) + x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") + + return x diff --git a/torch_harmonics/.ipynb_checkpoints/sht_dse-checkpoint.py b/torch_harmonics/.ipynb_checkpoints/sht_dse-checkpoint.py new file mode 100644 index 00000000..3213ced6 --- /dev/null +++ b/torch_harmonics/.ipynb_checkpoints/sht_dse-checkpoint.py @@ -0,0 +1,218 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +import torch +import torch.nn as nn +import torch.nn.functional as F + +from scipy.special import lpmv +import numpy as np + +class RealSHTDSE(): + r""" + Defines a module for computing the forward/backward SHT on arbitrary points. + Requires the collocation points (locations of measurements on the surface of the sphere), + as defined by the polar (theta) and azimuthal (phi) angles. + The SHT is applied to the last two dimensions of the input. + + [1] L. Lingsch, M. Michelis, E. de Bezenac, S. M. Perera, R. K. Katzschmann, S. Mishra; + Beyond Regular Grids: Fourier-Based Neural Operators on Arbitrary Domains; ICML 2024 + """ + + def __init__(self, phi, theta, degree): + """ + Initializes the matrices to compute the forward/backward SHT on arbitrary points. + + Parameters: + phi: input point locations as a azimuthal angle + theta: input grid locations as a polar angle + degree: degree of the spherical harmonics, total number of modes equal to degree^2 + """ + self.theta = theta # between 0 and pi + self.phi = phi # between 0 and 2 pi + + self.degree = degree + + self.num_points = theta.shape[0] + + self.V_fwd, self.V_inv = self.make_matrix() + + def make_matrix(self): + """ + Constructs the matrices to compute spherical harmonics transforms + + Inputs: + class variables + Outputs: + V_fwd computes the forward transform via matrix multiplication + V_inv computes the inverse transform via matrix multiplication + """ + V_forward = torch.zeros((self.num_points, self.degree ** 2), dtype=torch.float) + index = 0 + for l in range(self.degree): + for m in range(-l, l+1): + if index > 0: + c = np.sqrt(2) + else: + c = 1 + if m < 0: + V_forward[:, index] = (lpmv(m, l, torch.cos(self.theta)) * torch.sin(m*self.phi)) + V_forward[:,index] = c * V_forward[:,index] / torch.max( V_forward[:,index]) + else: + V_forward[:, index] = (lpmv(m, l, torch.cos(self.theta)) * torch.cos(m*self.phi)) + V_forward[:,index] = c * V_forward[:,index] / torch.max( V_forward[:,index]) + index += 1 + + return V_forward.cuda(), torch.transpose(V_forward, 0, 1).cuda() + + def forward(self, data): + """ + Computes the spherical harmonics from the data + + Inputs: + class variables + data; vector of inputs in spatial domain + Outputs: + data_fwd; data in spherical harmonic domain up to set degree + """ + data_fwd = torch.matmul(data, self.V_fwd) + + return data_fwd + + def inverse(self, data): + """ + Computes the modified data from the spherical harmonics + Note: This is not technically an inverse, as orthogonality is not preserved. + Nonetheless, we refer to it as such. + + Inputs: + class variables + data; vector of inputs in spherical harmonics domain + Outputs: + data_inv; data in spatial domain + """ + data_inv = torch.matmul(data, self.V_inv) / self.num_points + + return data_inv + + +class BatchedRealSHTDSE(): + r""" + Defines a module for computing the forward/backward SHT on arbitrary points. + Requires the collocation points (locations of measurements on the surface of the sphere), + as defined by the polar (theta) and azimuthal (phi) angles. + The SHT is applied to the last two dimensions of the input. + + [1] L. Lingsch, M. Michelis, E. de Bezenac, S. M. Perera, R. K. Katzschmann, S. Mishra; + Beyond Regular Grids: Fourier-Based Neural Operators on Arbitrary Domains; ICML 2024 + """ + + def __init__(self, phi, theta, degree): + """ + Initializes the matrices to compute the forward/backward SHT on arbitrary points. + + Parameters: + phi: input point locations as a azimuthal angle + theta: input grid locations as a polar angle + degree: degree of the spherical harmonics, total number of modes equal to degree^2 + """ + self.theta = theta # between 0 and pi + self.phi = phi # between 0 and 2 pi + + self.degree = degree + + self.batch_size = theta.shape[0] + self.num_points = theta.shape[1] + + self.V_fwd, self.V_inv = self.make_matrix() + + + def compute_legendre_matrix(self, l): + """ + Compute all associated Legendre polynomials for degree `l` across the batch. + Uses scipy.special.lpmv to generate values in a vectorized way. + """ + theta_cos = torch.cos(self.theta) + P_l_m = [] + for m in range(-l, l + 1): + P_lm = lpmv(m, l, theta_cos.cpu().numpy()) # lpmv operates on numpy arrays + P_l_m.append(torch.tensor(P_lm, dtype=torch.float, device=self.theta.device)) + return torch.stack(P_l_m, dim=0) # Shape: (2l+1, num_points) + + def make_matrix(self): + V_fwd = torch.zeros((self.batch_size, self.num_points, self.degree ** 2), dtype=torch.float, device=self.theta.device) + + index = 0 + + for l in range(self.degree): + P_l_m = self.compute_legendre_matrix(l) # Shape: (2l+1, num_points) + + for m in range(-l, l + 1): + trig_term = torch.sin(m * self.phi) if m < 0 else torch.cos(m * self.phi) + scale_factor = np.sqrt(2) if m != 0 else 1.0 + + V_fwd[:, :, index] = scale_factor * P_l_m[m + l, :] * trig_term + V_fwd[:, :, index] /= torch.max(V_fwd[:, :, index]).clamp(min=1e-6) # Avoid division by zero + index += 1 + + + + return V_fwd.cuda(), V_fwd.permute(0,2,1).cuda() + + def forward(self, data): + """ + Computes the spherical harmonics from the data + + Inputs: + class variables + data; vector of inputs in spatial domain + Outputs: + data_fwd; data in spherical harmonic domain up to set degree + """ + data_fwd = torch.matmul(data, self.V_fwd) + + return data_fwd + + def inverse(self, data): + """ + Computes the modified data from the spherical harmonics + Note: This is not technically an inverse, as orthogonality is not preserved. + Nonetheless, we refer to it as such. + + Inputs: + class variables + data; vector of inputs in spherical harmonics domain + Outputs: + data_inv; data in spatial domain + """ + data_inv = torch.matmul(data, self.V_inv) / self.num_points + + return data_inv + diff --git a/torch_harmonics/__init__.py b/torch_harmonics/__init__.py index 5ab04a17..5b3c5dea 100644 --- a/torch_harmonics/__init__.py +++ b/torch_harmonics/__init__.py @@ -32,8 +32,10 @@ __version__ = "0.7.2" from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT +from .sht_dse import RealSHTDSE, BatchedRealSHTDSE from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 from .resampling import ResampleS2 from . import quadrature from . import random_fields from . import examples +from .random_sampling import RandomSphericalSampling diff --git a/torch_harmonics/__pycache__/__init__.cpython-311.pyc b/torch_harmonics/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..e57abda6 Binary files /dev/null and b/torch_harmonics/__pycache__/__init__.cpython-311.pyc differ diff --git a/torch_harmonics/__pycache__/_disco_convolution.cpython-311.pyc b/torch_harmonics/__pycache__/_disco_convolution.cpython-311.pyc new file mode 100644 index 00000000..f0d54051 Binary files /dev/null and b/torch_harmonics/__pycache__/_disco_convolution.cpython-311.pyc differ diff --git a/torch_harmonics/__pycache__/convolution.cpython-311.pyc b/torch_harmonics/__pycache__/convolution.cpython-311.pyc new file mode 100644 index 00000000..f88e31e8 Binary files /dev/null and b/torch_harmonics/__pycache__/convolution.cpython-311.pyc differ diff --git a/torch_harmonics/__pycache__/legendre.cpython-311.pyc b/torch_harmonics/__pycache__/legendre.cpython-311.pyc new file mode 100644 index 00000000..4d3b1500 Binary files /dev/null and b/torch_harmonics/__pycache__/legendre.cpython-311.pyc differ diff --git a/torch_harmonics/__pycache__/quadrature.cpython-311.pyc b/torch_harmonics/__pycache__/quadrature.cpython-311.pyc new file mode 100644 index 00000000..8adee434 Binary files /dev/null and b/torch_harmonics/__pycache__/quadrature.cpython-311.pyc differ diff --git a/torch_harmonics/__pycache__/random_fields.cpython-311.pyc b/torch_harmonics/__pycache__/random_fields.cpython-311.pyc new file mode 100644 index 00000000..f1042252 Binary files /dev/null and b/torch_harmonics/__pycache__/random_fields.cpython-311.pyc differ diff --git a/torch_harmonics/__pycache__/random_sampling.cpython-311.pyc b/torch_harmonics/__pycache__/random_sampling.cpython-311.pyc new file mode 100644 index 00000000..940cc3fe Binary files /dev/null and b/torch_harmonics/__pycache__/random_sampling.cpython-311.pyc differ diff --git a/torch_harmonics/__pycache__/resampling.cpython-311.pyc b/torch_harmonics/__pycache__/resampling.cpython-311.pyc new file mode 100644 index 00000000..0708d6b4 Binary files /dev/null and b/torch_harmonics/__pycache__/resampling.cpython-311.pyc differ diff --git a/torch_harmonics/__pycache__/sht.cpython-311.pyc b/torch_harmonics/__pycache__/sht.cpython-311.pyc new file mode 100644 index 00000000..80b17178 Binary files /dev/null and b/torch_harmonics/__pycache__/sht.cpython-311.pyc differ diff --git a/torch_harmonics/__pycache__/sht_dse.cpython-311.pyc b/torch_harmonics/__pycache__/sht_dse.cpython-311.pyc new file mode 100644 index 00000000..07645dc3 Binary files /dev/null and b/torch_harmonics/__pycache__/sht_dse.cpython-311.pyc differ diff --git a/torch_harmonics/csrc/disco/.ipynb_checkpoints/disco_cuda_bwd-checkpoint.cu b/torch_harmonics/csrc/disco/.ipynb_checkpoints/disco_cuda_bwd-checkpoint.cu new file mode 100644 index 00000000..d31fc28f --- /dev/null +++ b/torch_harmonics/csrc/disco/.ipynb_checkpoints/disco_cuda_bwd-checkpoint.cu @@ -0,0 +1,373 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "disco.h" +#include "disco_cuda.cuh" + + +template +__device__ void disco_bwd_d(const int Hi, + const int Wi, + const int K, + const int Ho, + const int Wo, + const int pscale, + const int64_t *__restrict__ roff, + const int64_t *__restrict__ kers, + const int64_t *__restrict__ rows, + const int64_t *__restrict__ cols, + const REAL_T *__restrict__ vals, + const REAL_T *__restrict__ inp, + REAL_T *__restrict__ out) { + + const int tid = threadIdx.x; + + const int64_t bidx = blockIdx.x; // gloabl row + const int64_t bidy = blockIdx.y; // bc + + int64_t soff = roff[bidx]; + int64_t eoff = roff[bidx+1]; + + const int64_t ker = kers[soff]; + const int64_t row = rows[soff]; + + inp += bidy*K*Hi*Wi + ker*Hi*Wi + row*Wi; + out += bidy*Ho*Wo; + + // align to larger supported fp type + extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*(BDIM_X*ELXTH)*pscale] + + REAL_T (*__sh)[BDIM_X*ELXTH*2] = reinterpret_cast(__sh_ptr); + + // copy current inp row in regs + REAL_T __reg[ELXTH]; + + #pragma unroll + for(int i = 0; i < ELXTH; i++) { + __reg[i] = (i*BDIM_X+tid < Wi) ? inp[i*BDIM_X +tid] : REAL_T(0); + } + + // reset shared row up to Wo+2, remaining + // ppscale*(BDIM_X*ELXTH - Wo) locations + // will be written to but never copied to + // global mem + for(int i = 0; i < pscale; i++) { + #pragma unroll + for(int j = 0; j < 2*BDIM_X*ELXTH; j += BDIM_X) { + __sh[i][j+tid] = 0; + } + } + __syncthreads(); + + int col_prev = cols[soff]; + + int h_prev = col_prev / Wo; + int w_prev = col_prev % Wo; + + // loops along the colums of CTA's row + for(int64_t nz = soff; nz < eoff; nz++) { + + const int col = cols[nz]; + const REAL_T val = vals[nz]; + + // if we are processing a nz with a col value + // leading to a new row of inp then copy it + // to shmem; + // we read a col that points to a new output + // row if (col / Wo) > (col_prev / Wo) + if (col >= col_prev-w_prev+Wo) { + __syncthreads(); + for(int i = 0; i < pscale; i++) { + for(int j = tid; j < Wi; j += BDIM_X) { + + const REAL_T v = __sh[i][j] + __sh[i][Wi + j]; + + atomicAdd(&out[h_prev*Wo + j*pscale + i], v); + + __sh[i][ j] = 0; + __sh[i][Wi + j] = 0; + } + } + __syncthreads(); + + col_prev = col; + h_prev = col / Wo; + w_prev = col % Wo; + } + + const int w = w_prev + (col-col_prev); + const int w_mod_ps = w % pscale; + const int w_div_ps = w / pscale; + + #pragma unroll + for (int i = 0; i < ELXTH; i++) { + + const int pp = i*BDIM_X + tid; + __sh[w_mod_ps][w_div_ps + pp] += val*__reg[i]; + } + + // to avoid race conditions on __sh[] + // among consecutive iterations along nz + __syncthreads(); + } + __syncthreads(); + + // write last row + for(int i = 0; i < pscale; i++) { + + for(int j = tid; j < Wi; j += BDIM_X) { + + const REAL_T v = __sh[i][j] + __sh[i][Wi + j]; + atomicAdd(&out[h_prev*Wo + j*pscale + i], v); + } + } + return; +} + + +template +__global__ __launch_bounds__(BDIM_X) +void disco_bwd_blk_k(const int Hi, + const int Wi, + const int K, + const int Ho, + const int Wo, + const int pscale, + const int64_t *__restrict__ roff, + const int64_t *__restrict__ kers, + const int64_t *__restrict__ rows, + const int64_t *__restrict__ cols, + const REAL_T *__restrict__ vals, + const REAL_T *__restrict__ inp, + REAL_T *__restrict__ out) { + + if constexpr(PSCALE != 0) { disco_bwd_d(Hi, Wi, K, Ho, Wo, PSCALE, roff, kers, rows, cols, vals, inp, out); } + else { disco_bwd_d(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out); } + + return; +} + + +template +static void launch_kernel(int BC, + int Hi, + int Wi, + int K, + int Ho, + int Wo, + int64_t nrows, + int64_t *roff_d, + int64_t *ker_d, + int64_t *row_d, + int64_t *col_d, + REAL_T *val_d, + REAL_T *inp_d, + REAL_T *out_d, + cudaStream_t stream) { + + static_assert(sizeof(REAL_T) == 2 || + sizeof(REAL_T) == 4 || + sizeof(REAL_T) == 8); + + if constexpr(ELXTH <= ELXTH_MAX) { + if (NTH*ELXTH >= Wi) { + dim3 grid(nrows, BC); + + const int pscale = Wo/Wi; + size_t shmem = sizeof(*out_d)*(2 * (NTH*ELXTH)*pscale); + + switch(pscale) { + case 1: + disco_bwd_blk_k<<>>(Hi, Wi, + K, Ho, Wo, pscale, + roff_d, + ker_d, row_d, col_d, val_d, + inp_d, out_d); + break; + case 2: + disco_bwd_blk_k<<>>(Hi, Wi, + K, Ho, Wo, pscale, + roff_d, + ker_d, row_d, col_d, val_d, + inp_d, out_d); + break; + case 3: + disco_bwd_blk_k<<>>(Hi, Wi, + K, Ho, Wo, pscale, + roff_d, + ker_d, row_d, col_d, val_d, + inp_d, out_d); + break; + default: + disco_bwd_blk_k<<>>(Hi, Wi, + K, Ho, Wo, pscale, + roff_d, + ker_d, row_d, col_d, val_d, + inp_d, out_d); + } + } else { + launch_kernel(BC, + Hi, Wi, + K, Ho, Wo, + nrows, + roff_d, + ker_d, row_d, col_d, val_d, + inp_d, out_d, + stream); + } + } + return; +} + + +torch::Tensor disco_cuda_bwd(torch::Tensor inp, + torch::Tensor roff_idx, + torch::Tensor ker_idx, + torch::Tensor row_idx, + torch::Tensor col_idx, + torch::Tensor val, + int64_t K, + int64_t Ho, + int64_t Wo) { + + // some sanity checks + CHECK_CUDA_INPUT_TENSOR(inp); + CHECK_CUDA_INPUT_TENSOR(roff_idx); + CHECK_CUDA_INPUT_TENSOR(ker_idx); + CHECK_CUDA_INPUT_TENSOR(row_idx); + CHECK_CUDA_INPUT_TENSOR(col_idx); + CHECK_CUDA_INPUT_TENSOR(val); + + // extract some shapes + int64_t B = inp.size(0); + int64_t C = inp.size(1); + int64_t BC = B * C; + int64_t Hi = inp.size(3); + int64_t Wi = inp.size(4); + int64_t nrows = roff_idx.size(0) - 1; + + // allocate output + int64_t out_dims[] = {B, C, Ho, Wo}; + auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype()); + torch::Tensor out = torch::zeros(out_dims, options); + + // get stream + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + // assert + static_assert(0 == (ELXTH_MAX%2)); + + + if (Wo <= 64*ELXTH_MAX) { + AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { + launch_kernel<64, 1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows, + roff_idx.data_ptr(), + ker_idx.data_ptr(), + row_idx.data_ptr(), + col_idx.data_ptr(), + val.data_ptr(), + inp.data_ptr(), + out.data_ptr(), + stream); + })); + } + else if (Wo <= 128*ELXTH_MAX) { + AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { + launch_kernel<128, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows, + roff_idx.data_ptr(), + ker_idx.data_ptr(), + row_idx.data_ptr(), + col_idx.data_ptr(), + val.data_ptr(), + inp.data_ptr(), + out.data_ptr(), + stream); + })); + } + else if (Wo <= 256*ELXTH_MAX) { + AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { + launch_kernel<256, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows, + roff_idx.data_ptr(), + ker_idx.data_ptr(), + row_idx.data_ptr(), + col_idx.data_ptr(), + val.data_ptr(), + inp.data_ptr(), + out.data_ptr(), + stream); + })); + } + else if (Wo <= 512*ELXTH_MAX) { + AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { + launch_kernel<512, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows, + roff_idx.data_ptr(), + ker_idx.data_ptr(), + row_idx.data_ptr(), + col_idx.data_ptr(), + val.data_ptr(), + inp.data_ptr(), + out.data_ptr(), + stream); + })); + } + else if (Wo <= 1024*ELXTH_MAX) { + AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { + launch_kernel<1024, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows, + roff_idx.data_ptr(), + ker_idx.data_ptr(), + row_idx.data_ptr(), + col_idx.data_ptr(), + val.data_ptr(), + inp.data_ptr(), + out.data_ptr(), + stream); + })); + } + else { + fprintf(stderr, + "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", + __FILE__, __LINE__, Wo, 1024*ELXTH_MAX); + exit(EXIT_FAILURE); + } + + + return out; +} + +//PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +// m.def("backward", &disco_cuda_bwd, "DISCO backward (CUDA)"); +//} diff --git a/torch_harmonics/distributed/.ipynb_checkpoints/distributed_convolution-checkpoint.py b/torch_harmonics/distributed/.ipynb_checkpoints/distributed_convolution-checkpoint.py new file mode 100644 index 00000000..f3b4d499 --- /dev/null +++ b/torch_harmonics/distributed/.ipynb_checkpoints/distributed_convolution-checkpoint.py @@ -0,0 +1,446 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import abc +from typing import List, Tuple, Union, Optional +from itertools import accumulate +from warnings import warn + +import math + +import torch +import torch.nn as nn + +from functools import partial + +from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes +from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch +from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda + +from torch_harmonics.convolution import ( + _compute_support_vals_isotropic, + _compute_support_vals_anisotropic, + _normalize_convolution_tensor_s2, + DiscreteContinuousConv, +) + +from torch_harmonics.distributed import polar_group_size, azimuth_group_size +from torch_harmonics.distributed import distributed_transpose_azimuth, distributed_transpose_polar +from torch_harmonics.distributed import reduce_from_polar_region, scatter_to_polar_region, gather_from_polar_region, copy_to_polar_region +from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank +from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim + +# import custom C++/CUDA extensions if available +try: + from disco_helpers import preprocess_psi + import disco_cuda_extension + _cuda_extension_available = True +except ImportError as err: + disco_cuda_extension = None + _cuda_extension_available = False + + +def _precompute_distributed_convolution_tensor_s2( + in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False, merge_quadrature=False +): + """ + Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$. + Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al. + The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in). + + The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields + $$ + Y(\alpha) Z(\beta) Y(\gamma) n = + {\begin{bmatrix} + \cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\ + \sin(\beta)\sin(\gamma) \\ + \cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma) + \end{bmatrix}} + $$ + """ + + assert len(in_shape) == 2 + assert len(out_shape) == 2 + + if len(kernel_shape) == 1: + kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff) + elif len(kernel_shape) == 2: + kernel_handle = partial(_compute_support_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff) + else: + raise ValueError("kernel_shape should be either one- or two-dimensional.") + + nlat_in, nlon_in = in_shape + nlat_out, nlon_out = out_shape + + lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in) + lats_in = torch.from_numpy(lats_in).float() + lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out) + lats_out = torch.from_numpy(lats_out).float() + + # compute the phi differences + # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0 + lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1] + + out_idx = [] + out_vals = [] + for t in range(nlat_out): + # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis + alpha = -lats_out[t] + beta = lons_in + gamma = lats_in.reshape(-1, 1) + + # compute cartesian coordinates of the rotated position + # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation, + # and therefore applied with a negative sign + z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma) + x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha) + y = torch.sin(beta) * torch.sin(gamma) + + # normalization is emportant to avoid NaNs when arccos and atan are applied + # this can otherwise lead to spurious artifacts in the solution + norm = torch.sqrt(x * x + y * y + z * z) + x = x / norm + y = y / norm + z = z / norm + + # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range + theta = torch.arccos(z) + phi = torch.arctan2(y, x) + torch.pi + + # find the indices where the rotated position falls into the support of the kernel + iidx, vals = kernel_handle(theta, phi) + + # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in) + idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0) + + # append indices and values to the COO datastructure + out_idx.append(idx) + out_vals.append(vals) + + # concatenate the indices and values + out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous() + out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous() + + # perform the normalization over the entire psi matrix + if transpose_normalization: + quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in + else: + quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in + out_vals = _normalize_convolution_tensor_s2(out_idx, out_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=transpose_normalization, merge_quadrature=merge_quadrature) + + # TODO: this part can be split off into it's own function + # split the latitude indices: + comm_size_polar = polar_group_size() + comm_rank_polar = polar_group_rank() + split_shapes = compute_split_shapes(nlat_in, num_chunks=comm_size_polar) + offsets = [0] + list(accumulate(split_shapes)) + start_idx = offsets[comm_rank_polar] + end_idx = offsets[comm_rank_polar+1] + + # once normalization is done we can throw away the entries which correspond to input latitudes we do not care about + lats = out_idx[2] // nlon_in + lons = out_idx[2] % nlon_in + ilats = torch.argwhere((lats < end_idx) & (lats >= start_idx)).squeeze() + out_vals = out_vals[ilats] + # for the indices we need to recompute them to refer to local indices of the input tenor + out_idx = torch.stack([out_idx[0, ilats], out_idx[1, ilats], (lats[ilats]-start_idx) * nlon_in + lons[ilats]], dim=0) + + return out_idx, out_vals + + +class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): + """ + Distributed version of Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1]. + + [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603 + + We assume the data can be splitted in polar and azimuthal directions. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + in_shape: Tuple[int], + out_shape: Tuple[int], + kernel_shape: Union[int, List[int]], + groups: Optional[int] = 1, + grid_in: Optional[str] = "equiangular", + grid_out: Optional[str] = "equiangular", + bias: Optional[bool] = True, + theta_cutoff: Optional[float] = None, + ): + super().__init__(in_channels, out_channels, kernel_shape, groups, bias) + + self.nlat_in, self.nlon_in = in_shape + self.nlat_out, self.nlon_out = out_shape + + # get the comms grid: + self.comm_size_polar = polar_group_size() + self.comm_rank_polar = polar_group_rank() + self.comm_size_azimuth = azimuth_group_size() + self.comm_rank_azimuth = azimuth_group_rank() + + # we need those shapes: + self.lat_in_shapes = compute_split_shapes(self.nlat_in, self.comm_size_polar) + self.lon_in_shapes = compute_split_shapes(self.nlon_in, self.comm_size_azimuth) + self.lat_out_shapes = compute_split_shapes(self.nlat_out, self.comm_size_polar) + self.lon_out_shapes = compute_split_shapes(self.nlon_out, self.comm_size_azimuth) + + # compute theta cutoff based on the bandlimit of the input field + if theta_cutoff is None: + theta_cutoff = torch.pi / float(self.nlat_out - 1) + + if theta_cutoff <= 0.0: + raise ValueError("Error, theta_cutoff has to be positive.") + + # Note that the psi matrix is of shape nlat_out x nlat_in * nlon_in. Since the contraction in nlon direction is a convolution, + # we will keep local to all nodes and split the computation up along nlat. We further split the input dim because this reduces the number + # of atomic reduction calls inside the actual kernel + + # set local shapes according to distributed mode: + self.nlat_in_local = self.lat_in_shapes[self.comm_rank_polar] + self.nlat_out_local = self.nlat_out + idx, vals = _precompute_distributed_convolution_tensor_s2( + in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False, merge_quadrature=True + ) + + # sort the values + ker_idx = idx[0, ...].contiguous() + row_idx = idx[1, ...].contiguous() + col_idx = idx[2, ...].contiguous() + vals = vals.contiguous() + + if _cuda_extension_available: + # preprocessed data-structure for GPU kernel + roff_idx = preprocess_psi(self.kernel_size, self.nlat_out_local, ker_idx, row_idx, col_idx, vals).contiguous() + self.register_buffer("psi_roff_idx", roff_idx, persistent=False) + + self.register_buffer("psi_ker_idx", ker_idx, persistent=False) + self.register_buffer("psi_row_idx", row_idx, persistent=False) + self.register_buffer("psi_col_idx", col_idx, persistent=False) + self.register_buffer("psi_vals", vals, persistent=False) + + def extra_repr(self): + r""" + Pretty print module + """ + return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, kernel_shape={self.kernel_shape}, groups={self.groups}" + + @property + def psi_idx(self): + return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous() + + def get_psi(self): + psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out_local, self.nlat_in_local * self.nlon_in)).coalesce() + return psi + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + # store number of channels + num_chans = x.shape[1] + + # h and w is split. First we make w local by transposing into channel dim + if self.comm_size_azimuth > 1: + x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes) + + if x.is_cuda and _cuda_extension_available: + x = _disco_s2_contraction_cuda( + x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out_local, self.nlon_out + ) + else: + if x.is_cuda: + warn("couldn't find CUDA extension, falling back to slow PyTorch implementation") + + psi = self.get_psi() + + x = _disco_s2_contraction_torch(x, psi, self.nlon_out) + + # perform reduce scatter in polar region + x = reduce_from_polar_region(x) + x = scatter_to_polar_region(x, -2) + + # now we can transpose back the result, so that lon is split and channels are local + if self.comm_size_azimuth > 1: + chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth) + x = distributed_transpose_azimuth.apply(x, (-1, 1), chan_shapes) + + # extract shape + B, C, K, H, W = x.shape + x = x.reshape(B, self.groups, self.groupsize, K, H, W) + + # do weight multiplication + out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous() + out = out.reshape(out.shape[0], -1, H, W) + + if self.bias is not None: + out = out + self.bias.reshape(1, -1, 1, 1) + + return out + + +class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv): + """ + Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1]. + + [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + in_shape: Tuple[int], + out_shape: Tuple[int], + kernel_shape: Union[int, List[int]], + groups: Optional[int] = 1, + grid_in: Optional[str] = "equiangular", + grid_out: Optional[str] = "equiangular", + bias: Optional[bool] = True, + theta_cutoff: Optional[float] = None, + ): + super().__init__(in_channels, out_channels, kernel_shape, groups, bias) + + self.nlat_in, self.nlon_in = in_shape + self.nlat_out, self.nlon_out = out_shape + + # get the comms grid: + self.comm_size_polar = polar_group_size() + self.comm_rank_polar = polar_group_rank() + self.comm_size_azimuth = azimuth_group_size() + self.comm_rank_azimuth = azimuth_group_rank() + + # we need those shapes: + self.lat_in_shapes = compute_split_shapes(self.nlat_in, self.comm_size_polar) + self.lon_in_shapes = compute_split_shapes(self.nlon_in, self.comm_size_azimuth) + self.lat_out_shapes = compute_split_shapes(self.nlat_out, self.comm_size_polar) + self.lon_out_shapes = compute_split_shapes(self.nlon_out, self.comm_size_azimuth) + + # bandlimit + if theta_cutoff is None: + theta_cutoff = torch.pi / float(self.nlat_in - 1) + + if theta_cutoff <= 0.0: + raise ValueError("Error, theta_cutoff has to be positive.") + + # Note that the psi matrix is of shape nlat_out x nlat_in * nlon_in. Since the contraction in nlon direction is a convolution, + # we will keep local to all nodes and split the computation up along nlat. We further split the input dim because this reduces the number + # of atomic reduction calls inside the actual kernel + + # set local shapes according to distributed mode: + self.nlat_in_local = self.nlat_in + self.nlat_out_local = self.lat_out_shapes[self.comm_rank_polar] + + # switch in_shape and out_shape since we want transpose conv + # distributed mode here is swapped because of the transpose + idx, vals = _precompute_distributed_convolution_tensor_s2( + out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True, merge_quadrature=True + ) + + # sort the values + ker_idx = idx[0, ...].contiguous() + row_idx = idx[1, ...].contiguous() + col_idx = idx[2, ...].contiguous() + vals = vals.contiguous() + + if _cuda_extension_available: + # preprocessed data-structure for GPU kernel + roff_idx = preprocess_psi(self.kernel_size, self.nlat_in_local, ker_idx, row_idx, col_idx, vals).contiguous() + self.register_buffer("psi_roff_idx", roff_idx, persistent=False) + + self.register_buffer("psi_ker_idx", ker_idx, persistent=False) + self.register_buffer("psi_row_idx", row_idx, persistent=False) + self.register_buffer("psi_col_idx", col_idx, persistent=False) + self.register_buffer("psi_vals", vals, persistent=False) + + def extra_repr(self): + r""" + Pretty print module + """ + return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, kernel_shape={self.kernel_shape}, groups={self.groups}" + + @property + def psi_idx(self): + return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous() + + def get_psi(self, semi_transposed: bool = False): + if semi_transposed: + # do partial transpose + # we do a semi-transposition to faciliate the computation + tout = self.psi_idx[2] // self.nlon_out + pout = self.psi_idx[2] % self.nlon_out + # flip the axis of longitudes + pout = self.nlon_out - 1 - pout + tin = self.psi_idx[1] + idx = torch.stack([self.psi_idx[0], tout, tin * self.nlon_out + pout], dim=0) + psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out_local, self.nlat_in_local * self.nlon_out)).coalesce() + else: + psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in_local, self.nlat_out_local * self.nlon_out)).coalesce() + return psi + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + # extract shape + B, C, H, W = x.shape + x = x.reshape(B, self.groups, self.groupsize, H, W) + + # do weight multiplication + x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous() + x = x.reshape(B, -1, x.shape[-3], H, W) + num_chans = x.shape[1] + + # transpose such that lon is local, channels are split + if self.comm_size_azimuth > 1: + x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes) + + # gather input tensor and set up backward reduction hooks + x = gather_from_polar_region(x, -2, self.lat_in_shapes) + x = copy_to_polar_region(x) + + if x.is_cuda and _cuda_extension_available: + out = _disco_s2_transpose_contraction_cuda( + x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out_local, self.nlon_out + ) + else: + if x.is_cuda: + warn("couldn't find CUDA extension, falling back to slow PyTorch implementation") + psi = self.get_psi(semi_transposed=True) + out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out) + + # now we can transpose back the result, so that lon is split and channels are local + if self.comm_size_azimuth > 1: + chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth) + out = distributed_transpose_azimuth.apply(out, (-1, 1), chan_shapes) + + if self.bias is not None: + out = out + self.bias.reshape(1, -1, 1, 1) + + return out diff --git a/torch_harmonics/distributed/.ipynb_checkpoints/primitives-checkpoint.py b/torch_harmonics/distributed/.ipynb_checkpoints/primitives-checkpoint.py new file mode 100644 index 00000000..60edc9d8 --- /dev/null +++ b/torch_harmonics/distributed/.ipynb_checkpoints/primitives-checkpoint.py @@ -0,0 +1,425 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +from typing import List + +import torch +import torch.distributed as dist +from torch.amp import custom_fwd, custom_bwd + +from .utils import polar_group, azimuth_group, polar_group_size +from .utils import is_initialized, is_distributed_polar + +# helper routine to compute uneven splitting in balanced way: +def compute_split_shapes(size: int, num_chunks: int) -> List[int]: + + # treat trivial case first + if num_chunks == 1: + return [size] + + # first, check if we can split using div-up to balance the load: + chunk_size = (size + num_chunks - 1) // num_chunks + last_chunk_size = max(0, size - chunk_size * (num_chunks - 1)) + if last_chunk_size == 0: + # in this case, the last shard would be empty, split with floor instead: + chunk_size = size // num_chunks + last_chunk_size = size - chunk_size * (num_chunks-1) + + # generate sections list + sections = [chunk_size for _ in range(num_chunks - 1)] + [last_chunk_size] + + return sections + + +def split_tensor_along_dim(tensor, dim, num_chunks): + assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}" + assert (tensor.shape[dim] >= num_chunks), f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \ + {num_chunks} chunks. Empty slices are currently not supported." + + # get split + sections = compute_split_shapes(tensor.shape[dim], num_chunks) + tensor_list = torch.split(tensor, sections, dim=dim) + + return tensor_list + + +def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False): + # get comm params + comm_size = dist.get_world_size(group=group) + comm_rank = dist.get_rank(group=group) + + # split and local transposition + tsplit = split_tensor_along_dim(tensor, num_chunks=comm_size, dim=dim0) + x_send = [y.contiguous() for y in tsplit] + x_send_shapes = [x.shape for x in x_send] + x_recv = [] + x_shape = list(x_send_shapes[comm_rank]) + for dim1_len in dim1_split_sizes: + x_shape[dim1] = dim1_len + x_recv.append(torch.empty(x_shape, dtype=tensor.dtype, device=tensor.device)) + + # global transposition + req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op) + + # get dim0 split sizes + dim0_split_sizes = [x[dim0] for x in x_send_shapes] + + return x_recv, dim0_split_sizes, req + + +class distributed_transpose_azimuth(torch.autograd.Function): + + @staticmethod + @custom_fwd(device_type="cuda") + def forward(ctx, x, dims, dim1_split_sizes): + # WAR for a potential contig check torch bug for channels last contig tensors + xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group()) + x = torch.cat(xlist, dim=dims[1]) + ctx.dims = dims + ctx.dim0_split_sizes = dim0_split_sizes + + return x + + @staticmethod + @custom_bwd(device_type="cuda") + def backward(ctx, go): + dims = ctx.dims + dim0_split_sizes = ctx.dim0_split_sizes + # WAR for a potential contig check torch bug for channels last contig tensors + gilist, _, _ = _transpose(go, dims[1], dims[0], dim0_split_sizes, group=azimuth_group()) + gi = torch.cat(gilist, dim=dims[0]) + + return gi, None, None + + +class distributed_transpose_polar(torch.autograd.Function): + + @staticmethod + @custom_fwd(device_type="cuda") + def forward(ctx, x, dim, dim1_split_sizes): + # WAR for a potential contig check torch bug for channels last contig tensors + xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group()) + x = torch.cat(xlist, dim=dim[1]) + ctx.dim = dim + ctx.dim0_split_sizes = dim0_split_sizes + return x + + @staticmethod + @custom_bwd(device_type="cuda") + def backward(ctx, go): + dim = ctx.dim + dim0_split_sizes = ctx.dim0_split_sizes + # WAR for a potential contig check torch bug for channels last contig tensors + gilist, _, _ = _transpose(go, dim[1], dim[0], dim0_split_sizes, group=polar_group()) + gi = torch.cat(gilist, dim=dim[0]) + return gi, None, None + + +# we need those additional primitives for distributed matrix multiplications +def _reduce(input_, use_fp32=True, group=None): + """All-reduce the input tensor across model parallel group.""" + + # Bypass the function if we are using only 1 GPU. + if dist.get_world_size(group=group) == 1: + return input_ + + # All-reduce. + if use_fp32: + dtype = input_.dtype + inputf_ = input_.float() + inputf_ = inputf_.contiguous() + dist.all_reduce(inputf_, group=group) + input_ = inputf_.to(dtype) + else: + input_ = input_.contiguous() + dist.all_reduce(input_, group=group) + + return input_ + + +def _split(input_, dim_, group=None): + """Split the tensor along its last dimension and keep the corresponding slice.""" + # Bypass the function if we are using only 1 GPU. + comm_size = dist.get_world_size(group=group) + if comm_size == 1: + return input_ + + # Split along last dimension. + input_list = split_tensor_along_dim(input_, dim_, comm_size) + + # Note: torch.split does not create contiguous tensors by default. + rank = dist.get_rank(group=group) + output = input_list[rank] + + return output + + +def _gather(input_, dim_, shapes_, group=None): + """Gather unevenly split tensors across ranks""" + + comm_size = dist.get_world_size(group=group) + + if (shapes_ is not None) and (len(shapes_) != comm_size): + raise ValueError() + if dim_ >= input_.dim(): + raise ValueError() + + if comm_size == 1: + return input_ + + # make contiguous: + input_ = input_.contiguous() + input_shape = list(input_.shape) + + if shapes_ is not None: + input_list = [] + for src in range(comm_size): + input_shape[dim_] = shapes_[src] + input_list.append(torch.empty(input_shape, dtype=input_.dtype, device=input_.device)) + else: + # assume equal shape on all ranks + input_list = [torch.empty_like(input_) for _ in range(comm_size)] + + dist.all_gather(input_list, input_, group=group) + + output = torch.cat(input_list, dim=dim_) + + return output + + +def _reduce_scatter(input_, dim_, use_fp32=True, group=None): + """All-reduce the input tensor across model parallel group and scatter it back.""" + + # Bypass the function if we are using only 1 GPU. + if dist.get_world_size(group=group) == 1: + return input_ + + # make input contiguous + comm_size = dist.get_world_size(group=group) + comm_rank = dist.get_rank(group=group) + input_list = split_tensor_along_dim(input_, dim_, comm_size) + + dtype = input_.dtype + if (use_fp32 and (dtype != torch.float32)): + input_list = [x.to(torch.float32) for x in input_list] + + input_list = [x.contiguous() for x in input_list] + + # perform reduce_scatter + output = torch.empty_like(input_list[comm_rank]) + dist.reduce_scatter(output, input_list, group=group) + + # convert dtype if necessary + if use_fp32: + output = output.to(dtype=dtype) + + return output + + +class _CopyToPolarRegion(torch.autograd.Function): + """Split the input and keep only the corresponding chunk to the rank.""" + + @staticmethod + def symbolic(graph, input_): + return input_ + + @staticmethod + @custom_fwd(device_type="cuda") + def forward(ctx, input_): + return input_ + + @staticmethod + @custom_bwd(device_type="cuda") + def backward(ctx, grad_output): + if is_distributed_polar(): + return _reduce(grad_output, group=polar_group()) + else: + return grad_output, None + + +class _ScatterToPolarRegion(torch.autograd.Function): + """Split the input and keep only the corresponding chunk to the rank.""" + + @staticmethod + def symbolic(graph, input_, dim_): + return _split(input_, dim_, group=polar_group()) + + @staticmethod + @custom_fwd(device_type="cuda") + def forward(ctx, input_, dim_): + if is_distributed_polar(): + ctx.dim = dim_ + ctx.split_shapes = compute_split_shapes( + input_.shape[dim_], polar_group_size() + ) + return _split(input_, dim_, group=polar_group()) + else: + return input_ + + @staticmethod + @custom_bwd(device_type="cuda") + def backward(ctx, grad_output): + if is_distributed_polar(): + return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None + else: + return grad_output, None + + +class _GatherFromPolarRegion(torch.autograd.Function): + """Gather the input and keep it on the rank.""" + + @staticmethod + def symbolic(graph, input_, dim_, shapes_): + return _gather(input_, dim_, shapes_, polar_group()) + + @staticmethod + @custom_fwd(device_type="cuda") + def forward(ctx, input_, dim_, shapes_): + if is_distributed_polar(): + ctx.dim = dim_ + return _gather(input_, dim_, shapes_, group=polar_group()) + else: + return input_ + + @staticmethod + @custom_bwd(device_type="cuda") + def backward(ctx, grad_output): + if is_distributed_polar(): + return _split(grad_output, ctx.dim, group=polar_group()), None, None + else: + return grad_output, None, None + + +class _ReduceFromPolarRegion(torch.autograd.Function): + """All-reduce the input from the polar region.""" + + @staticmethod + def symbolic(graph, input_): + if is_distributed_polar(): + return _reduce(input_, group=polar_group()) + else: + return input_ + + @staticmethod + @custom_fwd(device_type="cuda") + def forward(ctx, input_): + if is_distributed_polar(): + return _reduce(input_, group=polar_group()) + else: + return input_ + + @staticmethod + @custom_bwd(device_type="cuda") + def backward(ctx, grad_output): + return grad_output + + +class _ReduceFromScatterToPolarRegion(torch.autograd.Function): + """All-reduce the input from the polar region and scatter back to polar region.""" + + @staticmethod + def symbolic(graph, input_, dim_): + if is_distributed_polar(): + return _reduce_scatter(input_, dim_, group=polar_group()) + else: + return input_ + + @staticmethod + @custom_fwd(device_type="cuda") + def forward(ctx, input_, dim_): + if is_distributed_polar(): + ctx.dim = dim_ + ctx.split_shapes = compute_split_shapes( + input_.shape[dim_], polar_group_size() + ) + return _reduce_scatter(input_, dim_, group=polar_group()) + else: + return input_ + + @staticmethod + @custom_bwd(device_type="cuda") + def backward(ctx, grad_output): + if is_distributed_polar(): + return _gather(grad_output, ctx.dim, ctx.split_shapes, polar_group()), None + else: + return grad_output, None + + +class _GatherFromCopyToPolarRegion(torch.autograd.Function): + """Gather the input from the polar region and register BWD AR, basically the inverse of reduce-scatter""" + + @staticmethod + def symbolic(graph, input_, dim_, shapes_): + if is_distributed_polar(): + return _gather(input_, dim_, shapes_, polar_group()) + else: + return input_ + + @staticmethod + @custom_fwd(device_type="cuda") + def forward(ctx, input_, dim_, shapes_): + if is_distributed_polar(): + ctx.dim = dim_ + return _gather(input_, dim_, shapes_, group=polar_group()) + else: + return input_ + + @staticmethod + @custom_bwd(device_type="cuda") + def backward(ctx, grad_output): + if is_distributed_polar(): + return _reduce_scatter(grad_output, ctx.dim, use_fp32=True, group=polar_group()), None, None + else: + return grad_output, None, None + + + +def copy_to_polar_region(input_): + return _CopyToPolarRegion.apply(input_) + + +def reduce_from_polar_region(input_): + return _ReduceFromPolarRegion.apply(input_) + + +def scatter_to_polar_region(input_, dim_): + return _ScatterToPolarRegion.apply(input_, dim_) + + +def gather_from_polar_region(input_, dim_, shapes_): + return _GatherFromPolarRegion.apply(input_, dim_, shapes_) + + +def reduce_from_scatter_to_polar_region(input_, dim_): + return _ReduceFromScatterToPolarRegion.apply(input_, dim_) + + +def gather_from_copy_to_polar_region(input_, dim_, shapes_): + return _GatherFromCopyToPolarRegion.apply(input_, dim_, shapes_) diff --git a/torch_harmonics/distributed/.ipynb_checkpoints/utils-checkpoint.py b/torch_harmonics/distributed/.ipynb_checkpoints/utils-checkpoint.py new file mode 100644 index 00000000..b584d472 --- /dev/null +++ b/torch_harmonics/distributed/.ipynb_checkpoints/utils-checkpoint.py @@ -0,0 +1,92 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +# we need this in order to enable distributed +import torch +import torch.distributed as dist + +# those need to be global +_POLAR_PARALLEL_GROUP = None +_AZIMUTH_PARALLEL_GROUP = None +_IS_INITIALIZED = False + +def polar_group(): + return _POLAR_PARALLEL_GROUP + +def azimuth_group(): + return _AZIMUTH_PARALLEL_GROUP + +def init(polar_process_group, azimuth_process_group): + global _POLAR_PARALLEL_GROUP + global _AZIMUTH_PARALLEL_GROUP + _POLAR_PARALLEL_GROUP = polar_process_group + _AZIMUTH_PARALLEL_GROUP = azimuth_process_group + _IS_INITIALIZED = True + +def finalize(): + if is_initialized(): + if is_distributed_polar(): + dist.destroy_process_group(_POLAR_PARALLEL_GROUP) + if is_distributed_azimuth(): + ist.destroy_process_group(_AZIMUTH_PARALLEL_GROUP) + +def is_initialized() -> bool: + return _IS_INITIALIZED + +def is_distributed_polar() -> bool: + return (_POLAR_PARALLEL_GROUP is not None) + +def is_distributed_azimuth() -> bool: + return (_AZIMUTH_PARALLEL_GROUP is not None) + +def polar_group_size() -> int: + if not is_distributed_polar(): + return 1 + else: + return dist.get_world_size(group = _POLAR_PARALLEL_GROUP) + +def azimuth_group_size() -> int: + if not is_distributed_azimuth(): + return 1 + else: + return dist.get_world_size(group = _AZIMUTH_PARALLEL_GROUP) + +def polar_group_rank() -> int: + if not is_distributed_polar(): + return 0 + else: + return dist.get_rank(group = _POLAR_PARALLEL_GROUP) + +def azimuth_group_rank() -> int: + if not is_distributed_azimuth(): + return 0 + else: + return dist.get_rank(group = _AZIMUTH_PARALLEL_GROUP) diff --git a/torch_harmonics/examples/.ipynb_checkpoints/__init__-checkpoint.py b/torch_harmonics/examples/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 00000000..c4e9073f --- /dev/null +++ b/torch_harmonics/examples/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,33 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from .pde_sphere import SphereSolver +from .shallow_water_equations import ShallowWaterSolver \ No newline at end of file diff --git a/torch_harmonics/examples/.ipynb_checkpoints/pde_sphere-checkpoint.py b/torch_harmonics/examples/.ipynb_checkpoints/pde_sphere-checkpoint.py new file mode 100644 index 00000000..44540a23 --- /dev/null +++ b/torch_harmonics/examples/.ipynb_checkpoints/pde_sphere-checkpoint.py @@ -0,0 +1,178 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + + +import torch +import torch.nn as nn +import torch_harmonics as harmonics + +import numpy as np + + +class SphereSolver(nn.Module): + """ + Solver class on the sphere. Can solve the following PDEs: + - Allen-Cahn eq + """ + + def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid='legendre-gauss', radius=1.0, coeff=0.001): + super().__init__() + + # time stepping param + self.dt = dt + + # grid parameters + self.nlat = nlat + self.nlon = nlon + self.grid = grid + + # physical sonstants + self.register_buffer('radius', torch.as_tensor(radius, dtype=torch.float64)) + self.register_buffer('coeff', torch.as_tensor(coeff, dtype=torch.float64)) + + # SHT + self.sht = harmonics.RealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False) + self.isht = harmonics.InverseRealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False) + + self.lmax = lmax or self.sht.lmax + self.mmax = lmax or self.sht.mmax + + # compute gridpoints + if self.grid == "legendre-gauss": + cost, _ = harmonics.quadrature.legendre_gauss_weights(self.nlat, -1, 1) + elif self.grid == "lobatto": + cost, _ = harmonics.quadrature.lobatto_weights(self.nlat, -1, 1) + elif self.grid == "equiangular": + cost, _ = harmonics.quadrature.clenshaw_curtiss_weights(self.nlat, -1, 1) + + # apply cosine transform and flip them + lats = -torch.as_tensor(np.arcsin(cost)) + lons = torch.linspace(0, 2*np.pi, self.nlon+1, dtype=torch.float64)[:nlon] + + self.lmax = self.sht.lmax + self.mmax = self.sht.mmax + + l = torch.arange(0, self.lmax).reshape(self.lmax, 1).cdouble() + l = l.expand(self.lmax, self.mmax) + # the laplace operator acting on the coefficients is given by l (l + 1) + lap = - l * (l + 1) / self.radius**2 + invlap = - self.radius**2 / l / (l + 1) + invlap[0] = 0. + + # register all + self.register_buffer('lats', lats) + self.register_buffer('lons', lons) + self.register_buffer('l', l) + self.register_buffer('lap', lap) + self.register_buffer('invlap', invlap) + + def grid2spec(self, u): + """spectral coefficients from spatial data""" + + return self.sht(u) + + def spec2grid(self, uspec): + """spatial data from spectral coefficients""" + + return self.isht(uspec) + + def dudtspec(self, uspec, pde='allen-cahn'): + + if pde == 'allen-cahn': + ugrid = self.spec2grid(uspec) + u3spec = self.grid2spec(ugrid**3) + dudtspec = self.coeff*self.lap*uspec + uspec - u3spec + elif pde == 'ginzburg-landau': + ugrid = self.spec2grid(uspec) + u3spec = self.grid2spec(ugrid**3) + dudtspec = uspec + (1. + 2.j)*self.coeff*self.lap*uspec - (1. + 2.j)*u3spec + else: + NotImplementedError + + return dudtspec + + def randspec(self): + """random data on the sphere""" + + rspec = torch.randn_like(self.lap) / 4 / torch.pi + return rspec + + + def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=False): + """ + plotting routine for data on the grid. Requires cartopy for 3d plots. + """ + import matplotlib.pyplot as plt + + lons = self.lons.squeeze() - torch.pi + lats = self.lats.squeeze() + + if data.is_cuda: + data = data.cpu() + lons = lons.cpu() + lats = lats.cpu() + + Lons, Lats = np.meshgrid(lons, lats) + + if projection == 'mollweide': + + #ax = plt.gca(projection=projection) + ax = fig.add_subplot(projection=projection) + im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, vmax=vmax, vmin=vmin) + # ax.set_title("Elevation map of mars") + ax.grid(True) + ax.set_xticklabels([]) + ax.set_yticklabels([]) + plt.colorbar(im, orientation='horizontal') + plt.title(title) + + elif projection == '3d': + + import cartopy.crs as ccrs + + proj = ccrs.Orthographic(central_longitude=0.0, central_latitude=25.0) + + #ax = plt.gca(projection=proj, frameon=True) + ax = fig.add_subplot(projection=proj) + Lons = Lons*180/np.pi + Lats = Lats*180/np.pi + + # contour data over the map. + im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=antialiased, vmax=vmax, vmin=vmin) + plt.title(title, y=1.05) + + else: + raise NotImplementedError + + return im + + def plot_specdata(self, data, fig, **kwargs): + return self.plot_griddata(self.isht(data), fig, **kwargs) \ No newline at end of file diff --git a/torch_harmonics/examples/.ipynb_checkpoints/shallow_water_equations-checkpoint.py b/torch_harmonics/examples/.ipynb_checkpoints/shallow_water_equations-checkpoint.py new file mode 100644 index 00000000..0b23e437 --- /dev/null +++ b/torch_harmonics/examples/.ipynb_checkpoints/shallow_water_equations-checkpoint.py @@ -0,0 +1,378 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + + +import torch +import torch.nn as nn +import sys +sys.path.append("../") +import torch_harmonics as harmonics +from torch_harmonics.quadrature import * + +import numpy as np + + +class ShallowWaterSolver(nn.Module): + """ + SWE solver class. Interface inspired bu pyspharm and SHTns + """ + + def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid='legendre-gauss', radius=6.37122E6, \ + omega=7.292E-5, gravity=9.80616, havg=10.e3, hamp=120.): + super().__init__() + + # time stepping param + self.dt = dt + + # grid parameters + self.nlat = nlat + self.nlon = nlon + self.grid = grid + + # physical sonstants + self.register_buffer('radius', torch.as_tensor(radius, dtype=torch.float64)) + self.register_buffer('omega', torch.as_tensor(omega, dtype=torch.float64)) + self.register_buffer('gravity', torch.as_tensor(gravity, dtype=torch.float64)) + self.register_buffer('havg', torch.as_tensor(havg, dtype=torch.float64)) + self.register_buffer('hamp', torch.as_tensor(hamp, dtype=torch.float64)) + + # SHT + self.sht = harmonics.RealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False) + self.isht = harmonics.InverseRealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False) + self.vsht = harmonics.RealVectorSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False) + self.ivsht = harmonics.InverseRealVectorSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False) + + self.lmax = lmax or self.sht.lmax + self.mmax = lmax or self.sht.mmax + + # compute gridpoints + if self.grid == "legendre-gauss": + cost, quad_weights = harmonics.quadrature.legendre_gauss_weights(self.nlat, -1, 1) + elif self.grid == "lobatto": + cost, quad_weights = harmonics.quadrature.lobatto_weights(self.nlat, -1, 1) + elif self.grid == "equiangular": + cost, quad_weights = harmonics.quadrature.clenshaw_curtiss_weights(self.nlat, -1, 1) + + quad_weights = torch.as_tensor(quad_weights).reshape(-1, 1) + + # apply cosine transform and flip them + lats = -torch.as_tensor(np.arcsin(cost)) + lons = torch.linspace(0, 2*np.pi, self.nlon+1, dtype=torch.float64)[:nlon] + + self.lmax = self.sht.lmax + self.mmax = self.sht.mmax + + # compute the laplace and inverse laplace operators + l = torch.arange(0, self.lmax).reshape(self.lmax, 1).double() + l = l.expand(self.lmax, self.mmax) + # the laplace operator acting on the coefficients is given by - l (l + 1) + lap = - l * (l + 1) / self.radius**2 + invlap = - self.radius**2 / l / (l + 1) + invlap[0] = 0. + + # compute coriolis force + coriolis = 2 * self.omega * torch.sin(lats).reshape(self.nlat, 1) + + # hyperdiffusion + hyperdiff = torch.exp(torch.asarray((-self.dt / 2 / 3600.)*(lap / lap[-1, 0])**4)) + + # register all + self.register_buffer('lats', lats) + self.register_buffer('lons', lons) + self.register_buffer('l', l) + self.register_buffer('lap', lap) + self.register_buffer('invlap', invlap) + self.register_buffer('coriolis', coriolis) + self.register_buffer('hyperdiff', hyperdiff) + self.register_buffer('quad_weights', quad_weights) + + def grid2spec(self, ugrid): + """ + spectral coefficients from spatial data + """ + return self.sht(ugrid) + + def spec2grid(self, uspec): + """ + spatial data from spectral coefficients + """ + return self.isht(uspec) + + def vrtdivspec(self, ugrid): + """spatial data from spectral coefficients""" + vrtdivspec = self.lap * self.radius * self.vsht(ugrid) + return vrtdivspec + + def getuv(self, vrtdivspec): + """ + compute wind vector from spectral coeffs of vorticity and divergence + """ + return self.ivsht( self.invlap * vrtdivspec / self.radius) + + def gethuv(self, uspec): + """ + compute wind vector from spectral coeffs of vorticity and divergence + """ + hgrid = self.spec2grid(uspec[:1]) + uvgrid = self.getuv(uspec[1:]) + return torch.cat((hgrid, uvgrid), dim=-3) + + def potential_vorticity(self, uspec): + """ + Compute potential vorticity + """ + ugrid = self.spec2grid(uspec) + pvrt = (0.5 * self.havg * self.gravity / self.omega) * (ugrid[1] + self.coriolis) / ugrid[0] + return pvrt + + def dimensionless(self, uspec): + """ + Remove dimensions from variables + """ + uspec[0] = (uspec[0] - self.havg * self.gravity) / self.hamp / self.gravity + # vorticity is measured in 1/s so we normalize using sqrt(g h) / r + uspec[1:] = uspec[1:] * self.radius / torch.sqrt(self.gravity * self.havg) + return uspec + + def dudtspec(self, uspec): + """ + Compute time derivatives from solution represented in spectral coefficients + """ + + dudtspec = torch.zeros_like(uspec) + + # compute the derivatives - this should be incorporated into the solver: + ugrid = self.spec2grid(uspec) + uvgrid = self.getuv(uspec[1:]) + + # phi = ugrid[0] + # vrtdiv = ugrid[1:] + + tmp = uvgrid * (ugrid[1] + self.coriolis) + tmpspec = self.vrtdivspec(tmp) + dudtspec[2] = tmpspec[0] + dudtspec[1] = -1 * tmpspec[1] + + tmp = uvgrid * ugrid[0] + tmp = self.vrtdivspec(tmp) + dudtspec[0] = -1 * tmp[1] + + tmpspec = self.grid2spec(ugrid[0] + 0.5 * (uvgrid[0]**2 + uvgrid[1]**2)) + dudtspec[2] = dudtspec[2] - self.lap * tmpspec + + return dudtspec + + def galewsky_initial_condition(self): + """ + Initializes non-linear barotropically unstable shallow water test case of Galewsky et al. (2004, Tellus, 56A, 429-440). + + [1] Galewsky; An initial-value problem for testing numerical models of the global shallow-water equations; + DOI: 10.1111/j.1600-0870.2004.00071.x; http://www-vortex.mcs.st-and.ac.uk/~rks/reprints/galewsky_etal_tellus_2004.pdf + """ + device = self.lap.device + + umax = 80. + phi0 = torch.asarray(torch.pi / 7., device=device) + phi1 = torch.asarray(0.5 * torch.pi - phi0, device=device) + phi2 = 0.25 * torch.pi + en = torch.exp(torch.asarray(-4.0 / (phi1 - phi0)**2, device=device)) + alpha = 1. / 3. + beta = 1. / 15. + + lats, lons = torch.meshgrid(self.lats, self.lons) + + u1 = (umax/en)*torch.exp(1./((lats-phi0)*(lats-phi1))) + ugrid = torch.where(torch.logical_and(lats < phi1, lats > phi0), u1, torch.zeros(self.nlat, self.nlon, device=device)) + vgrid = torch.zeros((self.nlat, self.nlon), device=device) + hbump = self.hamp * torch.cos(lats) * torch.exp(-((lons-torch.pi)/alpha)**2) * torch.exp(-(phi2-lats)**2/beta) + + # intial velocity field + ugrid = torch.stack((ugrid, vgrid)) + # intial vorticity/divergence field + vrtdivspec = self.vrtdivspec(ugrid) + vrtdivgrid = self.spec2grid(vrtdivspec) + + # solve balance eqn to get initial zonal geopotential with a localized bump (not balanced). + tmp = ugrid * (vrtdivgrid + self.coriolis) + tmpspec = self.vrtdivspec(tmp) + tmpspec[1] = self.grid2spec(0.5 * torch.sum(ugrid**2, dim=0)) + phispec = self.invlap*tmpspec[0] - tmpspec[1] + self.grid2spec(self.gravity*(self.havg + hbump)) + + # assemble solution + uspec = torch.zeros(3, self.lmax, self.mmax, dtype=vrtdivspec.dtype, device=device) + uspec[0] = phispec + uspec[1:] = vrtdivspec + + return torch.tril(uspec) + + def random_initial_condition(self, mach=0.1) -> torch.Tensor: + """ + random initial condition on the sphere + """ + device = self.lap.device + ctype = torch.complex128 if self.lap.dtype == torch.float64 else torch.complex64 + + # mach number relative to wave speed + llimit = mlimit = 20 + + # hgrid = self.havg + hamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype) + # ugrid = uamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype) + # vgrid = vamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype) + # ugrid = torch.stack((ugrid, vgrid)) + + # initial geopotential + uspec = torch.zeros(3, self.lmax, self.mmax, dtype=ctype, device=self.lap.device) + uspec[:, :llimit, :mlimit] = torch.sqrt(torch.tensor(4 * torch.pi / llimit / (llimit+1), device=device, dtype=ctype)) * torch.randn_like(uspec[:, :llimit, :mlimit]) + + uspec[0] = self.gravity * self.hamp * uspec[0] + uspec[0, 0, 0] += torch.sqrt(torch.tensor(4 * torch.pi, device=device, dtype=ctype)) * self.havg * self.gravity + uspec[1:] = mach * uspec[1:] * torch.sqrt(self.gravity * self.havg) / self.radius + # uspec[1:] = self.vrtdivspec(self.spec2grid(uspec[1:]) * torch.cos(self.lats.reshape(-1, 1))) + + # # intial velocity field + # ugrid = uamp * self.spec2grid(uspec[1]) + # vgrid = vamp * self.spec2grid(uspec[2]) + # ugrid = torch.stack((ugrid, vgrid)) + + + + # # intial vorticity/divergence field + # vrtdivspec = self.vrtdivspec(ugrid) + # vrtdivgrid = self.spec2grid(vrtdivspec) + + # # solve balance eqn to get initial zonal geopotential with a localized bump (not balanced). + # tmp = ugrid * (vrtdivgrid + self.coriolis) + # tmpspec = self.vrtdivspec(tmp) + # tmpspec[1] = self.grid2spec(0.5 * torch.sum(ugrid**2, dim=0)) + # phispec = self.invlap*tmpspec[0] - tmpspec[1] + self.grid2spec(self.gravity * hgrid) + + # # assemble solution + # uspec = torch.zeros(3, self.lmax, self.mmax, dtype=phispec.dtype, device=device) + # uspec[0] = phispec + # uspec[1:] = vrtdivspec + + return torch.tril(uspec) + + def timestep(self, uspec: torch.Tensor, nsteps: int) -> torch.Tensor: + """ + Integrate the solution using Adams-Bashforth / forward Euler for nsteps steps. + """ + + dudtspec = torch.zeros(3, 3, self.lmax, self.mmax, dtype=uspec.dtype, device=uspec.device) + + # pointers to indicate the most current result + inew = 0 + inow = 1 + iold = 2 + + for iter in range(nsteps): + dudtspec[inew] = self.dudtspec(uspec) + + # update vort,div,phiv with third-order adams-bashforth. + # forward euler, then 2nd-order adams-bashforth time steps to start. + if iter == 0: + dudtspec[inow] = dudtspec[inew] + dudtspec[iold] = dudtspec[inew] + elif iter == 1: + dudtspec[iold] = dudtspec[inew] + + uspec = uspec + self.dt*( (23./12.) * dudtspec[inew] - (16./12.) * dudtspec[inow] + (5./12.) * dudtspec[iold] ) + + # implicit hyperdiffusion for vort and div. + uspec[1:] = self.hyperdiff * uspec[1:] + + # cycle through the indices + inew = (inew - 1) % 3 + inow = (inow - 1) % 3 + iold = (iold - 1) % 3 + + return uspec + + def integrate_grid(self, ugrid, dimensionless=False, polar_opt=0): + dlon = 2 * torch.pi / self.nlon + radius = 1 if dimensionless else self.radius + if polar_opt > 0: + out = torch.sum(ugrid[..., polar_opt:-polar_opt, :] * self.quad_weights[polar_opt:-polar_opt] * dlon * radius**2, dim=(-2, -1)) + else: + out = torch.sum(ugrid * self.quad_weights * dlon * radius**2, dim=(-2, -1)) + return out + + + def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=False): + """ + plotting routine for data on the grid. Requires cartopy for 3d plots. + """ + import matplotlib.pyplot as plt + + lons = self.lons.squeeze() - torch.pi + lats = self.lats.squeeze() + + if data.is_cuda: + data = data.cpu() + lons = lons.cpu() + lats = lats.cpu() + + Lons, Lats = np.meshgrid(lons, lats) + + if projection == 'mollweide': + + #ax = plt.gca(projection=projection) + ax = fig.add_subplot(projection=projection) + im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, vmax=vmax, vmin=vmin) + # ax.set_title("Elevation map of mars") + ax.grid(True) + ax.set_xticklabels([]) + ax.set_yticklabels([]) + plt.colorbar(im, orientation='horizontal') + plt.title(title) + + elif projection == '3d': + + import cartopy.crs as ccrs + + proj = ccrs.Orthographic(central_longitude=0.0, central_latitude=25.0) + + #ax = plt.gca(projection=proj, frameon=True) + ax = fig.add_subplot(projection=proj) + Lons = Lons*180/np.pi + Lats = Lats*180/np.pi + + # contour data over the map. + im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=antialiased, vmax=vmax, vmin=vmin) + plt.title(title, y=1.05) + + else: + raise NotImplementedError + + return im + + def plot_specdata(self, data, fig, **kwargs): + return self.plot_griddata(self.isht(data), fig, **kwargs) diff --git a/torch_harmonics/examples/__pycache__/__init__.cpython-311.pyc b/torch_harmonics/examples/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..cb5f3f4e Binary files /dev/null and b/torch_harmonics/examples/__pycache__/__init__.cpython-311.pyc differ diff --git a/torch_harmonics/examples/__pycache__/pde_sphere.cpython-311.pyc b/torch_harmonics/examples/__pycache__/pde_sphere.cpython-311.pyc new file mode 100644 index 00000000..6a2aa4e3 Binary files /dev/null and b/torch_harmonics/examples/__pycache__/pde_sphere.cpython-311.pyc differ diff --git a/torch_harmonics/examples/__pycache__/shallow_water_equations.cpython-311.pyc b/torch_harmonics/examples/__pycache__/shallow_water_equations.cpython-311.pyc new file mode 100644 index 00000000..00f0209c Binary files /dev/null and b/torch_harmonics/examples/__pycache__/shallow_water_equations.cpython-311.pyc differ diff --git a/torch_harmonics/examples/sfno/.ipynb_checkpoints/__init__-checkpoint.py b/torch_harmonics/examples/sfno/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 00000000..6a7bea8d --- /dev/null +++ b/torch_harmonics/examples/sfno/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,33 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from .utils.pde_dataset import PdeDataset +from .models.sfno import SphericalFourierNeuralOperatorNet diff --git a/torch_harmonics/examples/sfno/__pycache__/__init__.cpython-311.pyc b/torch_harmonics/examples/sfno/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..a640c519 Binary files /dev/null and b/torch_harmonics/examples/sfno/__pycache__/__init__.cpython-311.pyc differ diff --git a/torch_harmonics/examples/sfno/models/.ipynb_checkpoints/__init__-checkpoint.py b/torch_harmonics/examples/sfno/models/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 00000000..ebaf52ae --- /dev/null +++ b/torch_harmonics/examples/sfno/models/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,30 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# \ No newline at end of file diff --git a/torch_harmonics/examples/sfno/models/.ipynb_checkpoints/sfno-checkpoint.py b/torch_harmonics/examples/sfno/models/.ipynb_checkpoints/sfno-checkpoint.py new file mode 100644 index 00000000..18a66b3b --- /dev/null +++ b/torch_harmonics/examples/sfno/models/.ipynb_checkpoints/sfno-checkpoint.py @@ -0,0 +1,536 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import torch +import torch.nn as nn + +import sys + +sys.path.append("../../../../") +from torch_harmonics import * + +from .layers import * + +from functools import partial + + +class SpectralFilterLayer(nn.Module): + """ + Fourier layer. Contains the convolution part of the FNO/SFNO + """ + + def __init__( + self, + forward_transform, + inverse_transform, + input_dim, + output_dim, + gain = 2., + operator_type = "diagonal", + hidden_size_factor = 2, + factorization = None, + separable = False, + rank = 1e-2, + bias = True): + super(SpectralFilterLayer, self).__init__() + + if factorization is None: + self.filter = SpectralConvS2(forward_transform, + inverse_transform, + input_dim, + output_dim, + gain = gain, + operator_type = operator_type, + bias = bias) + + elif factorization is not None: + self.filter = FactorizedSpectralConvS2(forward_transform, + inverse_transform, + input_dim, + output_dim, + gain = gain, + operator_type = operator_type, + rank = rank, + factorization = factorization, + separable = separable, + bias = bias) + + else: + raise(NotImplementedError) + + def forward(self, x): + return self.filter(x) + +class SphericalFourierNeuralOperatorBlock(nn.Module): + """ + Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks. + """ + def __init__( + self, + forward_transform, + inverse_transform, + input_dim, + output_dim, + operator_type = "driscoll-healy", + mlp_ratio = 2., + drop_rate = 0., + drop_path = 0., + act_layer = nn.ReLU, + norm_layer = nn.Identity, + factorization = None, + separable = False, + rank = 128, + inner_skip = "linear", + outer_skip = None, + use_mlp = True): + super(SphericalFourierNeuralOperatorBlock, self).__init__() + + if act_layer == nn.Identity: + gain_factor = 1.0 + else: + gain_factor = 2.0 + + if inner_skip == "linear" or inner_skip == "identity": + gain_factor /= 2.0 + + # convolution layer + self.filter = SpectralFilterLayer(forward_transform, + inverse_transform, + input_dim, + output_dim, + gain = gain_factor, + operator_type = operator_type, + hidden_size_factor = mlp_ratio, + factorization = factorization, + separable = separable, + rank = rank, + bias = True) + + if inner_skip == "linear": + self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1) + nn.init.normal_(self.inner_skip.weight, std=math.sqrt(gain_factor/input_dim)) + elif inner_skip == "identity": + assert input_dim == output_dim + self.inner_skip = nn.Identity() + elif inner_skip == "none": + pass + else: + raise ValueError(f"Unknown skip connection type {inner_skip}") + + self.act_layer = act_layer() + + # first normalisation layer + self.norm0 = norm_layer() + + # dropout + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + gain_factor = 1.0 + if outer_skip == "linear" or inner_skip == "identity": + gain_factor /= 2. + + if use_mlp == True: + mlp_hidden_dim = int(output_dim * mlp_ratio) + self.mlp = MLP(in_features = output_dim, + out_features = input_dim, + hidden_features = mlp_hidden_dim, + act_layer = act_layer, + drop_rate = drop_rate, + checkpointing = False, + gain = gain_factor) + + if outer_skip == "linear": + self.outer_skip = nn.Conv2d(input_dim, input_dim, 1, 1) + torch.nn.init.normal_(self.outer_skip.weight, std=math.sqrt(gain_factor/input_dim)) + elif outer_skip == "identity": + assert input_dim == output_dim + self.outer_skip = nn.Identity() + elif outer_skip == "none": + pass + else: + raise ValueError(f"Unknown skip connection type {outer_skip}") + + # second normalisation layer + self.norm1 = norm_layer() + + # def init_weights(self, scale): + # if hasattr(self, "inner_skip") and isinstance(self.inner_skip, nn.Conv2d): + # gain_factor = 1. + # scale = (gain_factor / embed_dim)**0.5 + # nn.init.normal_(self.inner_skip.weight, mean=0., std=scale) + # self.filter.filter.init_weights(scale) + # else: + # gain_factor = 2. + # scale = (gain_factor / embed_dim)**0.5 + # self.filter.filter.init_weights(scale) + + def forward(self, x): + + x, residual = self.filter(x) + + x = self.norm0(x) + + if hasattr(self, "inner_skip"): + x = x + self.inner_skip(residual) + + if hasattr(self, "act_layer"): + x = self.act_layer(x) + + if hasattr(self, "mlp"): + x = self.mlp(x) + + x = self.norm1(x) + + x = self.drop_path(x) + + if hasattr(self, "outer_skip"): + x = x + self.outer_skip(residual) + + return x + +class SphericalFourierNeuralOperatorNet(nn.Module): + """ + SphericalFourierNeuralOperator module. Can use both FFTs and SHTs to represent either FNO or SFNO, + both linear and non-linear variants. + + Parameters + ---------- + spectral_transform : str, optional + Type of spectral transformation to use, by default "sht" + operator_type : str, optional + Type of operator to use ('driscoll-healy', 'diagonal'), by default "driscoll-healy" + img_shape : tuple, optional + Shape of the input channels, by default (128, 256) + scale_factor : int, optional + Scale factor to use, by default 3 + in_chans : int, optional + Number of input channels, by default 3 + out_chans : int, optional + Number of output channels, by default 3 + embed_dim : int, optional + Dimension of the embeddings, by default 256 + num_layers : int, optional + Number of layers in the network, by default 4 + activation_function : str, optional + Activation function to use, by default "gelu" + encoder_layers : int, optional + Number of layers in the encoder, by default 1 + use_mlp : int, optional + Whether to use MLPs in the SFNO blocks, by default True + mlp_ratio : int, optional + Ratio of MLP to use, by default 2.0 + drop_rate : float, optional + Dropout rate, by default 0.0 + drop_path_rate : float, optional + Dropout path rate, by default 0.0 + normalization_layer : str, optional + Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm" + hard_thresholding_fraction : float, optional + Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0 + big_skip : bool, optional + Whether to add a single large skip connection, by default True + rank : float, optional + Rank of the approximation, by default 1.0 + factorization : Any, optional + Type of factorization to use, by default None + separable : bool, optional + Whether to use separable convolutions, by default False + rank : (int, Tuple[int]), optional + If a factorization is used, which rank to use. Argument is passed to tensorly + pos_embed : bool, optional + Whether to use positional embedding, by default True + + Example: + -------- + >>> model = SphericalFourierNeuralOperatorNet( + ... img_shape=(128, 256), + ... scale_factor=4, + ... in_chans=2, + ... out_chans=2, + ... embed_dim=16, + ... num_layers=4, + ... use_mlp=True,) + >>> model(torch.randn(1, 2, 128, 256)).shape + torch.Size([1, 2, 128, 256]) + """ + + def __init__( + self, + spectral_transform = "sht", + operator_type = "driscoll-healy", + img_size = (128, 256), + grid = "equiangular", + scale_factor = 3, + in_chans = 3, + out_chans = 3, + embed_dim = 256, + num_layers = 4, + activation_function = "relu", + encoder_layers = 1, + use_mlp = True, + mlp_ratio = 2., + drop_rate = 0., + drop_path_rate = 0., + normalization_layer = "none", + hard_thresholding_fraction = 1.0, + use_complex_kernels = True, + big_skip = False, + factorization = None, + separable = False, + rank = 128, + pos_embed = False): + + super(SphericalFourierNeuralOperatorNet, self).__init__() + + self.spectral_transform = spectral_transform + self.operator_type = operator_type + self.img_size = img_size + self.grid = grid + self.scale_factor = scale_factor + self.in_chans = in_chans + self.out_chans = out_chans + self.embed_dim = embed_dim + self.num_layers = num_layers + self.hard_thresholding_fraction = hard_thresholding_fraction + self.normalization_layer = normalization_layer + self.use_mlp = use_mlp + self.encoder_layers = encoder_layers + self.big_skip = big_skip + self.factorization = factorization + self.separable = separable, + self.rank = rank + + # activation function + if activation_function == "relu": + self.activation_function = nn.ReLU + elif activation_function == "gelu": + self.activation_function = nn.GELU + # for debugging purposes + elif activation_function == "identity": + self.activation_function = nn.Identity + else: + raise ValueError(f"Unknown activation function {activation_function}") + + # compute downsampled image size + self.h = self.img_size[0] // scale_factor + self.w = self.img_size[1] // scale_factor + + # dropout + self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0. else nn.Identity() + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)] + + # pick norm layer + if self.normalization_layer == "layer_norm": + norm_layer0 = partial(nn.LayerNorm, normalized_shape=(self.img_size[0], self.img_size[1]), eps=1e-6) + norm_layer1 = partial(nn.LayerNorm, normalized_shape=(self.h, self.w), eps=1e-6) + elif self.normalization_layer == "instance_norm": + norm_layer0 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False) + norm_layer1 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False) + elif self.normalization_layer == "none": + norm_layer0 = nn.Identity + norm_layer1 = norm_layer0 + else: + raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.") + + if pos_embed == "latlon" or pos_embed==True: + self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1])) + nn.init.constant_(self.pos_embed, 0.0) + elif pos_embed == "lat": + self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], 1)) + nn.init.constant_(self.pos_embed, 0.0) + elif pos_embed == "const": + self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, 1, 1)) + nn.init.constant_(self.pos_embed, 0.0) + else: + self.pos_embed = None + + # # encoder + # encoder_hidden_dim = int(self.embed_dim * mlp_ratio) + # encoder = MLP(in_features = self.in_chans, + # out_features = self.embed_dim, + # hidden_features = encoder_hidden_dim, + # act_layer = self.activation_function, + # drop_rate = drop_rate, + # checkpointing = False) + # self.encoder = encoder + + + # construct an encoder with num_encoder_layers + num_encoder_layers = 1 + encoder_hidden_dim = int(self.embed_dim * mlp_ratio) + current_dim = self.in_chans + encoder_layers = [] + for l in range(num_encoder_layers-1): + fc = nn.Conv2d(current_dim, encoder_hidden_dim, 1, bias=True) + # initialize the weights correctly + scale = math.sqrt(2. / current_dim) + nn.init.normal_(fc.weight, mean=0., std=scale) + if fc.bias is not None: + nn.init.constant_(fc.bias, 0.0) + encoder_layers.append(fc) + encoder_layers.append(self.activation_function()) + current_dim = encoder_hidden_dim + fc = nn.Conv2d(current_dim, self.embed_dim, 1, bias=False) + scale = math.sqrt(1. / current_dim) + nn.init.normal_(fc.weight, mean=0., std=scale) + if fc.bias is not None: + nn.init.constant_(fc.bias, 0.0) + encoder_layers.append(fc) + self.encoder = nn.Sequential(*encoder_layers) + + # prepare the spectral transform + if self.spectral_transform == "sht": + + modes_lat = int(self.h * self.hard_thresholding_fraction) + modes_lon = int(self.w//2 * self.hard_thresholding_fraction) + modes_lat = modes_lon = min(modes_lat, modes_lon) + + self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float() + self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float() + self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float() + self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float() + + elif self.spectral_transform == "fft": + + modes_lat = int(self.h * self.hard_thresholding_fraction) + modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction) + + self.trans_down = RealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float() + self.itrans_up = InverseRealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float() + self.trans = RealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float() + self.itrans = InverseRealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float() + + else: + raise(ValueError("Unknown spectral transform")) + + self.blocks = nn.ModuleList([]) + for i in range(self.num_layers): + + first_layer = i == 0 + last_layer = i == self.num_layers-1 + + forward_transform = self.trans_down if first_layer else self.trans + inverse_transform = self.itrans_up if last_layer else self.itrans + + inner_skip = "none" + outer_skip = "identity" + + if first_layer: + norm_layer = norm_layer1 + elif last_layer: + norm_layer = norm_layer0 + else: + norm_layer = norm_layer1 + + block = SphericalFourierNeuralOperatorBlock(forward_transform, + inverse_transform, + self.embed_dim, + self.embed_dim, + operator_type = self.operator_type, + mlp_ratio = mlp_ratio, + drop_rate = drop_rate, + drop_path = dpr[i], + act_layer = self.activation_function, + norm_layer = norm_layer, + inner_skip = inner_skip, + outer_skip = outer_skip, + use_mlp = use_mlp, + factorization = self.factorization, + separable = self.separable, + rank = self.rank) + + self.blocks.append(block) + + # # decoder + # decoder_hidden_dim = int(self.embed_dim * mlp_ratio) + # self.decoder = MLP(in_features = self.embed_dim + self.big_skip*self.in_chans, + # out_features = self.out_chans, + # hidden_features = decoder_hidden_dim, + # act_layer = self.activation_function, + # drop_rate = drop_rate, + # checkpointing = False) + + # construct an decoder with num_decoder_layers + num_decoder_layers = 1 + decoder_hidden_dim = int(self.embed_dim * mlp_ratio) + current_dim = self.embed_dim + self.big_skip*self.in_chans + decoder_layers = [] + for l in range(num_decoder_layers-1): + fc = nn.Conv2d(current_dim, decoder_hidden_dim, 1, bias=True) + # initialize the weights correctly + scale = math.sqrt(2. / current_dim) + nn.init.normal_(fc.weight, mean=0., std=scale) + if fc.bias is not None: + nn.init.constant_(fc.bias, 0.0) + decoder_layers.append(fc) + decoder_layers.append(self.activation_function()) + current_dim = decoder_hidden_dim + fc = nn.Conv2d(current_dim, self.out_chans, 1, bias=False) + scale = math.sqrt(1. / current_dim) + nn.init.normal_(fc.weight, mean=0., std=scale) + if fc.bias is not None: + nn.init.constant_(fc.bias, 0.0) + decoder_layers.append(fc) + self.decoder = nn.Sequential(*decoder_layers) + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed", "cls_token"} + + def forward_features(self, x): + + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + return x + + def forward(self, x): + + if self.big_skip: + residual = x + + x = self.encoder(x) + + if self.pos_embed is not None: + x = x + self.pos_embed + + x = self.forward_features(x) + + if self.big_skip: + x = torch.cat((x, residual), dim=1) + + x = self.decoder(x) + + return x + + diff --git a/torch_harmonics/examples/sfno/models/__pycache__/__init__.cpython-311.pyc b/torch_harmonics/examples/sfno/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..5e7928ec Binary files /dev/null and b/torch_harmonics/examples/sfno/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/torch_harmonics/examples/sfno/models/__pycache__/activations.cpython-311.pyc b/torch_harmonics/examples/sfno/models/__pycache__/activations.cpython-311.pyc new file mode 100644 index 00000000..9f3a68aa Binary files /dev/null and b/torch_harmonics/examples/sfno/models/__pycache__/activations.cpython-311.pyc differ diff --git a/torch_harmonics/examples/sfno/models/__pycache__/contractions.cpython-311.pyc b/torch_harmonics/examples/sfno/models/__pycache__/contractions.cpython-311.pyc new file mode 100644 index 00000000..c2b9e9b1 Binary files /dev/null and b/torch_harmonics/examples/sfno/models/__pycache__/contractions.cpython-311.pyc differ diff --git a/torch_harmonics/examples/sfno/models/__pycache__/layers.cpython-311.pyc b/torch_harmonics/examples/sfno/models/__pycache__/layers.cpython-311.pyc new file mode 100644 index 00000000..02e8e499 Binary files /dev/null and b/torch_harmonics/examples/sfno/models/__pycache__/layers.cpython-311.pyc differ diff --git a/torch_harmonics/examples/sfno/models/__pycache__/sfno.cpython-311.pyc b/torch_harmonics/examples/sfno/models/__pycache__/sfno.cpython-311.pyc new file mode 100644 index 00000000..401cf4d8 Binary files /dev/null and b/torch_harmonics/examples/sfno/models/__pycache__/sfno.cpython-311.pyc differ diff --git a/torch_harmonics/examples/sfno/models/sfno.py b/torch_harmonics/examples/sfno/models/sfno.py index 96e07a37..18a66b3b 100644 --- a/torch_harmonics/examples/sfno/models/sfno.py +++ b/torch_harmonics/examples/sfno/models/sfno.py @@ -32,6 +32,9 @@ import torch import torch.nn as nn +import sys + +sys.path.append("../../../../") from torch_harmonics import * from .layers import * diff --git a/torch_harmonics/examples/sfno/utils/.ipynb_checkpoints/pde_dataset-checkpoint.py b/torch_harmonics/examples/sfno/utils/.ipynb_checkpoints/pde_dataset-checkpoint.py new file mode 100644 index 00000000..96ee30fd --- /dev/null +++ b/torch_harmonics/examples/sfno/utils/.ipynb_checkpoints/pde_dataset-checkpoint.py @@ -0,0 +1,115 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import torch + +from math import ceil + +from ...shallow_water_equations import ShallowWaterSolver + +class PdeDataset(torch.utils.data.Dataset): + """Custom Dataset class for PDE training data""" + def __init__(self, dt, nsteps, dims=(384, 768), pde='shallow water equations', initial_condition='random', + num_examples=32, device=torch.device('cpu'), normalize=True, stream=None): + self.num_examples = num_examples + self.device = device + self.stream = stream + + self.nlat = dims[0] + self.nlon = dims[1] + + # number of solver steps used to compute the target + self.nsteps = nsteps + self.normalize = normalize + + if pde == 'shallow water equations': + lmax = ceil(self.nlat/3) + mmax = lmax + dt_solver = dt / float(self.nsteps) + self.solver = ShallowWaterSolver(self.nlat, self.nlon, dt_solver, lmax=lmax, mmax=mmax, grid='equiangular').to(self.device).float() + else: + raise NotImplementedError + + self.set_initial_condition(ictype=initial_condition) + + if self.normalize: + inp0, _ = self._get_sample() + self.inp_mean = torch.mean(inp0, dim=(-1, -2)).reshape(-1, 1, 1) + self.inp_var = torch.var(inp0, dim=(-1, -2)).reshape(-1, 1, 1) + + def __len__(self): + length = self.num_examples if self.ictype == 'random' else 1 + return length + + def set_initial_condition(self, ictype='random'): + self.ictype = ictype + + def set_num_examples(self, num_examples=32): + self.num_examples = num_examples + + def _get_sample(self): + if self.ictype == 'random': + inp = self.solver.random_initial_condition(mach=0.2) + elif self.ictype == 'galewsky': + inp = self.solver.galewsky_initial_condition() + + # solve pde for n steps to return the target + tar = self.solver.timestep(inp, self.nsteps) + inp = self.solver.spec2grid(inp) + tar = self.solver.spec2grid(tar) + + return inp, tar + + def __getitem__(self, index): + + # if self.stream is None: + # self.stream = torch.cuda.Stream() + + # with torch.cuda.stream(self.stream): + # with torch.inference_mode(): + # with torch.no_grad(): + # inp, tar = self._get_sample() + + # if self.normalize: + # inp = (inp - self.inp_mean) / torch.sqrt(self.inp_var) + # tar = (tar - self.inp_mean) / torch.sqrt(self.inp_var) + + # self.stream.synchronize() + + with torch.inference_mode(): + with torch.no_grad(): + inp, tar = self._get_sample() + + if self.normalize: + inp = (inp - self.inp_mean) / torch.sqrt(self.inp_var) + tar = (tar - self.inp_mean) / torch.sqrt(self.inp_var) + + return inp.clone(), tar.clone() diff --git a/torch_harmonics/examples/sfno/utils/__pycache__/pde_dataset.cpython-311.pyc b/torch_harmonics/examples/sfno/utils/__pycache__/pde_dataset.cpython-311.pyc new file mode 100644 index 00000000..b305dc3d Binary files /dev/null and b/torch_harmonics/examples/sfno/utils/__pycache__/pde_dataset.cpython-311.pyc differ diff --git a/torch_harmonics/examples/sfno_dse/.ipynb_checkpoints/__init__-checkpoint.py b/torch_harmonics/examples/sfno_dse/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 00000000..37f92692 --- /dev/null +++ b/torch_harmonics/examples/sfno_dse/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,33 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from .utils.pde_dataset import PdeDataset +from .models.sfno_dse import SFNODSEFp, SFNODSEVp diff --git a/torch_harmonics/examples/sfno_dse/__init__.py b/torch_harmonics/examples/sfno_dse/__init__.py new file mode 100644 index 00000000..37f92692 --- /dev/null +++ b/torch_harmonics/examples/sfno_dse/__init__.py @@ -0,0 +1,33 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from .utils.pde_dataset import PdeDataset +from .models.sfno_dse import SFNODSEFp, SFNODSEVp diff --git a/torch_harmonics/examples/sfno_dse/models/.ipynb_checkpoints/__init__-checkpoint.py b/torch_harmonics/examples/sfno_dse/models/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 00000000..ebaf52ae --- /dev/null +++ b/torch_harmonics/examples/sfno_dse/models/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,30 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# \ No newline at end of file diff --git a/torch_harmonics/examples/sfno_dse/models/.ipynb_checkpoints/sfno_dse-checkpoint.py b/torch_harmonics/examples/sfno_dse/models/.ipynb_checkpoints/sfno_dse-checkpoint.py new file mode 100644 index 00000000..4b115e07 --- /dev/null +++ b/torch_harmonics/examples/sfno_dse/models/.ipynb_checkpoints/sfno_dse-checkpoint.py @@ -0,0 +1,311 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch_harmonics import * + + +################################################## +# Classes for data with fixed collocation points +################################################## + +class SHTDSEConvFp(nn.Module): + r""" + Defines a module for computing the convolution in the Spherical Harmonic Domain. + This is computed as a vector-vector multiplication between a vector of spherical harmonic + coefficients and a vector of learnable weights. + + """ + def __init__(self, in_channels, out_channels, degree, sht_transform): + """ + Initializes the SHT Convolution Layer, + + Parameters: + in_channels: number of channels to be taken as input for the SH convolution + out_channels: number of channels produced by the spherical harmonics convolution + degree: degree of the spherical harmonics, total number of modes equal to degree^2 + sht_transform: computes the SH transform itself + """ + super(SHTDSEConvFp, self).__init__() + + + self.in_channels = in_channels + self.out_channels = out_channels + self.degree = degree + + self.scale = (1 / (in_channels * out_channels)) + self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.degree**2, dtype=torch.float)) + + self.sht_transform = sht_transform + + def compl_mul1d(self, input, weights): + """ + Computes the SH via multiplication in SH domain + + Inputs: + input; SH data + weights; trainable weights for convolution in SH domain + Outputs: + data convolved via multiplication in SH domain + """ + # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) + return torch.einsum("bix,iox->box", input, weights) + + def forward(self, x): + """ + Computes the forward pass of the SH Convolution layer + + Inputs: + x; data in spatial domain + Outputs: + x; data convolved in SH domain with learnable weights + """ + # Calculate SH for given data + x_ft = self.sht_transform.forward(x) + + # Multiply relevant SH modes + out_ft = self.compl_mul1d(x_ft, self.weights1) + + # Return to physical space + x = self.sht_transform.inverse(out_ft) + + return x + + +class SFNODSEFp(nn.Module): + r""" + Defines a module for training a SFNO on FIXED arbitrary collocation (measurement) points. + + [1] L. Lingsch, M. Michelis, E. de Bezenac, S. M. Perera, R. K. Katzschmann, S. Mishra; + Beyond Regular Grids: Fourier-Based Neural Operators on Arbitrary Domains; ICML 2024 + """ + def __init__(self, in_channels, out_channels, degree, width, sht_transform, num_layers): + """ + Initializes the class to learn the SFNO. + + Parameters: + in_channels: number of channels of the input data + out_channels: number of channels for the output data (to be learned) + degree: degree of the spherical harmonics, total number of modes equal to degree^2 + sht_transform: computes the SH transform itself + num_layers: number of trainable SFNO layers (SH convolution + Pointwise convolution) + """ + super(SFNODSEFp, self).__init__() + + self.degree = degree + self.width = width + self.num_layers = num_layers + + # Input layer + self.fc0 = nn.Linear(in_channels, self.width) + + # Dynamically create convolutional and pointwise linear layers + self.conv_layers = nn.ModuleList([ + SHTDSEConvFp(self.width, self.width, self.degree, sht_transform) + for _ in range(num_layers) + ]) + + self.pointwise_layers = nn.ModuleList([ + nn.Conv1d(self.width, self.width, 1) + for _ in range(num_layers) + ]) + + # Output layers + self.fc1 = nn.Linear(self.width, 128) + self.fc2 = nn.Linear(128, out_channels) + + def forward(self, x): + """ + Learns to predict an output as a function of the input data via SH + + Inputs: + class variables + x; vector of inputs in spatial domain with dimensions [batchsize, in_channels, N] + Outputs: + x; vector of outputs in spatial domain with dimensions [batchsize, out_channels, N] + """ + x = x.permute(0, 2, 1) + x = self.fc0(x) + x = x.permute(0, 2, 1) + + # Apply dynamically created layers + for i in range(self.num_layers): + x1 = self.conv_layers[i](x) + x2 = self.pointwise_layers[i](x) + x = x1 + x2 + x = F.gelu(x) + + x = x.permute(0, 2, 1) + x = self.fc1(x) + x = F.gelu(x) + x = self.fc2(x) + x = x.permute(0, 2, 1) + return x + + +################################################## +# Classes for data with varying collocation points +################################################## + +class SHTDSEConvVp(nn.Module): + r""" + Defines a module for computing the convolution in the Spherical Harmonic Domain. + This is computed as a vector-vector multiplication between a vector of spherical harmonic + coefficients and a vector of learnable weights. + + """ + def __init__(self, in_channels, out_channels, degree): + r""" + Initializes the SHT Convolution Layer, + + Parameters: + in_channels: number of channels to be taken as input for the SH convolution + out_channels: number of channels produced by the spherical harmonics convolution + degree: degree of the spherical harmonics, total number of modes equal to degree^2 + sht_transform: computes the SH transform itself + """ + super(SHTDSEConvVp, self).__init__() + + + self.in_channels = in_channels + self.out_channels = out_channels + self.degree = degree + + self.scale = (1 / (in_channels * out_channels)) + self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.degree**2, dtype=torch.float)) + + + def compl_mul1d(self, input, weights): + """ + Computes the SH via multiplication in SH domain + + Inputs: + input; SH data + weights; trainable weights for convolution in SH domain + Outputs: + data convolved via multiplication in SH domain + """ + # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) + return torch.einsum("bix,iox->box", input, weights) + + def forward(self, x, sht_transform): + """ + Computes the forward pass of the SH Convolution layer + + Inputs: + x; data in spatial domain + Outputs: + x; data convolved in SH domain with learnable weights + """ + # Calculate SH for given data + x_ft = sht_transform.forward(x) + + # Multiply relevant SH modes + out_ft = self.compl_mul1d(x_ft, self.weights1) + + # Return to physical space + x = sht_transform.inverse(out_ft) + + return x + +class SFNODSEVp(nn.Module): + r""" + Defines a module for training a SFNO on VARIABLE arbitrary collocation (measurement) points. + + [1] L. Lingsch, M. Michelis, E. de Bezenac, S. M. Perera, R. K. Katzschmann, S. Mishra; + Beyond Regular Grids: Fourier-Based Neural Operators on Arbitrary Domains; ICML 2024 + """ + def __init__(self, in_channels, out_channels, degree, width, num_layers): + """ + Initializes the class to learn the SFNO. + + Parameters: + in_channels: number of channels of the input data + out_channels: number of channels for the output data (to be learned) + degree: degree of the spherical harmonics, total number of modes equal to degree^2 + sht_transform: computes the SH transform itself + num_layers: number of trainable SFNO layers (SH convolution + Pointwise convolution) + """ + super(SFNODSEVp, self).__init__() + + self.degree = degree + self.width = width + self.num_layers = num_layers + + # Input layer + self.fc0 = nn.Linear(in_channels, self.width) + + # Dynamically create convolutional and pointwise linear layers + self.conv_layers = nn.ModuleList([ + SHTDSEConvVp(self.width, self.width, self.degree) + for _ in range(num_layers) + ]) + + self.pointwise_layers = nn.ModuleList([ + nn.Conv1d(self.width, self.width, 1) + for _ in range(num_layers) + ]) + + # Output layers + self.fc1 = nn.Linear(self.width, 128) + self.fc2 = nn.Linear(128, out_channels) + + def forward(self, x, sht_transform): + """ + Learns to predict an output as a function of the input data via SH + + Inputs: + class variables + x; vector of inputs in spatial domain with dimensions [batchsize, in_channels, N] + Outputs: + x; vector of outputs in spatial domain with dimensions [batchsize, out_channels, N] + """ + + x = x.permute(0, 2, 1) + x = self.fc0(x) + x = x.permute(0, 2, 1) + + # Apply dynamically created layers + for i in range(self.num_layers): + x1 = self.conv_layers[i](x, sht_transform) + x2 = self.pointwise_layers[i](x) + x = x1 + x2 + x = F.gelu(x) + + x = x.permute(0, 2, 1) + x = self.fc1(x) + x = F.gelu(x) + x = self.fc2(x) + x = x.permute(0, 2, 1) + return x diff --git a/torch_harmonics/examples/sfno_dse/models/__init__.py b/torch_harmonics/examples/sfno_dse/models/__init__.py new file mode 100644 index 00000000..ebaf52ae --- /dev/null +++ b/torch_harmonics/examples/sfno_dse/models/__init__.py @@ -0,0 +1,30 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# \ No newline at end of file diff --git a/torch_harmonics/examples/sfno_dse/models/sfno_dse.py b/torch_harmonics/examples/sfno_dse/models/sfno_dse.py new file mode 100644 index 00000000..4b115e07 --- /dev/null +++ b/torch_harmonics/examples/sfno_dse/models/sfno_dse.py @@ -0,0 +1,311 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch_harmonics import * + + +################################################## +# Classes for data with fixed collocation points +################################################## + +class SHTDSEConvFp(nn.Module): + r""" + Defines a module for computing the convolution in the Spherical Harmonic Domain. + This is computed as a vector-vector multiplication between a vector of spherical harmonic + coefficients and a vector of learnable weights. + + """ + def __init__(self, in_channels, out_channels, degree, sht_transform): + """ + Initializes the SHT Convolution Layer, + + Parameters: + in_channels: number of channels to be taken as input for the SH convolution + out_channels: number of channels produced by the spherical harmonics convolution + degree: degree of the spherical harmonics, total number of modes equal to degree^2 + sht_transform: computes the SH transform itself + """ + super(SHTDSEConvFp, self).__init__() + + + self.in_channels = in_channels + self.out_channels = out_channels + self.degree = degree + + self.scale = (1 / (in_channels * out_channels)) + self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.degree**2, dtype=torch.float)) + + self.sht_transform = sht_transform + + def compl_mul1d(self, input, weights): + """ + Computes the SH via multiplication in SH domain + + Inputs: + input; SH data + weights; trainable weights for convolution in SH domain + Outputs: + data convolved via multiplication in SH domain + """ + # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) + return torch.einsum("bix,iox->box", input, weights) + + def forward(self, x): + """ + Computes the forward pass of the SH Convolution layer + + Inputs: + x; data in spatial domain + Outputs: + x; data convolved in SH domain with learnable weights + """ + # Calculate SH for given data + x_ft = self.sht_transform.forward(x) + + # Multiply relevant SH modes + out_ft = self.compl_mul1d(x_ft, self.weights1) + + # Return to physical space + x = self.sht_transform.inverse(out_ft) + + return x + + +class SFNODSEFp(nn.Module): + r""" + Defines a module for training a SFNO on FIXED arbitrary collocation (measurement) points. + + [1] L. Lingsch, M. Michelis, E. de Bezenac, S. M. Perera, R. K. Katzschmann, S. Mishra; + Beyond Regular Grids: Fourier-Based Neural Operators on Arbitrary Domains; ICML 2024 + """ + def __init__(self, in_channels, out_channels, degree, width, sht_transform, num_layers): + """ + Initializes the class to learn the SFNO. + + Parameters: + in_channels: number of channels of the input data + out_channels: number of channels for the output data (to be learned) + degree: degree of the spherical harmonics, total number of modes equal to degree^2 + sht_transform: computes the SH transform itself + num_layers: number of trainable SFNO layers (SH convolution + Pointwise convolution) + """ + super(SFNODSEFp, self).__init__() + + self.degree = degree + self.width = width + self.num_layers = num_layers + + # Input layer + self.fc0 = nn.Linear(in_channels, self.width) + + # Dynamically create convolutional and pointwise linear layers + self.conv_layers = nn.ModuleList([ + SHTDSEConvFp(self.width, self.width, self.degree, sht_transform) + for _ in range(num_layers) + ]) + + self.pointwise_layers = nn.ModuleList([ + nn.Conv1d(self.width, self.width, 1) + for _ in range(num_layers) + ]) + + # Output layers + self.fc1 = nn.Linear(self.width, 128) + self.fc2 = nn.Linear(128, out_channels) + + def forward(self, x): + """ + Learns to predict an output as a function of the input data via SH + + Inputs: + class variables + x; vector of inputs in spatial domain with dimensions [batchsize, in_channels, N] + Outputs: + x; vector of outputs in spatial domain with dimensions [batchsize, out_channels, N] + """ + x = x.permute(0, 2, 1) + x = self.fc0(x) + x = x.permute(0, 2, 1) + + # Apply dynamically created layers + for i in range(self.num_layers): + x1 = self.conv_layers[i](x) + x2 = self.pointwise_layers[i](x) + x = x1 + x2 + x = F.gelu(x) + + x = x.permute(0, 2, 1) + x = self.fc1(x) + x = F.gelu(x) + x = self.fc2(x) + x = x.permute(0, 2, 1) + return x + + +################################################## +# Classes for data with varying collocation points +################################################## + +class SHTDSEConvVp(nn.Module): + r""" + Defines a module for computing the convolution in the Spherical Harmonic Domain. + This is computed as a vector-vector multiplication between a vector of spherical harmonic + coefficients and a vector of learnable weights. + + """ + def __init__(self, in_channels, out_channels, degree): + r""" + Initializes the SHT Convolution Layer, + + Parameters: + in_channels: number of channels to be taken as input for the SH convolution + out_channels: number of channels produced by the spherical harmonics convolution + degree: degree of the spherical harmonics, total number of modes equal to degree^2 + sht_transform: computes the SH transform itself + """ + super(SHTDSEConvVp, self).__init__() + + + self.in_channels = in_channels + self.out_channels = out_channels + self.degree = degree + + self.scale = (1 / (in_channels * out_channels)) + self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.degree**2, dtype=torch.float)) + + + def compl_mul1d(self, input, weights): + """ + Computes the SH via multiplication in SH domain + + Inputs: + input; SH data + weights; trainable weights for convolution in SH domain + Outputs: + data convolved via multiplication in SH domain + """ + # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) + return torch.einsum("bix,iox->box", input, weights) + + def forward(self, x, sht_transform): + """ + Computes the forward pass of the SH Convolution layer + + Inputs: + x; data in spatial domain + Outputs: + x; data convolved in SH domain with learnable weights + """ + # Calculate SH for given data + x_ft = sht_transform.forward(x) + + # Multiply relevant SH modes + out_ft = self.compl_mul1d(x_ft, self.weights1) + + # Return to physical space + x = sht_transform.inverse(out_ft) + + return x + +class SFNODSEVp(nn.Module): + r""" + Defines a module for training a SFNO on VARIABLE arbitrary collocation (measurement) points. + + [1] L. Lingsch, M. Michelis, E. de Bezenac, S. M. Perera, R. K. Katzschmann, S. Mishra; + Beyond Regular Grids: Fourier-Based Neural Operators on Arbitrary Domains; ICML 2024 + """ + def __init__(self, in_channels, out_channels, degree, width, num_layers): + """ + Initializes the class to learn the SFNO. + + Parameters: + in_channels: number of channels of the input data + out_channels: number of channels for the output data (to be learned) + degree: degree of the spherical harmonics, total number of modes equal to degree^2 + sht_transform: computes the SH transform itself + num_layers: number of trainable SFNO layers (SH convolution + Pointwise convolution) + """ + super(SFNODSEVp, self).__init__() + + self.degree = degree + self.width = width + self.num_layers = num_layers + + # Input layer + self.fc0 = nn.Linear(in_channels, self.width) + + # Dynamically create convolutional and pointwise linear layers + self.conv_layers = nn.ModuleList([ + SHTDSEConvVp(self.width, self.width, self.degree) + for _ in range(num_layers) + ]) + + self.pointwise_layers = nn.ModuleList([ + nn.Conv1d(self.width, self.width, 1) + for _ in range(num_layers) + ]) + + # Output layers + self.fc1 = nn.Linear(self.width, 128) + self.fc2 = nn.Linear(128, out_channels) + + def forward(self, x, sht_transform): + """ + Learns to predict an output as a function of the input data via SH + + Inputs: + class variables + x; vector of inputs in spatial domain with dimensions [batchsize, in_channels, N] + Outputs: + x; vector of outputs in spatial domain with dimensions [batchsize, out_channels, N] + """ + + x = x.permute(0, 2, 1) + x = self.fc0(x) + x = x.permute(0, 2, 1) + + # Apply dynamically created layers + for i in range(self.num_layers): + x1 = self.conv_layers[i](x, sht_transform) + x2 = self.pointwise_layers[i](x) + x = x1 + x2 + x = F.gelu(x) + + x = x.permute(0, 2, 1) + x = self.fc1(x) + x = F.gelu(x) + x = self.fc2(x) + x = x.permute(0, 2, 1) + return x diff --git a/torch_harmonics/examples/sfno_dse/utils/.ipynb_checkpoints/pde_dataset-checkpoint.py b/torch_harmonics/examples/sfno_dse/utils/.ipynb_checkpoints/pde_dataset-checkpoint.py new file mode 100644 index 00000000..96ee30fd --- /dev/null +++ b/torch_harmonics/examples/sfno_dse/utils/.ipynb_checkpoints/pde_dataset-checkpoint.py @@ -0,0 +1,115 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import torch + +from math import ceil + +from ...shallow_water_equations import ShallowWaterSolver + +class PdeDataset(torch.utils.data.Dataset): + """Custom Dataset class for PDE training data""" + def __init__(self, dt, nsteps, dims=(384, 768), pde='shallow water equations', initial_condition='random', + num_examples=32, device=torch.device('cpu'), normalize=True, stream=None): + self.num_examples = num_examples + self.device = device + self.stream = stream + + self.nlat = dims[0] + self.nlon = dims[1] + + # number of solver steps used to compute the target + self.nsteps = nsteps + self.normalize = normalize + + if pde == 'shallow water equations': + lmax = ceil(self.nlat/3) + mmax = lmax + dt_solver = dt / float(self.nsteps) + self.solver = ShallowWaterSolver(self.nlat, self.nlon, dt_solver, lmax=lmax, mmax=mmax, grid='equiangular').to(self.device).float() + else: + raise NotImplementedError + + self.set_initial_condition(ictype=initial_condition) + + if self.normalize: + inp0, _ = self._get_sample() + self.inp_mean = torch.mean(inp0, dim=(-1, -2)).reshape(-1, 1, 1) + self.inp_var = torch.var(inp0, dim=(-1, -2)).reshape(-1, 1, 1) + + def __len__(self): + length = self.num_examples if self.ictype == 'random' else 1 + return length + + def set_initial_condition(self, ictype='random'): + self.ictype = ictype + + def set_num_examples(self, num_examples=32): + self.num_examples = num_examples + + def _get_sample(self): + if self.ictype == 'random': + inp = self.solver.random_initial_condition(mach=0.2) + elif self.ictype == 'galewsky': + inp = self.solver.galewsky_initial_condition() + + # solve pde for n steps to return the target + tar = self.solver.timestep(inp, self.nsteps) + inp = self.solver.spec2grid(inp) + tar = self.solver.spec2grid(tar) + + return inp, tar + + def __getitem__(self, index): + + # if self.stream is None: + # self.stream = torch.cuda.Stream() + + # with torch.cuda.stream(self.stream): + # with torch.inference_mode(): + # with torch.no_grad(): + # inp, tar = self._get_sample() + + # if self.normalize: + # inp = (inp - self.inp_mean) / torch.sqrt(self.inp_var) + # tar = (tar - self.inp_mean) / torch.sqrt(self.inp_var) + + # self.stream.synchronize() + + with torch.inference_mode(): + with torch.no_grad(): + inp, tar = self._get_sample() + + if self.normalize: + inp = (inp - self.inp_mean) / torch.sqrt(self.inp_var) + tar = (tar - self.inp_mean) / torch.sqrt(self.inp_var) + + return inp.clone(), tar.clone() diff --git a/torch_harmonics/examples/sfno_dse/utils/pde_dataset.py b/torch_harmonics/examples/sfno_dse/utils/pde_dataset.py new file mode 100644 index 00000000..96ee30fd --- /dev/null +++ b/torch_harmonics/examples/sfno_dse/utils/pde_dataset.py @@ -0,0 +1,115 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import torch + +from math import ceil + +from ...shallow_water_equations import ShallowWaterSolver + +class PdeDataset(torch.utils.data.Dataset): + """Custom Dataset class for PDE training data""" + def __init__(self, dt, nsteps, dims=(384, 768), pde='shallow water equations', initial_condition='random', + num_examples=32, device=torch.device('cpu'), normalize=True, stream=None): + self.num_examples = num_examples + self.device = device + self.stream = stream + + self.nlat = dims[0] + self.nlon = dims[1] + + # number of solver steps used to compute the target + self.nsteps = nsteps + self.normalize = normalize + + if pde == 'shallow water equations': + lmax = ceil(self.nlat/3) + mmax = lmax + dt_solver = dt / float(self.nsteps) + self.solver = ShallowWaterSolver(self.nlat, self.nlon, dt_solver, lmax=lmax, mmax=mmax, grid='equiangular').to(self.device).float() + else: + raise NotImplementedError + + self.set_initial_condition(ictype=initial_condition) + + if self.normalize: + inp0, _ = self._get_sample() + self.inp_mean = torch.mean(inp0, dim=(-1, -2)).reshape(-1, 1, 1) + self.inp_var = torch.var(inp0, dim=(-1, -2)).reshape(-1, 1, 1) + + def __len__(self): + length = self.num_examples if self.ictype == 'random' else 1 + return length + + def set_initial_condition(self, ictype='random'): + self.ictype = ictype + + def set_num_examples(self, num_examples=32): + self.num_examples = num_examples + + def _get_sample(self): + if self.ictype == 'random': + inp = self.solver.random_initial_condition(mach=0.2) + elif self.ictype == 'galewsky': + inp = self.solver.galewsky_initial_condition() + + # solve pde for n steps to return the target + tar = self.solver.timestep(inp, self.nsteps) + inp = self.solver.spec2grid(inp) + tar = self.solver.spec2grid(tar) + + return inp, tar + + def __getitem__(self, index): + + # if self.stream is None: + # self.stream = torch.cuda.Stream() + + # with torch.cuda.stream(self.stream): + # with torch.inference_mode(): + # with torch.no_grad(): + # inp, tar = self._get_sample() + + # if self.normalize: + # inp = (inp - self.inp_mean) / torch.sqrt(self.inp_var) + # tar = (tar - self.inp_mean) / torch.sqrt(self.inp_var) + + # self.stream.synchronize() + + with torch.inference_mode(): + with torch.no_grad(): + inp, tar = self._get_sample() + + if self.normalize: + inp = (inp - self.inp_mean) / torch.sqrt(self.inp_var) + tar = (tar - self.inp_mean) / torch.sqrt(self.inp_var) + + return inp.clone(), tar.clone() diff --git a/torch_harmonics/examples/shallow_water_equations.py b/torch_harmonics/examples/shallow_water_equations.py index 642460b1..0b23e437 100644 --- a/torch_harmonics/examples/shallow_water_equations.py +++ b/torch_harmonics/examples/shallow_water_equations.py @@ -32,6 +32,8 @@ import torch import torch.nn as nn +import sys +sys.path.append("../") import torch_harmonics as harmonics from torch_harmonics.quadrature import * @@ -239,7 +241,7 @@ def random_initial_condition(self, mach=0.1) -> torch.Tensor: ctype = torch.complex128 if self.lap.dtype == torch.float64 else torch.complex64 # mach number relative to wave speed - llimit = mlimit = 80 + llimit = mlimit = 20 # hgrid = self.havg + hamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype) # ugrid = uamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype) diff --git a/torch_harmonics/random_sampling.py b/torch_harmonics/random_sampling.py new file mode 100644 index 00000000..9ad43dd5 --- /dev/null +++ b/torch_harmonics/random_sampling.py @@ -0,0 +1,130 @@ +import numpy as np +import torch + +class RandomSphericalSampling: + r""" + Defines a module for sampling a (uniformly) random set of measurement points from a grid. + + [1] L. Lingsch, M. Michelis, E. de Bezenac, S. M. Perera, R. K. Katzschmann, S. Mishra; + Beyond Regular Grids: Fourier-Based Neural Operators on Arbitrary Domains; ICML 2024 + """ + def __init__(self, number_points_x, number_points_y): + # the data must be equispaced + self.number_points_x = number_points_x + self.number_points_y = number_points_y + np.random.seed(0) + + def random_points_on_sphere(self, n): + r""" + This function generates points within a 2x2x2 cube, centered at the origin. + Points with a radius<=1 are projected to a sphere with radius 1, centered at the origin. + Points with a radius > 1 are excluded. The newly generated random points are used to select + the closest points from the original grid, removing any duplicate points in this selection. + + Inputs: + class variables + n; approximate number of points to be selected (doubled, as about half the randomly generated points must be removed) + Outputs: + theta_index; vector indices of the original grid points to be selected along polar angle + phi_index; vector indices of the original grid points to be selected along azimuthal angle + >> used for selecting the points from the original data + theta_angle; vector of polar angles for points, ranging from 0 to pi + phi_angle; vector of azimuthal angles for points, ranging from 0 to 2*pi + """ + # Double the number of points to be selected, as approximately half will not be valid + n = n*2 + + # Generate random points in 3D space + x = np.random.uniform(-1, 1, n) + y = np.random.uniform(-1, 1, n) + z = np.random.uniform(-1, 1, n) + + # remove all points with radius greater than 1 (slightly less than half of all points) + magnitude = np.sqrt(x**2 + y**2 + z**2) + mask = magnitude <= 1.0 + magnitude_filtered = magnitude[mask] + x = x[mask] + y = y[mask] + z = z[mask] + + # Normalize the points to lie on the unit sphere + x /= magnitude_filtered + y /= magnitude_filtered + z /= magnitude_filtered + + # Return the points on the sphere + r = np.sqrt(x**2 + y**2 + z**2) + theta = np.arccos(z / r) + phi = np.arctan2(y, x) + np.pi + + theta = np.floor(theta*self.number_points_y / np.pi) + phi = np.floor(phi*self.number_points_x / (2*np.pi)) + + # remove duplicate points (there are about 2% duplicates, generally) + # Combine phi and theta into a 2D array + positions = np.column_stack((phi, theta)) + # Remove duplicate positions + unique_positions = np.unique(positions, axis=0) + + # Extract the cleaned phi and theta vectors + phi_index = unique_positions[:, 0] + theta_index = unique_positions[:, 1] + + phi_angle = torch.from_numpy(phi_index) / self.number_points_x * 2 * torch.pi + theta_angle = torch.from_numpy(theta_index) / self.number_points_y * torch.pi + + self.theta_index = theta_index + self.phi_index = phi_index + + return theta_index, phi_index, theta_angle.to(torch.float), phi_angle.to(torch.float) + + def random_sets_on_sphere(self, n, batch_size): + theta_indices = [] + phi_indices = [] + thetas = [] + phis = [] + + for _ in range(batch_size): + theta_index, phi_index, theta, phi = self.random_points_on_sphere(n) + theta_indices.append(theta_index) + phi_indices.append(phi_index) + thetas.append(theta) + phis.append(phi) + + + max_cols = min(matrix.size for matrix in theta_indices) + + padded_vectors = torch.zeros(len(theta_indices), max_cols, dtype=torch.float32, requires_grad=True) + theta_index = np.zeros((len(theta_indices), max_cols)) + phi_index = np.zeros((len(theta_indices), max_cols)) + theta = torch.zeros_like(padded_vectors) + phi = torch.zeros_like(padded_vectors) + + + for i in range(batch_size): + theta_index[i, :max_cols] = theta_indices[i][ :max_cols] + phi_index[i, :max_cols] = phi_indices[i][ :max_cols] + theta[i, :max_cols] = thetas[i][ :max_cols] + phi[i, :max_cols] = phis[i][ :max_cols] + + # for i in range(batch_size): + # n_points = theta_indices[i].size + # theta_index[i, :n_points] = theta_indices[i] + # phi_index[i, :n_points] = phi_indices[i] + # theta[i, :n_points] = thetas[i] + # phi[i, :n_points] = phis[i] + + return theta_index, phi_index, theta, phi + + + def get_random_sphere_data(self, data, thetas, phis): + + batch_size = thetas.shape[0] + num_points = thetas.shape[1] + + batch_indices = torch.arange(batch_size).unsqueeze(1).expand(-1, num_points) + + data_sparse = data[batch_indices,:,thetas,phis] + + return data_sparse.permute(0,2,1) + \ No newline at end of file diff --git a/torch_harmonics/sht_dse.py b/torch_harmonics/sht_dse.py new file mode 100644 index 00000000..3213ced6 --- /dev/null +++ b/torch_harmonics/sht_dse.py @@ -0,0 +1,218 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +import torch +import torch.nn as nn +import torch.nn.functional as F + +from scipy.special import lpmv +import numpy as np + +class RealSHTDSE(): + r""" + Defines a module for computing the forward/backward SHT on arbitrary points. + Requires the collocation points (locations of measurements on the surface of the sphere), + as defined by the polar (theta) and azimuthal (phi) angles. + The SHT is applied to the last two dimensions of the input. + + [1] L. Lingsch, M. Michelis, E. de Bezenac, S. M. Perera, R. K. Katzschmann, S. Mishra; + Beyond Regular Grids: Fourier-Based Neural Operators on Arbitrary Domains; ICML 2024 + """ + + def __init__(self, phi, theta, degree): + """ + Initializes the matrices to compute the forward/backward SHT on arbitrary points. + + Parameters: + phi: input point locations as a azimuthal angle + theta: input grid locations as a polar angle + degree: degree of the spherical harmonics, total number of modes equal to degree^2 + """ + self.theta = theta # between 0 and pi + self.phi = phi # between 0 and 2 pi + + self.degree = degree + + self.num_points = theta.shape[0] + + self.V_fwd, self.V_inv = self.make_matrix() + + def make_matrix(self): + """ + Constructs the matrices to compute spherical harmonics transforms + + Inputs: + class variables + Outputs: + V_fwd computes the forward transform via matrix multiplication + V_inv computes the inverse transform via matrix multiplication + """ + V_forward = torch.zeros((self.num_points, self.degree ** 2), dtype=torch.float) + index = 0 + for l in range(self.degree): + for m in range(-l, l+1): + if index > 0: + c = np.sqrt(2) + else: + c = 1 + if m < 0: + V_forward[:, index] = (lpmv(m, l, torch.cos(self.theta)) * torch.sin(m*self.phi)) + V_forward[:,index] = c * V_forward[:,index] / torch.max( V_forward[:,index]) + else: + V_forward[:, index] = (lpmv(m, l, torch.cos(self.theta)) * torch.cos(m*self.phi)) + V_forward[:,index] = c * V_forward[:,index] / torch.max( V_forward[:,index]) + index += 1 + + return V_forward.cuda(), torch.transpose(V_forward, 0, 1).cuda() + + def forward(self, data): + """ + Computes the spherical harmonics from the data + + Inputs: + class variables + data; vector of inputs in spatial domain + Outputs: + data_fwd; data in spherical harmonic domain up to set degree + """ + data_fwd = torch.matmul(data, self.V_fwd) + + return data_fwd + + def inverse(self, data): + """ + Computes the modified data from the spherical harmonics + Note: This is not technically an inverse, as orthogonality is not preserved. + Nonetheless, we refer to it as such. + + Inputs: + class variables + data; vector of inputs in spherical harmonics domain + Outputs: + data_inv; data in spatial domain + """ + data_inv = torch.matmul(data, self.V_inv) / self.num_points + + return data_inv + + +class BatchedRealSHTDSE(): + r""" + Defines a module for computing the forward/backward SHT on arbitrary points. + Requires the collocation points (locations of measurements on the surface of the sphere), + as defined by the polar (theta) and azimuthal (phi) angles. + The SHT is applied to the last two dimensions of the input. + + [1] L. Lingsch, M. Michelis, E. de Bezenac, S. M. Perera, R. K. Katzschmann, S. Mishra; + Beyond Regular Grids: Fourier-Based Neural Operators on Arbitrary Domains; ICML 2024 + """ + + def __init__(self, phi, theta, degree): + """ + Initializes the matrices to compute the forward/backward SHT on arbitrary points. + + Parameters: + phi: input point locations as a azimuthal angle + theta: input grid locations as a polar angle + degree: degree of the spherical harmonics, total number of modes equal to degree^2 + """ + self.theta = theta # between 0 and pi + self.phi = phi # between 0 and 2 pi + + self.degree = degree + + self.batch_size = theta.shape[0] + self.num_points = theta.shape[1] + + self.V_fwd, self.V_inv = self.make_matrix() + + + def compute_legendre_matrix(self, l): + """ + Compute all associated Legendre polynomials for degree `l` across the batch. + Uses scipy.special.lpmv to generate values in a vectorized way. + """ + theta_cos = torch.cos(self.theta) + P_l_m = [] + for m in range(-l, l + 1): + P_lm = lpmv(m, l, theta_cos.cpu().numpy()) # lpmv operates on numpy arrays + P_l_m.append(torch.tensor(P_lm, dtype=torch.float, device=self.theta.device)) + return torch.stack(P_l_m, dim=0) # Shape: (2l+1, num_points) + + def make_matrix(self): + V_fwd = torch.zeros((self.batch_size, self.num_points, self.degree ** 2), dtype=torch.float, device=self.theta.device) + + index = 0 + + for l in range(self.degree): + P_l_m = self.compute_legendre_matrix(l) # Shape: (2l+1, num_points) + + for m in range(-l, l + 1): + trig_term = torch.sin(m * self.phi) if m < 0 else torch.cos(m * self.phi) + scale_factor = np.sqrt(2) if m != 0 else 1.0 + + V_fwd[:, :, index] = scale_factor * P_l_m[m + l, :] * trig_term + V_fwd[:, :, index] /= torch.max(V_fwd[:, :, index]).clamp(min=1e-6) # Avoid division by zero + index += 1 + + + + return V_fwd.cuda(), V_fwd.permute(0,2,1).cuda() + + def forward(self, data): + """ + Computes the spherical harmonics from the data + + Inputs: + class variables + data; vector of inputs in spatial domain + Outputs: + data_fwd; data in spherical harmonic domain up to set degree + """ + data_fwd = torch.matmul(data, self.V_fwd) + + return data_fwd + + def inverse(self, data): + """ + Computes the modified data from the spherical harmonics + Note: This is not technically an inverse, as orthogonality is not preserved. + Nonetheless, we refer to it as such. + + Inputs: + class variables + data; vector of inputs in spherical harmonics domain + Outputs: + data_inv; data in spatial domain + """ + data_inv = torch.matmul(data, self.V_inv) / self.num_points + + return data_inv +