From c3fd8597a973af084b7acc3e854e9936bdca009b Mon Sep 17 00:00:00 2001 From: Mauro Bisson Date: Fri, 22 Aug 2025 13:21:40 -0700 Subject: [PATCH 01/31] Test for a new disco kernel using vectorized operations on channel-last tensors and eisum called in C++ function on already permuted tensors. --- setup.py | 3 +- torch_harmonics/disco/_disco_utils.py | 26 +- torch_harmonics/disco/convolution.py | 51 +- torch_harmonics/disco/csrc/cudamacro.h | 47 + torch_harmonics/disco/csrc/disco_cpu.cpp | 4 +- torch_harmonics/disco/csrc/disco_cuda.cuh | 8 +- torch_harmonics/disco/csrc/disco_cuda_bwd.cu | 4 +- torch_harmonics/disco/csrc/disco_cuda_fwd.cu | 1180 ++++++++++++++++- .../disco/csrc/disco_cuda_utils.cu | 184 +++ .../disco/csrc/disco_cuda_utils.cuh | 382 ++++++ .../disco/csrc/disco_interface.cpp | 4 +- 11 files changed, 1851 insertions(+), 42 deletions(-) create mode 100644 torch_harmonics/disco/csrc/cudamacro.h create mode 100644 torch_harmonics/disco/csrc/disco_cuda_utils.cu create mode 100644 torch_harmonics/disco/csrc/disco_cuda_utils.cuh diff --git a/setup.py b/setup.py index b10a7e2e..23d8e4f6 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ def get_compile_args(module_name): cpp_extra_flags.append("-fopenmp") nvcc_extra_flags = [] - if profile_mode: + if True or profile_mode: nvcc_extra_flags.append("-lineinfo") nvcc_extra_flags.append("-Xptxas=-v") @@ -128,6 +128,7 @@ def get_ext_modules(): if BUILD_CUDA: print(f"Compiling custom CUDA kernels for torch-harmonics.") disco_sources.extend([ + "torch_harmonics/disco/csrc/disco_cuda_utils.cu", "torch_harmonics/disco/csrc/disco_cuda_fwd.cu", "torch_harmonics/disco/csrc/disco_cuda_bwd.cu", ]) diff --git a/torch_harmonics/disco/_disco_utils.py b/torch_harmonics/disco/_disco_utils.py index 4ad1effb..a870356c 100644 --- a/torch_harmonics/disco/_disco_utils.py +++ b/torch_harmonics/disco/_disco_utils.py @@ -42,7 +42,7 @@ @torch.library.register_fake("disco_kernels::forward") def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, - kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + weights: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: out_shape = (inp.shape[0], inp.shape[1], kernel_size, nlat_out, nlon_out) return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) @@ -50,7 +50,7 @@ def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, @torch.library.register_fake("disco_kernels::backward") def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, - kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + weights: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: out_shape = (inp.shape[0], inp.shape[1], nlat_out, nlon_out) return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) @@ -59,10 +59,10 @@ def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, def _disco_s2_contraction_optimized( inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, - kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + weights: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: itype = inp.dtype inp = inp.to(torch.float32).contiguous() - out = disco_kernels.forward.default(inp, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out) + out = disco_kernels.forward.default(inp, roff_idx, ker_idx, row_idx, col_idx, vals, weights, kernel_size, nlat_out, nlon_out) out = out.to(itype) return out @@ -71,10 +71,10 @@ def _disco_s2_contraction_optimized( def _disco_s2_transpose_contraction_optimized( inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, - kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + weights: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: itype = inp.dtype inp = inp.to(torch.float32).contiguous() - out = disco_kernels.backward.default(inp, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out) + out = disco_kernels.backward.default(inp, roff_idx, ker_idx, row_idx, col_idx, vals, weights, kernel_size, nlat_out, nlon_out) out = out.to(itype) return out @@ -82,7 +82,7 @@ def _disco_s2_transpose_contraction_optimized( @torch.library.register_fake("disco_kernels::_disco_s2_contraction_optimized") def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, - kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + weights: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: out_shape = (inp.shape[0], inp.shape[1], kernel_size, nlat_out, nlon_out) return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) @@ -90,13 +90,13 @@ def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, @torch.library.register_fake("disco_kernels::_disco_s2_transpose_contraction_optimized") def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, - kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + weights: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: out_shape = (inp.shape[0], inp.shape[1], nlat_out, nlon_out) return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) #general routines: this is the same for forward and transpose def _setup_context_conv_backward(ctx, inputs, output): - inp, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out = inputs + inp, roff_idx, ker_idx, row_idx, col_idx, vals, _, kernel_size, nlat_out, nlon_out = inputs ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals) ctx.kernel_size = kernel_size ctx.nlat_in = inp.shape[-2] @@ -110,12 +110,12 @@ def _disco_s2_contraction_bwd_optimized(ctx, grad_output): gtype = grad_output.dtype grad_output = grad_output.to(torch.float32).contiguous() grad_input = disco_kernels.backward.default(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals, - ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) + torch.empty(0), ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) # Mauro grad_input = grad_input.to(gtype) else: grad_input = None - return grad_input, None, None, None, None, None, None, None, None + return grad_input, None, None, None, None, None, None, None, None, None # Mauro: added a None for weights if optimized_kernels_is_available(): torch.library.register_autograd( @@ -129,12 +129,12 @@ def _disco_s2_transpose_contraction_bwd_optimized(ctx, grad_output): gtype = grad_output.dtype grad_output = grad_output.to(torch.float32).contiguous() grad_input = disco_kernels.forward.default(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals, - ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) + torch.empty(0), ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) # Mauro grad_input = grad_input.to(gtype) else: grad_input = None - return grad_input, None, None, None, None, None, None, None, None + return grad_input, None, None, None, None, None, None, None, None, None # Mauro: added a None for weights if optimized_kernels_is_available(): torch.library.register_autograd( diff --git a/torch_harmonics/disco/convolution.py b/torch_harmonics/disco/convolution.py index 83f11c82..cedfdef8 100644 --- a/torch_harmonics/disco/convolution.py +++ b/torch_harmonics/disco/convolution.py @@ -38,6 +38,10 @@ import torch import torch.nn as nn +import nvtx + +import numpy as np + from functools import partial from torch_harmonics.cache import lru_cache @@ -502,22 +506,49 @@ def extra_repr(self): def psi_idx(self): return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous() + @nvtx.annotate("forward", color="purple") def forward(self, x: torch.Tensor) -> torch.Tensor: + #print("input x.shape:", x.shape) + if self.optimized_kernel: - x = _disco_s2_contraction_optimized( - x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out - ) + with nvtx.annotate("_disco_s2_contraction_optimized", color="red"): + out = _disco_s2_contraction_optimized( + #x = _disco_s2_contraction_optimized( + x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, + self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2]), + self.kernel_size, self.nlat_out, self.nlon_out + ) else: x = _disco_s2_contraction_torch(x, self.psi.to(x.device), self.nlon_out) + + #print("y.shape:", x.shape) + #print("self.groups:", self.groups, "self.groupsize:", self.groupsize) + #print("weight.shape:", self.weight.shape) + #pippo = self.weight.clone() + #pippo = pippo.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2]) + #print("after reshape, weight.shape:", pippo.shape) + + # extract shape + B, C, K, H, W = x.shape + with nvtx.annotate("reshape", color="blue"): + x = x.reshape(B, self.groups, self.groupsize, K, H, W) + + #print("after reshape, x.shape:", x.shape) + + # do weight multiplication + with nvtx.annotate("einsum", color="blue"): + out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous() + #print("out.shape:", out.shape) + out = out.reshape(B, -1, H, W) + + #cpu_tensor = out.detach().cpu().numpy() + #np.savetxt('yout_einsum.ref.txt', cpu_tensor.flatten(), fmt='%.6f') + + print("weight.shape:", self.weight.shape) + print("after reshape, out.shape:", out.shape) + print("\n") - # extract shape - B, C, K, H, W = x.shape - x = x.reshape(B, self.groups, self.groupsize, K, H, W) - - # do weight multiplication - out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous() - out = out.reshape(B, -1, H, W) if self.bias is not None: out = out + self.bias.reshape(1, -1, 1, 1) diff --git a/torch_harmonics/disco/csrc/cudamacro.h b/torch_harmonics/disco/csrc/cudamacro.h new file mode 100644 index 00000000..0edef184 --- /dev/null +++ b/torch_harmonics/disco/csrc/cudamacro.h @@ -0,0 +1,47 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#define CHECK_CUDA(call) { \ + cudaError_t err = call; \ + if( cudaSuccess != err) { \ + fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \ + __FILE__, __LINE__, cudaGetErrorString( err) ); \ + exit(EXIT_FAILURE); \ + }} + +#define CHECK_ERROR(errorMessage) { \ + cudaError_t err = cudaGetLastError(); \ + if( cudaSuccess != err) { \ + fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \ + errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) );\ + exit(EXIT_FAILURE); \ + }} diff --git a/torch_harmonics/disco/csrc/disco_cpu.cpp b/torch_harmonics/disco/csrc/disco_cpu.cpp index 8a544e66..b959802f 100644 --- a/torch_harmonics/disco/csrc/disco_cpu.cpp +++ b/torch_harmonics/disco/csrc/disco_cpu.cpp @@ -34,7 +34,7 @@ namespace disco_kernels { // cpu ops torch::Tensor disco_cpu_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, - torch::Tensor col_idx, torch::Tensor vals, int64_t K, int64_t Ho, int64_t Wo) { + torch::Tensor col_idx, torch::Tensor vals, torch::Tensor weights, int64_t K, int64_t Ho, int64_t Wo) { // sanity checks CHECK_CPU_INPUT_TENSOR(inp); @@ -64,7 +64,7 @@ namespace disco_kernels { } torch::Tensor disco_cpu_bwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, - torch::Tensor col_idx, torch::Tensor vals, int64_t K, int64_t Ho, int64_t Wo) { + torch::Tensor col_idx, torch::Tensor vals, torch::Tensor weights, int64_t K, int64_t Ho, int64_t Wo) { // sanity checks CHECK_CPU_INPUT_TENSOR(inp); diff --git a/torch_harmonics/disco/csrc/disco_cuda.cuh b/torch_harmonics/disco/csrc/disco_cuda.cuh index ad8e276d..2a7fa7e5 100644 --- a/torch_harmonics/disco/csrc/disco_cuda.cuh +++ b/torch_harmonics/disco/csrc/disco_cuda.cuh @@ -40,7 +40,11 @@ CHECK_CUDA_TENSOR(x); \ CHECK_CONTIGUOUS_TENSOR(x) +// will come from ../../attention/csrc/attention_cuda_utils.cuh +#ifndef DIV_UP #define DIV_UP(a, b) (((a) + ((b)-1)) / (b)) +#endif + #define MIN_THREADS (64) #define ELXTH_MAX (32) @@ -49,10 +53,10 @@ namespace disco_kernels { // forward kernel torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, - torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo); + torch::Tensor col_idx, torch::Tensor val, torch::Tensor weights, int64_t K, int64_t Ho, int64_t Wo); // backward kernel torch::Tensor disco_cuda_bwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, - torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo); + torch::Tensor col_idx, torch::Tensor val, torch::Tensor weights, int64_t K, int64_t Ho, int64_t Wo); } diff --git a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu index 00f5e514..13d5cdda 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu @@ -199,7 +199,7 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t } torch::Tensor disco_cuda_bwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, - torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo) + torch::Tensor col_idx, torch::Tensor val, torch::Tensor weights, int64_t K, int64_t Ho, int64_t Wo) { // some sanity checks @@ -282,4 +282,4 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t m.impl("backward", &disco_cuda_bwd); } -} \ No newline at end of file +} diff --git a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu index 8482f76d..4c31e9f8 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu @@ -30,9 +30,34 @@ #include "disco.h" #include "disco_cuda.cuh" +#include "disco_cuda_utils.cuh" + +#define CHECK_CUDA(call) { \ + cudaError_t err = call; \ + if( cudaSuccess != err) { \ + fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \ + __FILE__, __LINE__, cudaGetErrorString( err) ); \ + exit(EXIT_FAILURE); \ + }} + +#define CHECK_ERROR(errorMessage) { \ + cudaError_t err = cudaGetLastError(); \ + if( cudaSuccess != err) { \ + fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \ + errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) ); \ + exit(EXIT_FAILURE); \ + }} + +#define THREADS (64) + +#define MAX_LOCAL_ARR_LEN (16) namespace disco_kernels { +void dump_tensor(const char *fname, at::Tensor t); +void dump_csr(const char *fname, at::Tensor roff, at::Tensor cols); +void dump_csr_linear(const char *fname, at::Tensor roff, at::Tensor kers, at::Tensor rows, at::Tensor cols, at::Tensor vals); + template __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int Ho, const int Wo, const int pscale, const int64_t *__restrict__ roff, const int64_t *__restrict__ kers, @@ -51,14 +76,13 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H const int64_t ker = kers[soff]; const int64_t row = rows[soff]; - inp += bidy * Hi * Wi; - out += bidy * K * Ho * Wo + ker * Ho * Wo + row * Wo; + inp += bidy*Hi*Wi; + out += bidy*K*Ho*Wo + ker*Ho*Wo + row*Wo; REAL_T __reg[ELXTH] = {0}; // align to larger supported fp type - extern __shared__ __align__( - sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)] + extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)] REAL_T *__sh = reinterpret_cast(__sh_ptr); int col_prev = cols[soff]; @@ -185,8 +209,769 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t return; } + +template +__global__ void pack_vals_k(const int64_t K, + const int64_t nlat_out, + const int64_t *__restrict__ row_off, + const VAL_T *__restrict__ val_dat, + VAL_T *__restrict__ val_pck) { + + const int tidx = threadIdx.x; + const int wid = blockIdx.x*blockDim.y + threadIdx.y; + if (wid >= nlat_out) { + return; + } + + const int64_t rbeg = row_off[wid]; + const int64_t rend = row_off[wid+1]; + + const int rlen = rend-rbeg; + + val_pck += rbeg*K; + + for(int off = tidx; off < rlen; off += blockDim.x) { + for(int ker = 0; ker < K; ker++) { + + val_pck[off*K + ker] = val_dat[ row_off[ker*nlat_out + wid] + off]; + } + } + + return; +} + + +// BEGIN VERSION WITH CHANNEL-LAST WITH 2D BLOCKS, 2ND DIM IDENTIFYING CHANNLES, NO EINSUM +#if 1 +template +__device__ void processCSR_Kpow2_shm_d(const int wo, + const int rlen, + const int nchan_in, // no. of input floats (not FLOATV_T!) elements along channel dim + const int nlon_in, + const int pscale, + const int K, + const float *__restrict__ x, + const int64_t *__restrict__ cols, + const FLOATV_T *__restrict__ vals, + FLOATV_T *__restrict__ shy) { + const int tidx = threadIdx.x; + + // only used in K_POWER_2==1 branch + const int log2_K = __ffs(K)-1; + + x += tidx >> log2_K; + vals += tidx & (K-1); + + const int BDIM_XdivK = BDIM_X >> log2_K; + + for(int off = 0; off < rlen; off++) { + + const int64_t col = cols[off]; + + const int hi = col / nlon_in; + const int wi = col - (hi*nlon_in); + + //const int wip = (wi + pscale*wo) % nlon_in; + // value of (wi + pscale*wo) < (Wi + (Wi/Wo)*Wo) = 2*Wi + // so we can replace the modulo with: + int wip = wi + pscale*wo; + if (wip >= nlon_in) wip -= nlon_in; + + const float *_x = x + int64_t(hi)*nlon_in*nchan_in + int64_t(wip)*nchan_in; + + // if BDIM_X is a multiple of K then "i*(j*BDIM_X) % K = const", + // so thread "i" only needs to read vals[off*K + (i % K)] to update the + // whole channel array + + const FLOATV_T myval = vals[0]; //vals[off*K + tidxModK]; + + for(int chan = tidx; chan < nchan_in*K; chan += BDIM_X) { // no. of vectors in nchan_in*K dim on intermediate out + + shy[chan] = __vadd(shy[chan], + __vmul(myval, + __vset(_x[0]))); + _x += BDIM_XdivK; + } + + vals += K; + } + return; +} + +template +__device__ void processCSR_Kanyv_shm_d(const int wo, + const int rlen, + const int nchan_in, // no. of input floats (not FLOATV_T!) elements along channel dim + const int nlon_in, + const int pscale, + const int K, + const float *__restrict__ x, + const int64_t *__restrict__ cols, + const FLOATV_T *__restrict__ vals, + FLOATV_T *__restrict__ shy) { + const int tidx = threadIdx.x; + + for(int off = 0; off < rlen; off++) { + + const int64_t col = cols[off]; + + const int hi = col / nlon_in; + const int wi = col - (hi*nlon_in); + + //const int wip = (wi + pscale*wo) % nlon_in; + // value of (wi + pscale*wo) < (Wi + (Wi/Wo)*Wo) = 2*Wi + // so we can replace the modulo with: + int wip = wi + pscale*wo; + if (wip >= nlon_in) wip -= nlon_in; + + const float *_x = x + int64_t(hi)*nlon_in*nchan_in + int64_t(wip)*nchan_in; + + // if BDIM_X is not a multiple of K then "i*(j*BDIM_X) % K = f(i,j)", + // so the mod need to be recomputed at each iteration of update the update loop + for(int chan = tidx; chan < nchan_in*K; chan += BDIM_X) { // no. of vectors in nchan_in*K dim on intermediate out + + const int iDivK = chan / K; + const int iModK = chan - (iDivK*K); + + shy[chan] = __vadd(shy[chan], + __vmul(vals[iModK], + __vset(_x[iDivK]))); + } + + vals += K; + } + return; +} + +template // either float or float4 +__global__ +__launch_bounds__(BDIM) +void s2_disco_fwd_generic_vec_k(int nchan_in, // no. of input float (not FLOATV_T!) elements along channel dim + int nlat_in, + int nlon_in, + int nlat_out, + int nlon_out, + int pscale, + int K, // no. of output FLOATV_T elem along K dim (kernel size) + const float *__restrict__ x, + const int32_t *__restrict__ row_idx, + const int64_t *__restrict__ row_off, + const int64_t *__restrict__ col_idx, + const FLOATV_T *__restrict__ val_pck, + FLOATV_T *__restrict__ y) { + + constexpr int VEC_SIZE = sizeof(FLOATV_T) / sizeof(float); + + const int tidx = threadIdx.x; + + const int batch = blockIdx.y; + const int ctaid = blockIdx.x*blockDim.y + threadIdx.y; + + if (ctaid >= nlat_out*nlon_out) { + return; + } + +#if 1 + const int h = ctaid / nlon_out; + const int wo = ctaid - (h*nlon_out); + const int ho = row_idx[h]; +#else + // for now don't use row_idx + const int ho = wid / nlon_out; + const int wo = wid - (ho*nlon_out); +#endif + + const int nchan_out = nchan_in*K; + + extern __shared__ __align__(sizeof(float4)) float shext[]; + FLOATV_T *shy = reinterpret_cast(shext) + threadIdx.y*nchan_out; + + for(int chan = tidx; chan < nchan_out; chan += WARP_SIZE) { + shy[chan] = __vset(0.f); + } + + x += int64_t(batch)*nlat_in*nlon_in*nchan_in; + y += int64_t(batch)*nlat_out*nlon_out*nchan_out + int64_t(ho)*nlon_out*nchan_out + int64_t(wo)*nchan_out; + + const int64_t rbeg = row_off[ho]; + const int64_t rend = row_off[ho+1]; + + col_idx += rbeg; + val_pck += rbeg*K; // val_pck CSR contains K values per element + + const int rlen = rend-rbeg; + + // check if BDIM_X is a multiple of K; since BDIM_X is a power of 2, check if K is also a power of two + if (!(K & K-1)) { processCSR_Kpow2_shm_d(wo, rlen, nchan_in, nlon_in, pscale, K, x, col_idx, val_pck, shy); } + else { processCSR_Kanyv_shm_d(wo, rlen, nchan_in, nlon_in, pscale, K, x, col_idx, val_pck, shy); } + + for(int chan = tidx; chan < nchan_out; chan += WARP_SIZE) { + y[chan] = shy[chan]; + } + + return; +} + +template +__device__ void processCSR_Kpow2_reg_d(const int wo, + const int rlen, + const int nchan_in, // no. of input floats (not FLOATV_T!) elements along channel dim + const int nlon_in, + const int pscale, + const int K, // kernel size + const float *__restrict__ x, + const int64_t *__restrict__ cols, + const FLOATV_T *__restrict__ vals, + const float *(&shXOff)[BDIM_X+SHPAD], + FLOATV_T (&locy)[NLOC]) { + + constexpr int NLOC_M1 = NLOC-1; + + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + + // unused if BDIM_X > WARP_SIZE + unsigned int subwarp_mask = FULL_MASK; + + if constexpr(BDIM_X <= WARP_SIZE) { + constexpr unsigned int MASK = (1ull << BDIM_X)-1; + subwarp_mask = MASK << (tidy*BDIM_X); + } + + // only used in K_POWER_2==1 branch + const int log2_K = __ffs(K)-1; + + const int tidxDivK = tidx >> log2_K; + const int tidxModK = tidx & (K-1); + + cols += tidx; + vals += tidxModK; + + const int BDIM_XdivK = BDIM_X >> log2_K; + + for(int off = 0; off < rlen; off++) { + + if ((off % BDIM_X) == 0) { + if constexpr(BDIM_X <= WARP_SIZE) { __syncwarp(subwarp_mask); } + else { __syncthreads(); } + + const int64_t col = (off+tidx < rlen) ? cols[0] : 0; + + const int hi = col / nlon_in; + const int wi = col - (hi*nlon_in); + + //const int wip = (wi + pscale*wo) % nlon_in; + // value of (wi + pscale*wo) < (Wi + (Wi/Wo)*Wo) = 2*Wi + // so we can replace the modulo with: + int wip = wi + pscale*wo; + if (wip >= nlon_in) wip -= nlon_in; + + shXOff[tidx] = x + int64_t(hi)*nlon_in*nchan_in + int64_t(wip)*nchan_in; + cols += BDIM_X; + + if constexpr(BDIM_X <= WARP_SIZE) { __syncwarp(subwarp_mask); } + else { __syncthreads(); } + } + + const float *_x = shXOff[off % BDIM_X] + tidxDivK; + + // if BDIM_X is a multiple of K then "i*(j*BDIM_X) % K = const", + // so thread "i" only needs to read vals[off*K + (i % K)] to update the + // whole channel array + + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + locy[i] = __vadd(locy[i], + __vmul(vals[0], + __vset(_x[i*BDIM_XdivK]))); + } + if (NLOC_M1*BDIM_X+tidx < nchan_in*K) { + locy[NLOC_M1] = __vadd(locy[NLOC_M1], + __vmul(vals[0], + __vset(_x[NLOC_M1*BDIM_XdivK]))); + } + + vals += K; + } + return; +} + +template +__device__ void processCSR_Kanyv_reg_d(const int wo, + const int rlen, + const int nchan_in, // no. of input floats (not FLOATV_T!) elements along channel dim + const int nlon_in, + const int pscale, + const int K, // kernel size + const float *__restrict__ x, + const int64_t *__restrict__ cols, + const FLOATV_T *__restrict__ vals, + const float *(&shXOff)[BDIM_X+SHPAD], + FLOATV_T (&locy)[NLOC]) { + + constexpr int NLOC_M1 = NLOC-1; + + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + + // unused if BDIM_X > WARP_SIZE + unsigned int subwarp_mask = 0xFFFFFFFF; + + if constexpr(BDIM_X <= WARP_SIZE) { + constexpr unsigned int MASK = (1ull << BDIM_X)-1; + subwarp_mask = MASK << (tidy*BDIM_X); + } + + cols += tidx; + + for(int off = 0; off < rlen; off++) { + + if ((off % BDIM_X) == 0) { + if constexpr(BDIM_X <= WARP_SIZE) { __syncwarp(subwarp_mask); } + else { __syncthreads(); } + + const int64_t col = (off+tidx < rlen) ? cols[0] : 0; + + const int hi = col / nlon_in; + const int wi = col - (hi*nlon_in); + + //const int wip = (wi + pscale*wo) % nlon_in; + // value of (wi + pscale*wo) < (Wi + (Wi/Wo)*Wo) = 2*Wi + // so we can replace the modulo with: + int wip = wi + pscale*wo; + if (wip >= nlon_in) wip -= nlon_in; + + shXOff[tidx] = x + int64_t(hi)*nlon_in*nchan_in + int64_t(wip)*nchan_in; + cols += BDIM_X; + + if constexpr(BDIM_X <= WARP_SIZE) { __syncwarp(subwarp_mask); } + else { __syncthreads(); } + } + + const float *_x = shXOff[off % BDIM_X]; + + // if BDIM_X is not a multiple of K then "i*(j*BDIM_X) % K = f(i,j)", + // so the mod need to be recomputed at each iteration of update the update loop + + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + + const int chan = i*BDIM_X+tidx; + const int iDivK = chan / K; + const int iModK = chan - (iDivK*K); + + const FLOATV_T vval = vals[iModK]; //vals[off*K + iModK]; + const FLOATV_T xval = __vset(_x[iDivK]); + + locy[i] = __vadd(locy[i], __vmul(vval, xval)); + } + if (NLOC_M1*BDIM_X+tidx < nchan_in*K) { + + const int chan = NLOC_M1*BDIM_X+tidx; + const int iDivK = chan / K; + const int iModK = chan - (iDivK*K); + + const FLOATV_T vval = vals[iModK]; //vals[off*K + iModK]; + const FLOATV_T xval = __vset(_x[iDivK]); + + locy[NLOC_M1] = __vadd(locy[NLOC_M1], __vmul(vval, xval)); + } + + vals += K; + } + return; +} + +template // either float or float4 +__global__ +__launch_bounds__(BDIM_X*BDIM_Y) +void s2_disco_fwd_special_vec_k(const int nchan_in, // no. of input float (not FLOATV_T!) elements along channel dim + const int nlat_in, + const int nlon_in, + const int nlat_out, + const int nlon_out, + const int pscale, + const int K, // no. of output FLOATV_T elem along K dim (kernel size) + const float *__restrict__ x, + const int32_t *__restrict__ row_idx, + const int64_t *__restrict__ row_off, + const int64_t *__restrict__ col_idx, + const FLOATV_T *__restrict__ val_pck, + FLOATV_T *__restrict__ y) { + + static_assert(0 == (BDIM_X & (BDIM_X-1))); + static_assert(0 == (BDIM_Y & (BDIM_Y-1))); + static_assert((BDIM_X <= 32 && BDIM_Y > 1) || + (BDIM_X > 32 && BDIM_Y == 1)) ; + + constexpr int NLOC_M1 = NLOC-1; + + constexpr int VEC_SIZE = sizeof(FLOATV_T) / sizeof(float); + + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + + const int batch = blockIdx.y; + const int ctaid = blockIdx.x*blockDim.y + threadIdx.y; + + if (ctaid >= nlat_out*nlon_out) { + return; + } + +#if 1 + const int h = ctaid / nlon_out; + const int wo = ctaid - (h*nlon_out); + const int ho = row_idx[h]; +#else + // for now don't use row_idx + const int ho = ctaid / nlon_out; + const int wo = ctaid - (ho*nlon_out); +#endif + + const int nchan_out = nchan_in*K; + + FLOATV_T locy[NLOC]; + + x += int64_t(batch)*nlat_in*nlon_in*nchan_in; + y += int64_t(batch)*nlat_out*nlon_out*nchan_out + int64_t(ho)*nlon_out*nchan_out + int64_t(wo)*nchan_out + tidx; + + #pragma unroll + for(int i = 0; i < NLOC; i++) { + locy[i] = __vset(0.f); + } + + const int64_t rbeg = row_off[ho]; + const int64_t rend = row_off[ho+1]; + + col_idx += rbeg; + val_pck += rbeg*K; // val_pck CSR contains K values per element + + const int rlen = rend-rbeg; + + constexpr int PAD = (BDIM_X < WARP_SIZE) ? 1 : 0; + __shared__ const float *shXOffAll[BDIM_Y][BDIM_X+PAD]; + + // check if BDIM_X is a multiple of K; since BDIM_X is a power of 2, check if K is also a power of two + const int isKpow2 = !(K & (K-1)); + if (isKpow2) { processCSR_Kpow2_reg_d(wo, rlen, nchan_in, nlon_in, pscale, K, x, col_idx, val_pck, shXOffAll[tidy], locy); } + else { processCSR_Kanyv_reg_d(wo, rlen, nchan_in, nlon_in, pscale, K, x, col_idx, val_pck, shXOffAll[tidy], locy); } + + + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + y[i*BDIM_X] = locy[i]; + } + if (NLOC_M1*BDIM_X+tidx < nchan_out) { + y[NLOC_M1*BDIM_X] = locy[NLOC_M1]; + } + + return; +} + +template +void launch_gen_disco_fwd(int64_t batch_size, + int64_t nchan_in, + int64_t nlat_in, + int64_t nlon_in, + int64_t nlat_out, + int64_t nlon_out, + int64_t K, + float *__restrict__ _xp, + int32_t *_row_idx, + int64_t *_row_off, + int64_t *_col_idx, + FLOATV_T *_val_pck, + FLOATV_T *__restrict__ _yp, + cudaStream_t stream) { + + dim3 block(WARP_SIZE, THREADS/WARP_SIZE); + dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size); + + size_t shsize = sizeof(FLOATV_T)*(nchan_in*K)*block.y; + + const int pscale = nlon_in / nlon_out; +#if 0 + printf("Launching s2_disco_fwd_generic_vec_k<%d, float%s><<<..., ..., %zu, ...>>> with:\n" + "\tngroup: %ld\n" + "\tnchan_in: %ld\n" + "\tK: %ld\n\n", + THREADS, sizeof(FLOATV_T)==16?"4":"", shsize, ngroup, nchan_in, K); +#endif + // will use only the first 1/K-th of the CSR, i.e. only the first nlat_out rows + s2_disco_fwd_generic_vec_k + <<>>(nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, pscale, K, + _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp); + CHECK_ERROR("s2_disco_fwd_generic_vec_k"); + + return; +} + +template +void launch_spc_disco_fwd(int nloc, // "BDIM_X*nloc" >= nchans + int64_t batch_size, + int64_t nchan_in, + int64_t nlat_in, + int64_t nlon_in, + int64_t nlat_out, + int64_t nlon_out, + int64_t K, + float *__restrict__ _xp, + int32_t *_row_idx, + int64_t *_row_off, + int64_t *_col_idx, + FLOATV_T *_val_pck, + FLOATV_T *__restrict__ _yp, + cudaStream_t stream) { + + if (CUR_LOC_SIZE == nloc) { + + // block size set to 64 threads + constexpr int BDIM_Y = (BDIM_X <= WARP_SIZE) ? THREADS / BDIM_X : 1; + + // groups in gridDim.y + dim3 block(BDIM_X, BDIM_Y); + dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size); + + size_t shsize = 0; //sizeof(float)*chxgrp_out * block.y; + + const int pscale = nlon_in / nlon_out; +#if 0 + printf("Launching s2_disco_fwd_special_vec_k<%d, %d, %d, float%s><<<(%d, %d, %d), (%d, %d), ..., %zu, ...>>> with:\n" + "\tngroup: %ld\n" + "\tnchan_in: %ld\n" + "\tK: %ld\n\n", + BDIM_X, BDIM_Y, CUR_LOC_SIZE, sizeof(FLOATV_T)==16?"4":"", grid.x, grid.y, grid.z, block.x, block.y, shsize, ngroup, nchan_in, K); +#endif + s2_disco_fwd_special_vec_k + <<>>(nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, pscale, K, + _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp); + + CHECK_ERROR("s2_disco_fwd_special_vec_k"); + + return; + } + if constexpr(CUR_LOC_SIZE < MAX_LOC_SIZE) { + launch_spc_disco_fwd(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, + K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); + } + return; +} + +static void s2_disco_fwd_dispatch(int64_t batch_size, + int64_t nchan_in, + int64_t nlon_in, + int64_t nlat_out, + int64_t nlon_out, + int64_t K, + at::Tensor xP, + at::Tensor row_off, // CSR row offsets + at::Tensor col_idx, // CSR column indices + at::Tensor val_dat, // CSR value data + at::Tensor yP) { + + static_assert(0 == (MAX_LOCAL_ARR_LEN & (MAX_LOCAL_ARR_LEN-1))); + + if (batch_size <= 0 || + nchan_in <= 0 || + nlon_in <= 0 || + nlat_out <= 0 || + nlon_out <= 0 || + K <= 0) { + + fprintf(stderr, + ":%s:%d: invalid value of one or more input parameters!\n", + __FILE__, __LINE__); + exit(EXIT_FAILURE); + } + + // get stream + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + // sort row indices (ho-s) in descending order + // based on (row_off[ho+1]-row_off[ho]) + at::Tensor row_idx = sortRows(nlat_out, row_off, stream); + + + // replace the K sequential CRSs in "val_dat": + // + // val_dat[ 0: nnz/K) for ker = 0 + // val_dat[nnz/K:2*nnz/K) for ker = 1 + // ... + // val_dat[nnz/K:2*nnz/K) for ker = K-1 + // + // with a packed CSR: + // + // val_dat[nnz/K][K], i.e. with a CSR where elements of the original K CSRs are packed in consecutive elements + assert(0 == (val_idx.size(0) % K)); + + // move into "disco_cuda_utils.cu" IF val_dat format won't be changed upstream in the call chain + int64_t val_dims[] = {val_dat.size(0)}; + auto options = torch::TensorOptions().device(val_dat.device()).dtype(val_dat.dtype()); + torch::Tensor val_pck = torch::zeros(val_dims, options); + { + dim3 block(WARP_SIZE, THREADS/WARP_SIZE); + dim3 grid(DIV_UP(nlat_out, block.y)); + + pack_vals_k<<>>(K, nlat_out, + row_off.data_ptr(), + val_dat.data_ptr(), + val_pck.data_ptr()); + } + // if K is a multiple of VEC_SIZE it will be read with vector lds + + const int nlat_in = xP.size(1); + + // smallest power of two "bdimx" (>=4) s.t. bdimx*MAX_LOCAL_ARR_LEN >= nchan_in*K + int bdimx; + bdimx = DIV_UP(nchan_in*K, MAX_LOCAL_ARR_LEN); + bdimx = max(bdimx, WARP_SIZE/8); // min 4 threads per group + bdimx = next_pow2(bdimx); + + float *_xp = reinterpret_cast(xP.data_ptr()); + float *_yp = reinterpret_cast(yP.data_ptr()); + + int32_t *_row_idx = reinterpret_cast(row_idx.data_ptr()); + int64_t *_row_off = reinterpret_cast(row_off.data_ptr()); + int64_t *_col_idx = reinterpret_cast(col_idx.data_ptr()); + float *_val_pck = reinterpret_cast(val_pck.data_ptr()); + + constexpr int VEC_SIZE = sizeof(float4) / sizeof(float); + + if (!is_aligned(_yp) || + !is_aligned(_val_pck) || + (K % VEC_SIZE) != 0) { + + //printf("%s:%d: VEC_SIZE: %d, nchan_in: %d, K: %d, _xp: %p, _yp: %p\n", __func__, __LINE__, VEC_SIZE, nchan_in, K, _xp, _yp); + + const int nloc = DIV_UP(nchan_in*K, bdimx); + + // to avoid the compilation of unused template instances; + // we use a block size BDIM_X that is the smallest power of 2 + // such that BDIM_X*MAX_LOCAL_ARR_LEN >= nchan_in*K, so + // BDIM_X > 32 are used only for: + // + // (BDIM_X-1)*MAX_LOCAL_ARR_LEN < nchan <= BDIM_X*MAX_LOCAL_ARR_LEN + constexpr int MIN_LOC_ARR_LEN = MAX_LOCAL_ARR_LEN/2+1; + + // use 2D blocks only if 32 threads are enough + switch(bdimx) { + case 8: launch_spc_disco_fwd< 8, 1, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; + case 16: launch_spc_disco_fwd< 16, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; + case 32: launch_spc_disco_fwd< 32, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; + case 64: launch_spc_disco_fwd< 64, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; + case 128: launch_spc_disco_fwd< 128, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; + case 256: launch_spc_disco_fwd< 256, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; + case 512: launch_spc_disco_fwd< 512, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; + case 1024: launch_spc_disco_fwd<1024, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; + default: launch_gen_disco_fwd ( batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; + } + + } else { + + //printf("%s:%d: VEC_SIZE: %d, nchan_in: %d, K: %d, _xp: %p, _yp: %p\n", __func__, __LINE__, VEC_SIZE, nchan_in, K, _xp, _yp); + + //float4 *_xp4 = reinterpret_cast(_xp); + float4 *_yp4 = reinterpret_cast(_yp); + + float4 *_val_pck4 = reinterpret_cast(_val_pck); + + K /= VEC_SIZE; + const int nloc = DIV_UP(nchan_in*K, bdimx); + + constexpr int MAX_LOCAL_VEC_LEN = MAX_LOCAL_ARR_LEN / VEC_SIZE; + constexpr int MIN_LOC_VEC_LEN = MAX_LOCAL_VEC_LEN/2+1; + + // use 2D blocks only if 32 threads are enough + switch(bdimx) { + case 8: launch_spc_disco_fwd< 8, 1, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck4, _yp4, stream); break; + case 16: launch_spc_disco_fwd< 16, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck4, _yp4, stream); break; + case 32: launch_spc_disco_fwd< 32, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck4, _yp4, stream); break; + case 64: launch_spc_disco_fwd< 64, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck4, _yp4, stream); break; + case 128: launch_spc_disco_fwd< 128, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck4, _yp4, stream); break; + case 256: launch_spc_disco_fwd< 256, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck4, _yp4, stream); break; + case 512: launch_spc_disco_fwd< 512, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck4, _yp4, stream); break; + case 1024: launch_spc_disco_fwd<1024, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck4, _yp4, stream); break; + default: launch_gen_disco_fwd ( batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck4, _yp4, stream); break; + } + } + return; +} +#endif +// END VERSION WITH CHANNEL-LAST WITH 2D BLOCKS, 2ND DIM IDENTIFYING CHANNLES, NO EINSUM + + + + + + + + + // utility functions + void dump_out_kers(const char *fprefix, at::Tensor t) { + + int64_t B = t.size(0); + int64_t C = t.size(1); + int64_t K = t.size(2); + int64_t Ho = t.size(3); + int64_t Wo = t.size(4); + + at::Tensor t_h = t.to(torch::kCPU); + + auto accessor = t_h.accessor(); + + printf("Writing data to file..."); + + char fname[256]; + + for(size_t k = 0; k < K; k++) { + + snprintf(fname, sizeof(fname), "%s_%ld.txt", fprefix, k); + + FILE *fp = fopen(fname, "w"); + if (!fp) { + fprintf(stderr, "Cannot open file %s for writing!\n", fname); + exit(EXIT_FAILURE); + } + for(int64_t b = 0; b < B; b++) { + fprintf(fp, "b: %ld\n", b); + for(int64_t c = 0; c < C; c++) { + fprintf(fp, "c: %ld\n", c); + for(int64_t h = 0; h < Ho; h++) { + for(int64_t w = 0; w < Wo; w++) { + fprintf(fp, " %f", accessor[b][c][k][h][w]); + } + fprintf(fp, "\n"); + } + fprintf(fp, "\n"); + } + fprintf(fp, "\n"); + } + fclose(fp); + } + printf("done\n"); + + return; + } + torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, - torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo) + torch::Tensor col_idx, torch::Tensor val, torch::Tensor weights, int64_t K, int64_t Ho, int64_t Wo) { // some sanity checks @@ -205,16 +990,33 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t int64_t Wi = inp.size(3); int64_t nrows = roff_idx.size(0) - 1; - // allocate output - int64_t out_dims[] = {B, C, K, Ho, Wo}; - auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype()); - torch::Tensor out = torch::zeros(out_dims, options); + // rename dimensions consistent with attention + int64_t batch_size = inp.size(0); + int64_t nchan = inp.size(1); + int64_t nlat_in = inp.size(2); + int64_t nlon_in = inp.size(3); + int64_t nlat_out = Ho; + int64_t nlon_out = Wo; +/* + int64_t ngroup = 1; + if (std::getenv("S2_NGROUP")) { + ngroup = atoi(std::getenv("S2_NGROUP")); + } +*/ + printf("%s:%d: batch_size: %ld, nchan: %ld, nlat_in: %ld, nlon_in: %ld, nlat_out: %ld, nlon_out: %ld, nrows: %ld, nnz_tot: %ld, K: %ld\n", + __func__, __LINE__, batch_size, nchan, nlat_in, nlon_in, nlat_out, nlon_out, nrows, col_idx.size(0), K); // get stream auto stream = at::cuda::getCurrentCUDAStream().stream(); // assert static_assert(0 == (ELXTH_MAX % 2)); +#if 0 + // allocate output + int64_t out_dims[] = {B, C, K, Ho, Wo}; + auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype()); + torch::Tensor out = torch::zeros(out_dims, options); + // pick the correct launch config if (Wo <= 64 * ELXTH_MAX) { @@ -262,7 +1064,203 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t 1024 * ELXTH_MAX); exit(EXIT_FAILURE); } +#else + +#if 0 // FUSED VERSION + // switch to channel-last + // version with fused enisum + + int64_t ngroup = weights.size(0); + int64_t chan_x_grp_out = weights.size(1); + int64_t chan_x_grp_in = weights.size(2); + int64_t weight_k = weights.size(3); + + int64_t nchan_out = ngroup*chan_x_grp_out; + + printf("weight tensor shape: %ld, %ld, %ld, %ld\n", ngroup, chan_x_grp_out, chan_x_grp_in, weight_k); fflush(stdout); + + if (nchan != chan_x_grp_in*ngroup || K != weight_k) { + fprintf(stderr, + "%s:%d: error, dimension mismatch for weight tensor!\n", + __func__, __LINE__); + exit(EXIT_FAILURE); + } + + + // input: inp[B][Ci][Hi][Wi] -> inp[B][Hi][Wi][Ci] + // + // output: out[[B][Ho][Wo][Co] -> out[B][Co][Ho][Wo] + // with Co = ngroup*chan_x_grp_out + + // switch to channel-last + + // extract dtype + auto x_type = inp.dtype(); + torch::Tensor xP = inp.to(torch::kFloat32); + + // exract memory format: this is much safer than checking is_contiguous(at::MemoryFormat::ChannelsLast) + // the former fails for num_channels == 1 + bool x_is_channels_last = xP.strides()[1] == 1; + + // transpose if required + if (!x_is_channels_last) { xP = permute_4D_to0231(xP); } + +#if 1 + int64_t out_dims[] = {batch_size, nlat_out, nlon_out, nchan_out}; + auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype()); + torch::Tensor yP = torch::zeros(out_dims, options); // this will be empty_like() + // y is {batch_size, nlat_out, nlon_out, nchan_out}, +#else + // to test before fusion + int64_t out_dims[] = {batch_size, nlat_out, nlon_out, nchan*K}; + auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype()); + torch::Tensor yP = torch::zeros(out_dims, options); + // y is {batch_size, nlat_out, nlon_out, nchan*K}, +#endif + // call channel-last kernel implementation + s2_disco_fwd_dispatch(batch_size, + nchan, + nchan_out, + ngroup, + nlon_in, + //nlat_in, + nlat_out, + nlon_out, + K, + xP, + roff_idx, + col_idx, + val, + weights, + yP); + +#if 1 + // switch back to original layout; + // I'm assuming that if x was passed as channel last, then + // the output tensor should be K last + torch::Tensor y = yP; + if (!x_is_channels_last) { + y = permute_4D_to0312(y); + // make y {batch_size, nchan_out, nlat_out, nlon_out} + } +#else + // to test before fusion + torch::Tensor y = yP; + if (!x_is_channels_last) { + y = permute_4D_to0312(y); + // make y {batch_size, nchan, K, nlat_out, nlon_out} + y = y.reshape({batch_size, nchan, K, nlat_out, nlon_out}); + } else { + // make y {batch_size, nlat_out, nlon_out, nchan, K} + y = y.reshape({batch_size, nlat_out, nlon_out, nchan, K}); + } +#endif + +#else // VERSION WITH SEPARATED EINSUM + // switch to channel-last + // version with fused enisum + + int64_t ngroup = weights.size(0); + int64_t chan_x_grp_out = weights.size(1); + int64_t chan_x_grp_in = weights.size(2); + int64_t weight_k = weights.size(3); + + int64_t nchan_out = ngroup*chan_x_grp_out; + + printf("weight tensor shape: %ld, %ld, %ld, %ld\n", ngroup, chan_x_grp_out, chan_x_grp_in, weight_k); fflush(stdout); + + if (nchan != chan_x_grp_in*ngroup || K != weight_k) { + fprintf(stderr, + "%s:%d: error, dimension mismatch for weight tensor!\n", + __func__, __LINE__); + exit(EXIT_FAILURE); + } + + + // input: inp[B][Ci][Hi][Wi] -> inp[B][Hi][Wi][Ci] + // + // output: out[[B][Ho][Wo][Co] -> out[B][Co][Ho][Wo] + // with Co = ngroup*chan_x_grp_out + + + // switch to channel-last + + // extract dtype + auto x_type = inp.dtype(); + torch::Tensor xP = inp.to(torch::kFloat32); + + // exract memory format: this is much safer than checking is_contiguous(at::MemoryFormat::ChannelsLast) + // the former fails for num_channels == 1 + bool x_is_channels_last = xP.strides()[1] == 1; + + // transpose if required + if (!x_is_channels_last) { xP = permute_4D_to0231(xP); } + + // to test before fusion + int64_t out_dims[] = {batch_size, nlat_out, nlon_out, nchan*K}; + auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype()); + torch::Tensor yP = torch::empty(out_dims, options); + + // call channel-last kernel implementation + s2_disco_fwd_dispatch(batch_size, + nchan, + nlon_in, + //nlat_in, + nlat_out, + nlon_out, + K, + xP, + roff_idx, + col_idx, + val, + yP); + + // call einsum + + // yP is {batch_size, nlat_out, nlon_out, nchan_in*K}, + // reshape it to {batch_size, nlat_out, nlon_out, ngroup, chan_x_grp_in*K} + auto yP_resh = yP.reshape({batch_size, nlat_out, nlon_out, ngroup, -1}); + + // weight is {ngroup, chan_x_grp_out, chan_x_grp_in, K} + // reshape weight to {ngroup, chan_x_grp_out, chan_x_grp_in*K} + auto weights_resh = weights.reshape({ngroup, chan_x_grp_out, -1}); + + auto out_sum = torch::einsum("bxygc,goc->bxygo", {yP_resh, weights_resh}).contiguous(); + + // out is {batch_size, nlat_out, nlon_out, ngroup, chan_x_grp_out}, + // reshape it ot {batch_size, nlat_out, nlon_out, nchan_out}, + auto out_resh = out_sum.reshape({batch_size, nlat_out, nlon_out, -1}); + + // switch back to original layout; + // I'm assuming that if x was passed as channel last, then + // the output tensor should be K last + torch::Tensor y = out_resh; + if (!x_is_channels_last) { + y = permute_4D_to0312(y); + // make y {batch_size, nchan_out, nlat_out, nlon_out} + } + + //CHECK_CUDA(cudaDeviceSynchronize()); + + // convert precision back to starting + y = y.to(x_type); + + torch::Tensor out = y; +#endif + +#endif // closes ORIGINAL if +#if 1 + if (std::getenv("S2_DISCO_DUMP_Y")) { + printf("waiting for kernel to finish..."); + CHECK_CUDA(cudaStreamSynchronize(stream)); + printf("done\n"); + fflush(stdout); + dump_tensor("yout.txt", out); + //dump_csr_linear("csr_disco.txt", roff_idx, ker_idx, row_idx, col_idx, val); + //dump_out_kers("out_kers", out); + } +#endif return out; } @@ -271,4 +1269,166 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t m.impl("forward", &disco_cuda_fwd); } -} \ No newline at end of file + // utility functions + void dump_tensor(const char *fname, at::Tensor t) { + + size_t n = 1; + for(int i = 0; i < t.dim(); i++) { + n *= t.size(i); + } + + float *data_h = (float *)malloc(sizeof(*data_h)*n); + if (!data_h) { + fprintf(stderr, "Cannot allcoate %zu bytes!\n", sizeof(*data_h)*n); + exit(EXIT_FAILURE); + } + + float *float_d = t.data_ptr(); + + CHECK_CUDA(cudaMemcpy(data_h, float_d, sizeof(*data_h)*n, cudaMemcpyDeviceToHost)); + + printf("Writing data to file..."); + + FILE *fp = fopen(fname, "w"); + if (!fp) { + fprintf(stderr, "Cannot open file %s for writing!\n", fname); + exit(EXIT_FAILURE); + } + + for(size_t i = 0; i < n; i++) { + fprintf(fp, "%f\n", data_h[i]); + } + + fclose(fp); + printf("done\n"); + + free(data_h); + + return; + } + + void dump_csr(const char *fname, + at::Tensor roff, + at::Tensor cols) { + + int64_t nrows = roff.size(0)-1; + int64_t nnz = cols.size(0); + + int64_t *roff_h = new int64_t[nrows+1]; + int64_t *cols_h = new int64_t[nnz]; + + int64_t *roff_d = roff.data_ptr(); + int64_t *cols_d = cols.data_ptr(); + + CHECK_CUDA(cudaMemcpy(roff_h, roff_d, sizeof(*roff_h)*(nrows+1), cudaMemcpyDeviceToHost)); + CHECK_CUDA(cudaMemcpy(cols_h, cols_d, sizeof(*cols_d)*nnz , cudaMemcpyDeviceToHost)); + + printf("Writing data to file..."); + + FILE *fp = fopen(fname, "w"); + if (!fp) { + fprintf(stderr, "Cannot open file %s for writing!\n", fname); + exit(EXIT_FAILURE); + } + for(int64_t r = 0; r < nrows; r++) { + + fprintf(fp, "%10ld %10ld", r, roff_h[r+1]-roff_h[r]); + + for(int64_t o = roff_h[r]; o < roff_h[r+1]; o++) { + fprintf(fp, "%10ld", cols_h[o]); + } + fprintf(fp, "\n"); + } + fclose(fp); + printf("done\n"); + + delete [] roff_h; + delete [] cols_h; + } + + void dump_csr_linear(const char *fname, + at::Tensor roff, + at::Tensor kers, + at::Tensor rows, + at::Tensor cols, + at::Tensor vals) { + + int64_t nrows = roff.size(0)-1; + int64_t nnz = cols.size(0); + + int64_t *roff_h = new int64_t[nrows+1]; + int64_t *kers_h = new int64_t[nnz]; + int64_t *rows_h = new int64_t[nnz]; + int64_t *cols_h = new int64_t[nnz]; + float *vals_h = new float[nnz]; + + int64_t *roff_d = roff.data_ptr(); + int64_t *kers_d = kers.data_ptr(); + int64_t *rows_d = rows.data_ptr(); + int64_t *cols_d = cols.data_ptr(); + float *vals_d = vals.data_ptr(); + + CHECK_CUDA(cudaMemcpy(roff_h, roff_d, sizeof(*roff_h)*(nrows+1), cudaMemcpyDeviceToHost)); + CHECK_CUDA(cudaMemcpy(kers_h, kers_d, sizeof(*kers_h)*nnz , cudaMemcpyDeviceToHost)); + CHECK_CUDA(cudaMemcpy(rows_h, rows_d, sizeof(*rows_h)*nnz , cudaMemcpyDeviceToHost)); + CHECK_CUDA(cudaMemcpy(cols_h, cols_d, sizeof(*cols_h)*nnz , cudaMemcpyDeviceToHost)); + CHECK_CUDA(cudaMemcpy(vals_h, vals_d, sizeof(*vals_h)*nnz , cudaMemcpyDeviceToHost)); + + printf("Writing data to file..."); + + FILE *fp = fopen(fname, "w"); + if (!fp) { + fprintf(stderr, "Cannot open file %s for writing!\n", fname); + exit(EXIT_FAILURE); + } + fprintf(fp, "COLS:\n"); + for(int64_t r = 0; r < nrows; r++) { + + fprintf(fp, "%10ld %10ld", r, roff_h[r+1]-roff_h[r]); + + for(int64_t o = roff_h[r]; o < roff_h[r+1]; o++) { + fprintf(fp, "%10ld", cols_h[o]); + } + fprintf(fp, "\n"); + } + fprintf(fp, "KERS:\n"); + for(int64_t r = 0; r < nrows; r++) { + + fprintf(fp, "%10ld %10ld", r, roff_h[r+1]-roff_h[r]); + + for(int64_t o = roff_h[r]; o < roff_h[r+1]; o++) { + fprintf(fp, "%10ld", kers_h[o]); + } + fprintf(fp, "\n"); + } + fprintf(fp, "ROWS:\n"); + for(int64_t r = 0; r < nrows; r++) { + + fprintf(fp, "%10ld %10ld", r, roff_h[r+1]-roff_h[r]); + + for(int64_t o = roff_h[r]; o < roff_h[r+1]; o++) { + fprintf(fp, "%10ld", rows_h[o]); + } + fprintf(fp, "\n"); + } + fprintf(fp, "VALS:\n"); + for(int64_t r = 0; r < nrows; r++) { + + fprintf(fp, "%10ld %10ld", r, roff_h[r+1]-roff_h[r]); + + for(int64_t o = roff_h[r]; o < roff_h[r+1]; o++) { + fprintf(fp, "%10f", vals_h[o]); + } + fprintf(fp, "\n"); + } + fclose(fp); + printf("done\n"); + + delete [] roff_h; + delete [] kers_h; + delete [] rows_h; + delete [] cols_h; + delete [] vals_h; + } +} + diff --git a/torch_harmonics/disco/csrc/disco_cuda_utils.cu b/torch_harmonics/disco/csrc/disco_cuda_utils.cu new file mode 100644 index 00000000..b8ba3ca4 --- /dev/null +++ b/torch_harmonics/disco/csrc/disco_cuda_utils.cu @@ -0,0 +1,184 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "disco_cuda_utils.cuh" + +#include +#include +#include + +#include + +#include +#include + +#include "cudamacro.h" +#include "disco_cuda.cuh" + +#define THREADS (64) + +#define TRANSP_WARPS_X_TILE_GENERIC (32) +#define TRANSP_WARPS_X_TILE_SM100 (4) + +namespace disco_kernels { + +// BEGIN - CSR rows sorting kernels and functions +__global__ void set_rlen_rids_k(const int n, + const int64_t *__restrict__ offs, + int *__restrict__ rids, + int *__restrict__ rlen) { + + const int nth = gridDim.x*blockDim.x; + const int tid = blockIdx.x*blockDim.x + threadIdx.x; + + for(int i = tid; i < n; i += nth) { + rids[i] = i; + rlen[i] = offs[i+1]-offs[i]; + } + + return; +} + +at::Tensor sortRows(int nlat_out, at::Tensor row_off, cudaStream_t stream) { + + int64_t *_row_off_d = reinterpret_cast(row_off.data_ptr()); + + auto options = torch::TensorOptions().dtype(torch::kInt32).device(row_off.device()); + + torch::Tensor rids_d = torch::empty({nlat_out}, options); + torch::Tensor rlen_d = torch::empty({nlat_out}, options); + + int *_rids_d = reinterpret_cast(rids_d.data_ptr()); + int *_rlen_d = reinterpret_cast(rlen_d.data_ptr()); + + const int grid = DIV_UP(nlat_out, THREADS); + const int block = THREADS; + + set_rlen_rids_k<<>>(nlat_out, + _row_off_d, + _rids_d, + _rlen_d); + + torch::Tensor rids_sort_d = torch::empty({nlat_out}, options); + torch::Tensor rlen_sort_d = torch::empty({nlat_out}, options); + + int *_rids_sort_d = reinterpret_cast(rids_sort_d.data_ptr()); + int *_rlen_sort_d = reinterpret_cast(rlen_sort_d.data_ptr()); + + size_t temp_storage_bytes = 0; + CHECK_CUDA(cub::DeviceRadixSort::SortPairsDescending(NULL, temp_storage_bytes, + _rlen_d, _rlen_sort_d, + _rids_d, _rids_sort_d, + nlat_out, 0, sizeof(*_rlen_d)*8, stream)); + + options = torch::TensorOptions().dtype(torch::kByte).device(row_off.device()); + torch::Tensor temp_storage_d = torch::empty({int64_t(temp_storage_bytes)}, options); + + void *_temp_storage_d = reinterpret_cast(temp_storage_d.data_ptr()); + + CHECK_CUDA(cub::DeviceRadixSort::SortPairsDescending(_temp_storage_d, temp_storage_bytes, + _rlen_d, _rlen_sort_d, + _rids_d, _rids_sort_d, + nlat_out, 0, sizeof(*_rlen_d)*8, stream)); + return rids_sort_d; +} +// END - CSR rows sorting kernels and functions + + +// BEGIN - 4D tensor permutation kernels and functions +__global__ void empty_k() {} + +static int getPtxver() { + cudaFuncAttributes attrs; + CHECK_CUDA(cudaFuncGetAttributes(&attrs, empty_k)); + return attrs.ptxVersion*10; +} + +at::Tensor permute_4D_to0231(at::Tensor src) { + + auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device()); + torch::Tensor dst = torch::empty({src.size(0), src.size(2), src.size(3), src.size(1)}, options); + + const int ptxv = getPtxver(); + + // to be further specialized for additional archs, if necessary + if (ptxv < 100) { + AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_generic", ([&] { + launch_permute_to0231(src, dst); + })); + CHECK_ERROR("permute_to0231_k_tile_generic"); + } else { + AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_sm100", ([&] { + launch_permute_to0231(src, dst); + })); + CHECK_ERROR("permute_to0231_k_tile_sm100"); + } + + return dst; +} + +at::Tensor permute_4D_to0312(at::Tensor src) { + + auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device()); + torch::Tensor dst = torch::empty({src.size(0), src.size(3), src.size(1), src.size(2)}, options); + + const int ptxv = getPtxver(); + + // to be further specialized for additional archs, if necessary + if (ptxv < 100) { + AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_generic", ([&] { + launch_permute_to0312(src, dst); + })); + CHECK_ERROR("permute_to0312_k_tile_generic"); + } else { + AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_sm100", ([&] { + launch_permute_to0312(src, dst); + })); + CHECK_ERROR("permute_to0312_k_tile_sm100"); + } + + return dst; +} +// END - tensor permutation kernels and functions + +// BEGIN - general host-side functions +unsigned int next_pow2(unsigned int x) { + + x -= 1; + + #pragma unroll + for(int i = 1; i <= sizeof(x)*8 / 2; i *= 2) { + x |= x >> i; + } + return x+1; +} +// END - general host-side functions + +} diff --git a/torch_harmonics/disco/csrc/disco_cuda_utils.cuh b/torch_harmonics/disco/csrc/disco_cuda_utils.cuh new file mode 100644 index 00000000..76ccb527 --- /dev/null +++ b/torch_harmonics/disco/csrc/disco_cuda_utils.cuh @@ -0,0 +1,382 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include +#include +#include + +#define WARP_SIZE (32) +#define FULL_MASK (0xFFFFFFFF) + +#ifndef DIV_UP +#define DIV_UP(a,b) (((a)+((b)-1))/(b)) +#endif + +namespace disco_kernels { + +// CSR rows sorting kernels and functions +at::Tensor sortRows(int nlat_out, at::Tensor row_off, cudaStream_t stream); + +// 4D tensor permutation kernels and functions +at::Tensor permute_4D_to0231(at::Tensor src); +at::Tensor permute_4D_to0312(at::Tensor src); + +// Host tensor dump and CSR manipulation functions +void dump_tensor(const char *fname, at::Tensor t); +void dump_csr(const char *fname, at::Tensor roff, at::Tensor cols); + +int part_csr_rows(int *row_perm, + const at::Tensor roff, + const at::Tensor cols, + int **part_off, + int **part_val); + +int verify_part(const int npart, + const int *part_off, + const int *part_val, + const at::Tensor roff, + const at::Tensor cols); + +void verify_part_new(const int nlon_out, + const int nlat_in, + const int nlon_in, + const int npart, // partitioning data + const int *part_off, + const int *part_val, + const at::Tensor roff, + const at::Tensor cols); + +unsigned int next_pow2(unsigned int x); + + +// utility host functions and templates + +template +int is_aligned(const void *ptr) { + + static_assert(0 == (ALIGN & (ALIGN-1))); + return (0 == (uintptr_t(ptr) & (ALIGN-1))); +} + + +// utility device functions and templates + +template +__device__ FLOATV_T __vset(float x) { + static_assert(sizeof(FLOATV_T) == 0, "Unsupported type for __vset"); + return FLOATV_T{}; +} + +template<> +__device__ float __forceinline__ __vset(float x) { + return x; +} + +__device__ float __forceinline__ __vmul(float a, float b) { + return a*b; +} + +__device__ float __forceinline__ __vadd(float a, float b) { + return a+b; +} + +__device__ float __forceinline__ __vsub(float a, float b) { + return a-b; +} + +__device__ float __forceinline__ __vred(float a) { + return a; +} + +__device__ float __forceinline__ __vscale(float s, float v) { + return v*s; +} + +__device__ float __forceinline__ __vdiv(float s, float v) { + return v/s; +} + +template<> +__device__ float4 __forceinline__ __vset(float x) { + return make_float4(x, x, x, x); +} + +__device__ float4 __forceinline__ __vmul(float4 a, float4 b) { + return make_float4(a.x*b.x, a.y*b.y, a.z*b.z, a.w*b.w); +} + +__device__ float4 __forceinline__ __vadd(float4 a, float4 b) { + return make_float4(a.x+b.x, a.y+b.y, a.z+b.z, a.w+b.w); +} + +__device__ float4 __forceinline__ __vsub(float4 a, float4 b) { + return make_float4(a.x-b.x, a.y-b.y, a.z-b.z, a.w-b.w); +} + +__device__ float __forceinline__ __vred(float4 a) { + return a.x + a.y + a.z + a.w; +} + +__device__ float4 __forceinline__ __vscale(float s, float4 v) { + return make_float4(s*v.x, s*v.y, s*v.z, s*v.w); +} + +__device__ float4 __forceinline__ __vdiv(float s, float4 v) { + return make_float4(s/v.x, s/v.y, s/v.z, s/v.w);; +} + +template +__device__ VAL_T __warp_sum(VAL_T val) { + + #pragma unroll + for(int i = WARP_SIZE/2; i; i /= 2) { + val += __shfl_xor_sync(FULL_MASK, val, i, WARP_SIZE); + } + return val; +} + +template +__device__ VAL_T __block_sum(VAL_T val) { + + const int NWARP = (BDIM_X*BDIM_Y*BDIM_Z) / WARP_SIZE; + + val = __warp_sum(val); + + if constexpr(NWARP > 1) { + + int tid = threadIdx.x; + if constexpr(BDIM_Y > 1) { tid += threadIdx.y*BDIM_X; } + if constexpr(BDIM_Z > 1) { tid += threadIdx.z*BDIM_X*BDIM_Y; } + + const int lid = tid%WARP_SIZE; + const int wid = tid/WARP_SIZE; + + __shared__ VAL_T sh[NWARP]; + + if (lid == 0) { + sh[wid] = val; + } + __syncthreads(); + + if (wid == 0) { + val = (lid < NWARP) ? sh[lid] : 0; + + val = __warp_sum(val); + __syncwarp(); + + if (!lid) { + sh[0] = val; + } + } + __syncthreads(); + + val = sh[0]; + __syncthreads(); + } + return val; +} + +// transpose utils +template +__global__ +__launch_bounds__(BDIM_X*BDIM_Y) +void permute_to0231_k(const int nchn, + const int nlat, + const int nlon, + const at::PackedTensorAccessor32 src, + at::PackedTensorAccessor32 dst) { + + static_assert(!(BDIM_X & (BDIM_X-1))); + static_assert(!(BDIM_Y & (BDIM_Y-1))); + static_assert(BDIM_X >= BDIM_Y); + + __shared__ VAL_T sh[BDIM_X][BDIM_X+1]; + + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + + const int coff = blockIdx.x*BDIM_X; // channel offset + const int woff = blockIdx.y*BDIM_X; // width offset + const int batch = blockIdx.z / nlat; // batch (same for all block) + const int h = blockIdx.z - (batch * nlat); // height (same for all block) + + const int nchn_full = (nchn-coff) >= BDIM_X; + const int nlon_full = (nlon-woff) >= BDIM_X; + + if (nchn_full && nlon_full) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + sh[j+tidy][tidx] = src[batch][coff + j+tidy][h][woff+tidx]; + } + __syncthreads(); + + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy]; + } + } else { + if (woff+tidx < nlon) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + sh[j+tidy][tidx] = (coff + j+tidy < nchn) ? src[batch][coff + j+tidy][h][woff+tidx] : VAL_T(0); + } + } + __syncthreads(); + + if (coff+tidx < nchn) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + if (woff + j+tidy < nlon) { + dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy]; + } + } + } + } + return; +} + +template +void launch_permute_to0231(at::Tensor src, at::Tensor dst){ + dim3 block; + dim3 grid; + + block.x = WARP_SIZE; + block.y = WARPS_X_TILE; + grid.x = DIV_UP(src.size(1), block.x); + grid.y = DIV_UP(src.size(3), block.x); + grid.z = src.size(2)*src.size(0); + + assert(grid.y < 65536); + assert(grid.z < 65536); + + // get stream + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + permute_to0231_k + <<>>(src.size(1), + src.size(2), + src.size(3), + src.packed_accessor32(), + dst.packed_accessor32()); +} + +template +__global__ +__launch_bounds__(BDIM_X*BDIM_Y) +void permute_to0312_k(const int nchn, + const int nlat, + const int nlon, + const at::PackedTensorAccessor32 src, + at::PackedTensorAccessor32 dst) { + + static_assert(!(BDIM_X & (BDIM_X-1))); + static_assert(!(BDIM_Y & (BDIM_Y-1))); + static_assert(BDIM_X >= BDIM_Y); + + __shared__ VAL_T sh[BDIM_X][BDIM_X+1]; + + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + + const int woff = blockIdx.x*BDIM_X; // width offset + const int coff = blockIdx.y*BDIM_X; // channel offset + const int batch = blockIdx.z / nlat; // batch (same for all block) + const int h = blockIdx.z - (batch * nlat); // height (same for all block) + + const int nchn_full = (nchn-coff) >= BDIM_X; + const int nlon_full = (nlon-woff) >= BDIM_X; + + if (nchn_full && nlon_full) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + sh[j+tidy][tidx] = src[batch][h][woff + j+tidy][coff+tidx]; + } + __syncthreads(); + + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy]; + } + } else { + if (coff+tidx < nchn) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + sh[j+tidy][tidx] = (woff + j+tidy < nlon) ? src[batch][h][woff + j+tidy][coff+tidx] : VAL_T(0); + } + } + __syncthreads(); + + if (woff+tidx < nlon) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + if (coff + j+tidy < nchn) { + dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];; + } + } + } + } + return; +} + +template +void launch_permute_to0312(at::Tensor src, at::Tensor dst){ + dim3 block; + dim3 grid; + + block.x = WARP_SIZE; + block.y = WARPS_X_TILE; + grid.x = DIV_UP(src.size(2), block.x); + grid.y = DIV_UP(src.size(3), block.x); + grid.z = src.size(1)*src.size(0); + + assert(grid.y < 65536); + assert(grid.z < 65536); + + // get stream + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + permute_to0312_k + <<>>(src.size(3), + src.size(1), + src.size(2), + src.packed_accessor32(), + dst.packed_accessor32()); +} + +} diff --git a/torch_harmonics/disco/csrc/disco_interface.cpp b/torch_harmonics/disco/csrc/disco_interface.cpp index 779616f4..0eeb122c 100644 --- a/torch_harmonics/disco/csrc/disco_interface.cpp +++ b/torch_harmonics/disco/csrc/disco_interface.cpp @@ -54,8 +54,8 @@ namespace disco_kernels { // Declare the operators TORCH_LIBRARY(disco_kernels, m) { - m.def("forward(Tensor inp, Tensor roff_idx, Tensor ker_idx, Tensor row_idx, Tensor col_idx, Tensor vals, int kernel_size, int nlat_out, int nlon_out) -> Tensor", {at::Tag::pt2_compliant_tag}); - m.def("backward(Tensor inp, Tensor roff_idx, Tensor ker_idx, Tensor row_idx, Tensor col_idx, Tensor vals, int kernel_size, int nlat_out, int nlon_out) -> Tensor", {at::Tag::pt2_compliant_tag}); + m.def("forward(Tensor inp, Tensor roff_idx, Tensor ker_idx, Tensor row_idx, Tensor col_idx, Tensor vals, Tensor weights, int kernel_size, int nlat_out, int nlon_out) -> Tensor", {at::Tag::pt2_compliant_tag}); + m.def("backward(Tensor inp, Tensor roff_idx, Tensor ker_idx, Tensor row_idx, Tensor col_idx, Tensor vals, Tensor weights, int kernel_size, int nlat_out, int nlon_out) -> Tensor", {at::Tag::pt2_compliant_tag}); } } From 2932853da91756e94f6047472d44c0cfb1c86857 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Tue, 26 Aug 2025 01:12:50 -0700 Subject: [PATCH 02/31] there is more work to do --- setup.py | 4 +- torch_harmonics/__init__.py | 2 +- .../attention/csrc/attention_interface.cu | 38 --- torch_harmonics/disco/__init__.py | 2 +- torch_harmonics/disco/_disco_utils.py | 203 ++++++------ torch_harmonics/disco/convolution.py | 297 +++++++++--------- torch_harmonics/disco/csrc/disco_cuda.cuh | 8 +- torch_harmonics/disco/csrc/disco_cuda_bwd.cu | 70 +++-- torch_harmonics/disco/csrc/disco_cuda_fwd.cu | 40 ++- .../disco/csrc/disco_interface.cpp | 4 +- 10 files changed, 337 insertions(+), 331 deletions(-) delete mode 100644 torch_harmonics/attention/csrc/attention_interface.cu diff --git a/setup.py b/setup.py index 23d8e4f6..6d0e4514 100644 --- a/setup.py +++ b/setup.py @@ -122,7 +122,7 @@ def get_ext_modules(): # Create a single extension that includes both CPU and CUDA code disco_sources = [ "torch_harmonics/disco/csrc/disco_interface.cpp", - "torch_harmonics/disco/csrc/disco_cpu.cpp" + #"torch_harmonics/disco/csrc/disco_cpu.cpp" ] if BUILD_CUDA: @@ -157,7 +157,7 @@ def get_ext_modules(): "torch_harmonics/attention/csrc/attention_cpu_bwd.cpp", ] - if BUILD_CUDA: + if False: #BUILD_CUDA: print(f"Compiling attention CUDA kernels for torch-harmonics.") attention_sources.extend([ "torch_harmonics/attention/csrc/attention_cuda_utils.cu", diff --git a/torch_harmonics/__init__.py b/torch_harmonics/__init__.py index 2d390b54..77e43db0 100644 --- a/torch_harmonics/__init__.py +++ b/torch_harmonics/__init__.py @@ -32,7 +32,7 @@ __version__ = "0.8.1" from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT -from .disco import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 +from .disco import DiscreteContinuousConvS2 #, DiscreteContinuousConvTransposeS2 from .resample import ResampleS2 from .attention import AttentionS2, NeighborhoodAttentionS2 from . import quadrature diff --git a/torch_harmonics/attention/csrc/attention_interface.cu b/torch_harmonics/attention/csrc/attention_interface.cu deleted file mode 100644 index b18c315c..00000000 --- a/torch_harmonics/attention/csrc/attention_interface.cu +++ /dev/null @@ -1,38 +0,0 @@ -// coding=utf-8 -// -// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. -// SPDX-License-Identifier: BSD-3-Clause -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -//#include "attention.cuh" -//#include - -//PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -//{ -// m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2"); -// m.def("backward_dkvq", &s2_attention_bwd_dkvq_cuda, "(Local) Attention gradient on S2 (gradient for k,v,&q)"); -//} diff --git a/torch_harmonics/disco/__init__.py b/torch_harmonics/disco/__init__.py index 50d3268b..f8f09f14 100644 --- a/torch_harmonics/disco/__init__.py +++ b/torch_harmonics/disco/__init__.py @@ -42,4 +42,4 @@ disco_kernels = None warnings.warn("No optimized kernels are available. Please compile the extension first setting BUILD_CPP and BUILD_CUDA to 1.") -from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 +from .convolution import DiscreteContinuousConvS2 #, DiscreteContinuousConvTransposeS2 diff --git a/torch_harmonics/disco/_disco_utils.py b/torch_harmonics/disco/_disco_utils.py index a870356c..ed1979a2 100644 --- a/torch_harmonics/disco/_disco_utils.py +++ b/torch_harmonics/disco/_disco_utils.py @@ -29,7 +29,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -from typing import Optional +from typing import Optional, Tuple import math import torch @@ -40,105 +40,110 @@ if optimized_kernels_is_available(): # raw forward fake @torch.library.register_fake("disco_kernels::forward") - def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, - row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, - weights: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - out_shape = (inp.shape[0], inp.shape[1], kernel_size, nlat_out, nlon_out) - return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) - - # raw backward fake - @torch.library.register_fake("disco_kernels::backward") - def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, - row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, - weights: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - out_shape = (inp.shape[0], inp.shape[1], nlat_out, nlon_out) - return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) + def _(inp: torch.Tensor, weights: torch.Tensor, + roff_idx: torch.Tensor, ker_idx: torch.Tensor, + row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, + nlat_out: int, nlon_out: int) -> Tuple[torch.Tensor, torch.Tensor]: + out_shape = (inp.shape[0], weights.shape[0] * weights.shape[1], nlat_out, nlon_out) + dout_shape = (inp.shape[0], weights.shape[0] * weights.shape[2], nlat_out, nlon_out) + return torch.empty(out_shape, dtype=inp.dtype, device=inp.device), torch.empty(dout_shape, dtype=inp.dtype, device=inp.device) + + # # raw backward fake + # @torch.library.register_fake("disco_kernels::backward") + # def _(inp: torch.Tensor, weights: torch.Tensor, + # roff_idx: torch.Tensor, ker_idx: torch.Tensor, + # row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, + # nlat_out: int, nlon_out: int) -> torch.Tensor: + # out_shape = (inp.shape[0], weights.shape[0] * weights.shape[2], nlat_out, nlon_out) + # return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) # forward @torch.library.custom_op("disco_kernels::_disco_s2_contraction_optimized", mutates_args=()) def _disco_s2_contraction_optimized( - inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, + inp: torch.Tensor, weights: torch.Tensor, + roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, - weights: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - itype = inp.dtype - inp = inp.to(torch.float32).contiguous() - out = disco_kernels.forward.default(inp, roff_idx, ker_idx, row_idx, col_idx, vals, weights, kernel_size, nlat_out, nlon_out) - out = out.to(itype) - return out - - # transpose - @torch.library.custom_op("disco_kernels::_disco_s2_transpose_contraction_optimized", mutates_args=()) - def _disco_s2_transpose_contraction_optimized( - inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, - row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, - weights: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - itype = inp.dtype - inp = inp.to(torch.float32).contiguous() - out = disco_kernels.backward.default(inp, roff_idx, ker_idx, row_idx, col_idx, vals, weights, kernel_size, nlat_out, nlon_out) - out = out.to(itype) - return out + nlat_out: int, nlon_out: int) -> Tuple[torch.Tensor, torch.Tensor]: + out, dout = disco_kernels.forward.default(inp, weights, roff_idx, ker_idx, row_idx, col_idx, vals, nlat_out, nlon_out) + return out, dout + + # # transpose + # @torch.library.custom_op("disco_kernels::_disco_s2_transpose_contraction_optimized", mutates_args=()) + # def _disco_s2_transpose_contraction_optimized( + # inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, + # row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, + # weights: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + # itype = inp.dtype + # inp = inp.to(torch.float32).contiguous() + # out = disco_kernels.backward.default(inp, roff_idx, ker_idx, row_idx, col_idx, vals, weights, kernel_size, nlat_out, nlon_out) + # out = out.to(itype) + # return out # forward fake @torch.library.register_fake("disco_kernels::_disco_s2_contraction_optimized") - def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, + def _(inp: torch.Tensor, weights: torch.Tensor, + roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, - weights: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - out_shape = (inp.shape[0], inp.shape[1], kernel_size, nlat_out, nlon_out) - return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) - - # transpose fake - @torch.library.register_fake("disco_kernels::_disco_s2_transpose_contraction_optimized") - def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, - row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, - weights: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - out_shape = (inp.shape[0], inp.shape[1], nlat_out, nlon_out) - return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) + nlat_out: int, nlon_out: int) -> Tuple[torch.Tensor, torch.Tensor]: + out_shape = (inp.shape[0], weights.shape[0] * weights.shape[1], nlat_out, nlon_out) + dout_shape = (inp.shape[0], weights.shape[0] * weights.shape[2], nlat_out, nlon_out) + return torch.empty(out_shape, dtype=inp.dtype, device=inp.device), torch.empty(dout_shape, dtype=inp.dtype, device=inp.device) + + # # transpose fake + # @torch.library.register_fake("disco_kernels::_disco_s2_transpose_contraction_optimized") + # def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, + # row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, + # weights: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + # out_shape = (inp.shape[0], weights.shape[0] * weights.shape[1], nlat_out, nlon_out) + # return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) #general routines: this is the same for forward and transpose def _setup_context_conv_backward(ctx, inputs, output): - inp, roff_idx, ker_idx, row_idx, col_idx, vals, _, kernel_size, nlat_out, nlon_out = inputs - ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals) - ctx.kernel_size = kernel_size + inp, weights, roff_idx, ker_idx, row_idx, col_idx, vals, _, _ = inputs + _, dinp = output + ctx.save_for_backward(dinp, weights, roff_idx, ker_idx, row_idx, col_idx, vals) ctx.nlat_in = inp.shape[-2] ctx.nlon_in = inp.shape[-1] # convolution related def _disco_s2_contraction_bwd_optimized(ctx, grad_output): - roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors + dinp, weights, roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors + + print("grad_output", grad_output) + + print("SHAPE CHECK", grad_output.shape, dinp.shape, weights.shape) if ctx.needs_input_grad[0]: - gtype = grad_output.dtype - grad_output = grad_output.to(torch.float32).contiguous() - grad_input = disco_kernels.backward.default(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals, - torch.empty(0), ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) # Mauro - grad_input = grad_input.to(gtype) + grad_input, wgrad = disco_kernels.backward.default(grad_output, dinp, weights, roff_idx, ker_idx, row_idx, col_idx, vals, + ctx.nlat_in, ctx.nlon_in) # Mauro else: grad_input = None + wgrad = None - return grad_input, None, None, None, None, None, None, None, None, None # Mauro: added a None for weights + return grad_input, wgrad, None, None, None, None, None, None, None # Mauro: added a None for weights if optimized_kernels_is_available(): torch.library.register_autograd( "disco_kernels::_disco_s2_contraction_optimized", _disco_s2_contraction_bwd_optimized, setup_context=_setup_context_conv_backward) -# Transpose convolution related -def _disco_s2_transpose_contraction_bwd_optimized(ctx, grad_output): - roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors +# # Transpose convolution related +# def _disco_s2_transpose_contraction_bwd_optimized(ctx, grad_output): +# roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors - if ctx.needs_input_grad[0]: - gtype = grad_output.dtype - grad_output = grad_output.to(torch.float32).contiguous() - grad_input = disco_kernels.forward.default(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals, - torch.empty(0), ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) # Mauro - grad_input = grad_input.to(gtype) - else: - grad_input = None +# if ctx.needs_input_grad[0]: +# gtype = grad_output.dtype +# grad_output = grad_output.to(torch.float32).contiguous() +# grad_input = disco_kernels.forward.default(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals, +# torch.empty(0), ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) # Mauro +# grad_input = grad_input.to(gtype) +# else: +# grad_input = None - return grad_input, None, None, None, None, None, None, None, None, None # Mauro: added a None for weights +# return grad_input, None, None, None, None, None, None, None, None, None # Mauro: added a None for weights -if optimized_kernels_is_available(): - torch.library.register_autograd( - "disco_kernels::_disco_s2_transpose_contraction_optimized", _disco_s2_transpose_contraction_bwd_optimized, setup_context=_setup_context_conv_backward) +# if optimized_kernels_is_available(): +# torch.library.register_autograd( +# "disco_kernels::_disco_s2_transpose_contraction_optimized", _disco_s2_transpose_contraction_bwd_optimized, setup_context=_setup_context_conv_backward) # torch kernel related functions def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nlat_in: int, nlon_in: int, nlat_out: int, nlon_out: int, nlat_in_local: Optional[int] = None, nlat_out_local: Optional[int] = None, semi_transposed: Optional[bool] = False): @@ -182,7 +187,7 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in # add a dummy dimension for nkernel and move the batch and channel dims to the end x = x.reshape(1, batch_size * n_chans, nlat_in, nlon_in).permute(0, 2, 3, 1) - x = x.expand(kernel_size, -1, -1, -1) + x = x.expand(kernel_size, -1, -1, -1).contiguous() y = torch.zeros(nlon_out, kernel_size, nlat_out, batch_size * n_chans, device=x.device, dtype=x.dtype) @@ -193,43 +198,43 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in x = torch.roll(x, -pscale, dims=2) # reshape y back to expose the correct dimensions - y = y.permute(3, 1, 2, 0).reshape(batch_size, n_chans, kernel_size, nlat_out, nlon_out) + y = y.permute(3, 1, 2, 0).reshape(batch_size, n_chans, kernel_size, nlat_out, nlon_out).contiguous() return y -# transpose convolution -def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int): - assert len(psi.shape) == 3 - assert len(x.shape) == 5 - psi = psi.to(x.device) +# # transpose convolution +# def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int): +# assert len(psi.shape) == 3 +# assert len(x.shape) == 5 +# psi = psi.to(x.device) - batch_size, n_chans, kernel_size, nlat_in, nlon_in = x.shape - kernel_size, nlat_out, n_out = psi.shape +# batch_size, n_chans, kernel_size, nlat_in, nlon_in = x.shape +# kernel_size, nlat_out, n_out = psi.shape - assert n_out % nlon_out == 0 - assert nlon_out >= nlon_in - pscale = nlon_out // nlon_in +# assert n_out % nlon_out == 0 +# assert nlon_out >= nlon_in +# pscale = nlon_out // nlon_in - # interleave zeros along the longitude dimension to allow for fractional offsets to be considered - x_ext = torch.zeros(kernel_size, nlat_in, nlon_out, batch_size * n_chans, device=x.device, dtype=x.dtype) - x = x.reshape(batch_size * n_chans, kernel_size, nlat_in, nlon_in).permute(1, 2, 3, 0) +# # interleave zeros along the longitude dimension to allow for fractional offsets to be considered +# x_ext = torch.zeros(kernel_size, nlat_in, nlon_out, batch_size * n_chans, device=x.device, dtype=x.dtype) +# x = x.reshape(batch_size * n_chans, kernel_size, nlat_in, nlon_in).permute(1, 2, 3, 0).contiguous() - # x has shape kernel_size x nlat_in x nlon_in x batch_size * n_chans - # we only need to apoply the nlon stride here, since nlat stride is taken care of by the kernel - x_ext[:, :, ::pscale, :] = x[...] +# # x has shape kernel_size x nlat_in x nlon_in x batch_size * n_chans +# # we only need to apoply the nlon stride here, since nlat stride is taken care of by the kernel +# x_ext[:, :, ::pscale, :] = x[...] - # create output tensor - y = torch.zeros(kernel_size, nlon_out, nlat_out, batch_size * n_chans, device=x.device, dtype=x.dtype) +# # create output tensor +# y = torch.zeros(kernel_size, nlon_out, nlat_out, batch_size * n_chans, device=x.device, dtype=x.dtype) - for pout in range(nlon_out): - # we need to repeatedly roll the input tensor to faciliate the shifted multiplication - # TODO: double-check why this has to happen first - x_ext = torch.roll(x_ext, -1, dims=2) - # sparse contraction with the modified psi - y[:, pout, :, :] = torch.bmm(psi, x_ext.reshape(kernel_size, nlat_in * nlon_out, -1)) +# for pout in range(nlon_out): +# # we need to repeatedly roll the input tensor to faciliate the shifted multiplication +# # TODO: double-check why this has to happen first +# x_ext = torch.roll(x_ext, -1, dims=2) +# # sparse contraction with the modified psi +# y[:, pout, :, :] = torch.bmm(psi, x_ext.reshape(kernel_size, nlat_in * nlon_out, -1)) - # sum over the kernel dimension and reshape to the correct output size - y = y.sum(dim=0).permute(2, 1, 0).reshape(batch_size, n_chans, nlat_out, nlon_out).contiguous() +# # sum over the kernel dimension and reshape to the correct output size +# y = y.sum(dim=0).permute(2, 1, 0).reshape(batch_size, n_chans, nlat_out, nlon_out).contiguous() - return y +# return y diff --git a/torch_harmonics/disco/convolution.py b/torch_harmonics/disco/convolution.py index cedfdef8..b55ac27a 100644 --- a/torch_harmonics/disco/convolution.py +++ b/torch_harmonics/disco/convolution.py @@ -46,8 +46,8 @@ from torch_harmonics.cache import lru_cache from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes -from ._disco_utils import _get_psi, _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch -from ._disco_utils import _disco_s2_contraction_optimized, _disco_s2_transpose_contraction_optimized +from ._disco_utils import _get_psi, _disco_s2_contraction_torch #, _disco_s2_transpose_contraction_torch +from ._disco_utils import _disco_s2_contraction_optimized #, _disco_s2_transpose_contraction_optimized from torch_harmonics.filter_basis import FilterBasis, get_filter_basis from disco_helpers import optimized_kernels_is_available, preprocess_psi @@ -374,9 +374,10 @@ def __init__( raise ValueError("Error, the number of input channels has to be an integer multiple of the group size") if out_channels % self.groups != 0: raise ValueError("Error, the number of output channels has to be an integer multiple of the group size") - self.groupsize = in_channels // self.groups - scale = math.sqrt(1.0 / self.groupsize / self.kernel_size) - self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size)) + self.groupsize_in = in_channels // self.groups + self.groupsize_out = out_channels // self.groups + scale = math.sqrt(1.0 / self.groupsize_in / self.kernel_size) + self.weight = nn.Parameter(scale * torch.randn(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size)) if bias: self.bias = nn.Parameter(torch.zeros(out_channels)) @@ -513,11 +514,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.optimized_kernel: with nvtx.annotate("_disco_s2_contraction_optimized", color="red"): - out = _disco_s2_contraction_optimized( - #x = _disco_s2_contraction_optimized( - x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, - self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2]), - self.kernel_size, self.nlat_out, self.nlon_out + out, _ = _disco_s2_contraction_optimized( + x, self.weight, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, + self.nlat_out, self.nlon_out ) else: x = _disco_s2_contraction_torch(x, self.psi.to(x.device), self.nlon_out) @@ -530,15 +529,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: #print("after reshape, weight.shape:", pippo.shape) # extract shape - B, C, K, H, W = x.shape + B, _, K, H, W = x.shape with nvtx.annotate("reshape", color="blue"): - x = x.reshape(B, self.groups, self.groupsize, K, H, W) + x = x.reshape(B, self.groups, self.groupsize_in, K, H, W) #print("after reshape, x.shape:", x.shape) # do weight multiplication with nvtx.annotate("einsum", color="blue"): - out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous() + out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight).contiguous() #print("out.shape:", out.shape) out = out.reshape(B, -1, H, W) @@ -556,139 +555,139 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out -class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): - """ - Discrete-continuous (DISCO) transpose convolutions on the 2-Sphere as described in [1]. - - Parameters - ----------- - in_channels: int - Number of input channels - out_channels: int - Number of output channels - in_shape: Tuple[int] - Input shape of the convolution tensor - out_shape: Tuple[int] - Output shape of the convolution tensor - kernel_shape: Union[int, Tuple[int], Tuple[int, int]] - Shape of the kernel - basis_type: Optional[str] - Type of the basis functions - basis_norm_mode: Optional[str] - Mode for basis normalization - groups: Optional[int] - Number of groups - grid_in: Optional[str] - Input grid type - grid_out: Optional[str] - Output grid type - bias: Optional[bool] - Whether to use bias - theta_cutoff: Optional[float] - Theta cutoff for the filter basis functions - optimized_kernel: Optional[bool] - Whether to use the optimized kernel (if available) - - Returns - -------- - out: torch.Tensor - Output tensor - - References - ---------- - [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603 - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - in_shape: Tuple[int], - out_shape: Tuple[int], - kernel_shape: Union[int, Tuple[int], Tuple[int, int]], - basis_type: Optional[str] = "piecewise linear", - basis_norm_mode: Optional[str] = "mean", - groups: Optional[int] = 1, - grid_in: Optional[str] = "equiangular", - grid_out: Optional[str] = "equiangular", - bias: Optional[bool] = True, - theta_cutoff: Optional[float] = None, - optimized_kernel: Optional[bool] = True, - ): - super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias, optimized_kernel) - - self.nlat_in, self.nlon_in = in_shape - self.nlat_out, self.nlon_out = out_shape - - # make sure the p-shift works by checking that longitudes are divisible - assert self.nlon_out % self.nlon_in == 0 - - # bandlimit - if theta_cutoff is None: - theta_cutoff = torch.pi / float(self.nlat_in - 1) - - if theta_cutoff <= 0.0: - raise ValueError("Error, theta_cutoff has to be positive.") - - # switch in_shape and out_shape since we want the transpose convolution - idx, vals, _ = _precompute_convolution_tensor_s2( - out_shape, - in_shape, - self.filter_basis, - grid_in=grid_out, - grid_out=grid_in, - theta_cutoff=theta_cutoff, - transpose_normalization=True, - basis_norm_mode=basis_norm_mode, - merge_quadrature=True, - ) - - # sort the values - ker_idx = idx[0, ...].contiguous() - row_idx = idx[1, ...].contiguous() - col_idx = idx[2, ...].contiguous() - vals = vals.contiguous() - - if self.optimized_kernel: - # preprocessed data-structure for GPU kernel - roff_idx = preprocess_psi(self.kernel_size, self.nlat_in, ker_idx, row_idx, col_idx, vals).contiguous() - self.register_buffer("psi_roff_idx", roff_idx, persistent=False) - - # save all datastructures - self.register_buffer("psi_ker_idx", ker_idx, persistent=False) - self.register_buffer("psi_row_idx", row_idx, persistent=False) - self.register_buffer("psi_col_idx", col_idx, persistent=False) - self.register_buffer("psi_vals", vals, persistent=False) - - # also store psi just in case - if not self.optimized_kernel: - self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, semi_transposed=True) - - def extra_repr(self): - return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" - - @property - def psi_idx(self): - return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - - # extract shape - B, C, H, W = x.shape - x = x.reshape(B, self.groups, self.groupsize, H, W) - - # do weight multiplication - x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous() - x = x.reshape(B, -1, x.shape[-3], H, W) - - if self.optimized_kernel: - out = _disco_s2_transpose_contraction_optimized( - x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out - ) - else: - out = _disco_s2_transpose_contraction_torch(x, self.psi_st.to(x.device), self.nlon_out) - - if self.bias is not None: - out = out + self.bias.reshape(1, -1, 1, 1) - - return out +# class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): +# """ +# Discrete-continuous (DISCO) transpose convolutions on the 2-Sphere as described in [1]. + +# Parameters +# ----------- +# in_channels: int +# Number of input channels +# out_channels: int +# Number of output channels +# in_shape: Tuple[int] +# Input shape of the convolution tensor +# out_shape: Tuple[int] +# Output shape of the convolution tensor +# kernel_shape: Union[int, Tuple[int], Tuple[int, int]] +# Shape of the kernel +# basis_type: Optional[str] +# Type of the basis functions +# basis_norm_mode: Optional[str] +# Mode for basis normalization +# groups: Optional[int] +# Number of groups +# grid_in: Optional[str] +# Input grid type +# grid_out: Optional[str] +# Output grid type +# bias: Optional[bool] +# Whether to use bias +# theta_cutoff: Optional[float] +# Theta cutoff for the filter basis functions +# optimized_kernel: Optional[bool] +# Whether to use the optimized kernel (if available) + +# Returns +# -------- +# out: torch.Tensor +# Output tensor + +# References +# ---------- +# [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603 +# """ + +# def __init__( +# self, +# in_channels: int, +# out_channels: int, +# in_shape: Tuple[int], +# out_shape: Tuple[int], +# kernel_shape: Union[int, Tuple[int], Tuple[int, int]], +# basis_type: Optional[str] = "piecewise linear", +# basis_norm_mode: Optional[str] = "mean", +# groups: Optional[int] = 1, +# grid_in: Optional[str] = "equiangular", +# grid_out: Optional[str] = "equiangular", +# bias: Optional[bool] = True, +# theta_cutoff: Optional[float] = None, +# optimized_kernel: Optional[bool] = True, +# ): +# super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias, optimized_kernel) + +# self.nlat_in, self.nlon_in = in_shape +# self.nlat_out, self.nlon_out = out_shape + +# # make sure the p-shift works by checking that longitudes are divisible +# assert self.nlon_out % self.nlon_in == 0 + +# # bandlimit +# if theta_cutoff is None: +# theta_cutoff = torch.pi / float(self.nlat_in - 1) + +# if theta_cutoff <= 0.0: +# raise ValueError("Error, theta_cutoff has to be positive.") + +# # switch in_shape and out_shape since we want the transpose convolution +# idx, vals, _ = _precompute_convolution_tensor_s2( +# out_shape, +# in_shape, +# self.filter_basis, +# grid_in=grid_out, +# grid_out=grid_in, +# theta_cutoff=theta_cutoff, +# transpose_normalization=True, +# basis_norm_mode=basis_norm_mode, +# merge_quadrature=True, +# ) + +# # sort the values +# ker_idx = idx[0, ...].contiguous() +# row_idx = idx[1, ...].contiguous() +# col_idx = idx[2, ...].contiguous() +# vals = vals.contiguous() + +# if self.optimized_kernel: +# # preprocessed data-structure for GPU kernel +# roff_idx = preprocess_psi(self.kernel_size, self.nlat_in, ker_idx, row_idx, col_idx, vals).contiguous() +# self.register_buffer("psi_roff_idx", roff_idx, persistent=False) + +# # save all datastructures +# self.register_buffer("psi_ker_idx", ker_idx, persistent=False) +# self.register_buffer("psi_row_idx", row_idx, persistent=False) +# self.register_buffer("psi_col_idx", col_idx, persistent=False) +# self.register_buffer("psi_vals", vals, persistent=False) + +# # also store psi just in case +# if not self.optimized_kernel: +# self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, semi_transposed=True) + +# def extra_repr(self): +# return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" + +# @property +# def psi_idx(self): +# return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous() + +# def forward(self, x: torch.Tensor) -> torch.Tensor: + +# # extract shape +# B, C, H, W = x.shape +# x = x.reshape(B, self.groups, self.groupsize, H, W) + +# # do weight multiplication +# x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous() +# x = x.reshape(B, -1, x.shape[-3], H, W) + +# if self.optimized_kernel: +# out = _disco_s2_transpose_contraction_optimized( +# x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out +# ) +# else: +# out = _disco_s2_transpose_contraction_torch(x, self.psi_st.to(x.device), self.nlon_out) + +# if self.bias is not None: +# out = out + self.bias.reshape(1, -1, 1, 1) + +# return out diff --git a/torch_harmonics/disco/csrc/disco_cuda.cuh b/torch_harmonics/disco/csrc/disco_cuda.cuh index 2a7fa7e5..f7c05708 100644 --- a/torch_harmonics/disco/csrc/disco_cuda.cuh +++ b/torch_harmonics/disco/csrc/disco_cuda.cuh @@ -52,11 +52,11 @@ namespace disco_kernels { // forward kernel - torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, - torch::Tensor col_idx, torch::Tensor val, torch::Tensor weights, int64_t K, int64_t Ho, int64_t Wo); + std::tuple disco_cuda_fwd(torch::Tensor inp, torch::Tensor weights, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, + torch::Tensor col_idx, torch::Tensor val, int64_t Ho, int64_t Wo); // backward kernel - torch::Tensor disco_cuda_bwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, - torch::Tensor col_idx, torch::Tensor val, torch::Tensor weights, int64_t K, int64_t Ho, int64_t Wo); + std::tuple disco_cuda_bwd(torch::Tensor ograd, torch::Tensor dinp, torch::Tensor weights, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, + torch::Tensor col_idx, torch::Tensor val, int64_t Ho, int64_t Wo); } diff --git a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu index 13d5cdda..8b1c2f9c 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu @@ -198,12 +198,14 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t return; } - torch::Tensor disco_cuda_bwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, - torch::Tensor col_idx, torch::Tensor val, torch::Tensor weights, int64_t K, int64_t Ho, int64_t Wo) + std::tuple disco_cuda_bwd(torch::Tensor inp, torch::Tensor dinp, torch::Tensor weights, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, + torch::Tensor col_idx, torch::Tensor val, int64_t Ho, int64_t Wo) { // some sanity checks CHECK_CUDA_INPUT_TENSOR(inp); + CHECK_CUDA_INPUT_TENSOR(dinp); + CHECK_CUDA_INPUT_TENSOR(weights); CHECK_CUDA_INPUT_TENSOR(roff_idx); CHECK_CUDA_INPUT_TENSOR(ker_idx); CHECK_CUDA_INPUT_TENSOR(row_idx); @@ -212,69 +214,97 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t // extract some shapes int64_t B = inp.size(0); - int64_t C = inp.size(1); - int64_t BC = B * C; - int64_t Hi = inp.size(3); - int64_t Wi = inp.size(4); + int64_t Cin = inp.size(1); + int64_t Cout = weights.size(0) * weights.size(2); + int64_t K = weights.size(3); + int64_t BC = B * Cout; + int64_t Hi = inp.size(2); + int64_t Wi = inp.size(3); int64_t nrows = roff_idx.size(0) - 1; // allocate output - int64_t out_dims[] = {B, C, Ho, Wo}; - auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype()); - torch::Tensor out = torch::zeros(out_dims, options); + int64_t out_dims[] = {B, Cout, Ho, Wo}; // get stream auto stream = at::cuda::getCurrentCUDAStream().stream(); + // print shapes + std::cout << "dinp.dim(): " << dinp.dim() << std::endl; + for (int i=0; ibgikxy", {inp_resh, weights}).reshape({B, Cout, K, Hi, Wi}).contiguous(); + auto wgrad = torch::einsum("bgixy,bgokxy->giok", {inp_resh, dinp_resh}).to(torch::kFloat32).contiguous(); + + // extract dtype + auto x_type = x.dtype(); + torch::Tensor xP = x.to(torch::kFloat32); + + torch::Tensor out = torch::zeros(out_dims, xP.options()); + // assert static_assert(0 == (ELXTH_MAX % 2)); if (Wo <= 64 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(xP.scalar_type(), "disco_backward_cuda", ([&] { launch_kernel<64, 1, scalar_t>( BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), ker_idx.data_ptr(), row_idx.data_ptr(), col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); + xP.data_ptr(), out.data_ptr(), stream); })); } else if (Wo <= 128 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(xP.scalar_type(), "disco_backward_cuda", ([&] { launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>( BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), ker_idx.data_ptr(), row_idx.data_ptr(), col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); + xP.data_ptr(), out.data_ptr(), stream); })); } else if (Wo <= 256 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(xP.scalar_type(), "disco_backward_cuda", ([&] { launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>( BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), ker_idx.data_ptr(), row_idx.data_ptr(), col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); + xP.data_ptr(), out.data_ptr(), stream); })); } else if (Wo <= 512 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(xP.scalar_type(), "disco_backward_cuda", ([&] { launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>( BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), ker_idx.data_ptr(), row_idx.data_ptr(), col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); + xP.data_ptr(), out.data_ptr(), stream); })); } else if (Wo <= 1024 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(xP.scalar_type(), "disco_backward_cuda", ([&] { launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>( BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), ker_idx.data_ptr(), row_idx.data_ptr(), col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); + xP.data_ptr(), out.data_ptr(), stream); })); } else { fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo, 1024 * ELXTH_MAX); exit(EXIT_FAILURE); } - return out; + + // convert back to original dtype + out = out.to(x_type); + + return std::make_tuple(out, wgrad); } TORCH_LIBRARY_IMPL(disco_kernels, CUDA, m) diff --git a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu index 4c31e9f8..e108ea60 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu @@ -970,12 +970,13 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, return; } - torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, - torch::Tensor col_idx, torch::Tensor val, torch::Tensor weights, int64_t K, int64_t Ho, int64_t Wo) + std::tuple disco_cuda_fwd(torch::Tensor inp, torch::Tensor weights, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, + torch::Tensor col_idx, torch::Tensor val, int64_t Ho, int64_t Wo) { // some sanity checks CHECK_CUDA_INPUT_TENSOR(inp); + CHECK_CUDA_INPUT_TENSOR(weights); CHECK_CUDA_INPUT_TENSOR(roff_idx); CHECK_CUDA_INPUT_TENSOR(ker_idx); CHECK_CUDA_INPUT_TENSOR(row_idx); @@ -989,6 +990,7 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, int64_t Hi = inp.size(2); int64_t Wi = inp.size(3); int64_t nrows = roff_idx.size(0) - 1; + int64_t K = weights.size(3); // rename dimensions consistent with attention int64_t batch_size = inp.size(0); @@ -1186,21 +1188,21 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, // switch to channel-last - // extract dtype + // extract dtype and memory format + bool x_is_channels_last = inp.strides()[1] == 1; auto x_type = inp.dtype(); - torch::Tensor xP = inp.to(torch::kFloat32); - - // exract memory format: this is much safer than checking is_contiguous(at::MemoryFormat::ChannelsLast) - // the former fails for num_channels == 1 - bool x_is_channels_last = xP.strides()[1] == 1; // transpose if required + torch::Tensor xP = inp; if (!x_is_channels_last) { xP = permute_4D_to0231(xP); } + // convert datatype + xP = xP.to(torch::kFloat32); + // to test before fusion int64_t out_dims[] = {batch_size, nlat_out, nlon_out, nchan*K}; - auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype()); - torch::Tensor yP = torch::empty(out_dims, options); + //auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype()); + torch::Tensor yP = torch::empty(out_dims, xP.options()); // call channel-last kernel implementation s2_disco_fwd_dispatch(batch_size, @@ -1216,8 +1218,10 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, val, yP); - // call einsum + // store output of convolution + auto disco_y = yP; + // call einsum // yP is {batch_size, nlat_out, nlon_out, nchan_in*K}, // reshape it to {batch_size, nlat_out, nlon_out, ngroup, chan_x_grp_in*K} auto yP_resh = yP.reshape({batch_size, nlat_out, nlon_out, ngroup, -1}); @@ -1241,12 +1245,18 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, // make y {batch_size, nchan_out, nlat_out, nlon_out} } - //CHECK_CUDA(cudaDeviceSynchronize()); - + // some hacky reshuffling now: + if (!x_is_channels_last) { + disco_y = permute_4D_to0312(yP).reshape({batch_size, ngroup*chan_x_grp_in, K, nlat_out, nlon_out}); + } else { + disco_y = yP.reshape({batch_size, ngroup*chan_x_grp_in, K, nlat_out, nlon_out}); + } + // convert precision back to starting y = y.to(x_type); + disco_y = disco_y.to(x_type); - torch::Tensor out = y; + std::tuple out = std::make_tuple(y, disco_y); #endif #endif // closes ORIGINAL if @@ -1256,7 +1266,7 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, CHECK_CUDA(cudaStreamSynchronize(stream)); printf("done\n"); fflush(stdout); - dump_tensor("yout.txt", out); + //dump_tensor("yout.txt", out); //dump_csr_linear("csr_disco.txt", roff_idx, ker_idx, row_idx, col_idx, val); //dump_out_kers("out_kers", out); } diff --git a/torch_harmonics/disco/csrc/disco_interface.cpp b/torch_harmonics/disco/csrc/disco_interface.cpp index 0eeb122c..8269efb0 100644 --- a/torch_harmonics/disco/csrc/disco_interface.cpp +++ b/torch_harmonics/disco/csrc/disco_interface.cpp @@ -54,8 +54,8 @@ namespace disco_kernels { // Declare the operators TORCH_LIBRARY(disco_kernels, m) { - m.def("forward(Tensor inp, Tensor roff_idx, Tensor ker_idx, Tensor row_idx, Tensor col_idx, Tensor vals, Tensor weights, int kernel_size, int nlat_out, int nlon_out) -> Tensor", {at::Tag::pt2_compliant_tag}); - m.def("backward(Tensor inp, Tensor roff_idx, Tensor ker_idx, Tensor row_idx, Tensor col_idx, Tensor vals, Tensor weights, int kernel_size, int nlat_out, int nlon_out) -> Tensor", {at::Tag::pt2_compliant_tag}); + m.def("forward(Tensor inp, Tensor weights, Tensor roff_idx, Tensor ker_idx, Tensor row_idx, Tensor col_idx, Tensor vals, int nlat_out, int nlon_out) -> (Tensor, Tensor)"); //, {at::Tag::pt2_compliant_tag}); + m.def("backward(Tensor inp, Tensor dinp, Tensor weights, Tensor roff_idx, Tensor ker_idx, Tensor row_idx, Tensor col_idx, Tensor vals, int nlat_out, int nlon_out) -> (Tensor, Tensor)"); //, {at::Tag::pt2_compliant_tag}); } } From 97daca4dac9bac351e3de4bbf7586ce4f4b31656 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Thu, 28 Aug 2025 04:28:18 -0700 Subject: [PATCH 03/31] dbg --- setup.py | 6 +- torch_harmonics/disco/csrc/disco_cuda_fwd.cu | 19 +- .../disco/csrc/disco_cuda_utils.cu | 184 --------- .../disco/csrc/disco_cuda_utils.cuh | 382 ------------------ 4 files changed, 7 insertions(+), 584 deletions(-) delete mode 100644 torch_harmonics/disco/csrc/disco_cuda_utils.cu delete mode 100644 torch_harmonics/disco/csrc/disco_cuda_utils.cuh diff --git a/setup.py b/setup.py index 6d0e4514..cce9dc42 100644 --- a/setup.py +++ b/setup.py @@ -93,6 +93,9 @@ def get_helpers_compile_args(): def get_ext_modules(): """Get list of extension modules to compile.""" + # define setup dir + setup_dir = os.path.abspath(os.path.dirname(__file__)) + ext_modules = [] cmdclass = {} @@ -128,7 +131,7 @@ def get_ext_modules(): if BUILD_CUDA: print(f"Compiling custom CUDA kernels for torch-harmonics.") disco_sources.extend([ - "torch_harmonics/disco/csrc/disco_cuda_utils.cu", + "torch_harmonics/utils/csrc/cuda_utils.cu", "torch_harmonics/disco/csrc/disco_cuda_fwd.cu", "torch_harmonics/disco/csrc/disco_cuda_bwd.cu", ]) @@ -136,6 +139,7 @@ def get_ext_modules(): CUDAExtension( "torch_harmonics.disco._C", disco_sources, + include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")], extra_compile_args=get_compile_args("disco") ) ) diff --git a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu index e108ea60..3cc37797 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu @@ -28,25 +28,10 @@ // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#include "cudamacro.h" #include "disco.h" #include "disco_cuda.cuh" -#include "disco_cuda_utils.cuh" - -#define CHECK_CUDA(call) { \ - cudaError_t err = call; \ - if( cudaSuccess != err) { \ - fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \ - __FILE__, __LINE__, cudaGetErrorString( err) ); \ - exit(EXIT_FAILURE); \ - }} - -#define CHECK_ERROR(errorMessage) { \ - cudaError_t err = cudaGetLastError(); \ - if( cudaSuccess != err) { \ - fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \ - errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) ); \ - exit(EXIT_FAILURE); \ - }} +#include "cuda_utils.cuh" #define THREADS (64) diff --git a/torch_harmonics/disco/csrc/disco_cuda_utils.cu b/torch_harmonics/disco/csrc/disco_cuda_utils.cu deleted file mode 100644 index b8ba3ca4..00000000 --- a/torch_harmonics/disco/csrc/disco_cuda_utils.cu +++ /dev/null @@ -1,184 +0,0 @@ -// coding=utf-8 -// -// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. -// SPDX-License-Identifier: BSD-3-Clause -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -#include "disco_cuda_utils.cuh" - -#include -#include -#include - -#include - -#include -#include - -#include "cudamacro.h" -#include "disco_cuda.cuh" - -#define THREADS (64) - -#define TRANSP_WARPS_X_TILE_GENERIC (32) -#define TRANSP_WARPS_X_TILE_SM100 (4) - -namespace disco_kernels { - -// BEGIN - CSR rows sorting kernels and functions -__global__ void set_rlen_rids_k(const int n, - const int64_t *__restrict__ offs, - int *__restrict__ rids, - int *__restrict__ rlen) { - - const int nth = gridDim.x*blockDim.x; - const int tid = blockIdx.x*blockDim.x + threadIdx.x; - - for(int i = tid; i < n; i += nth) { - rids[i] = i; - rlen[i] = offs[i+1]-offs[i]; - } - - return; -} - -at::Tensor sortRows(int nlat_out, at::Tensor row_off, cudaStream_t stream) { - - int64_t *_row_off_d = reinterpret_cast(row_off.data_ptr()); - - auto options = torch::TensorOptions().dtype(torch::kInt32).device(row_off.device()); - - torch::Tensor rids_d = torch::empty({nlat_out}, options); - torch::Tensor rlen_d = torch::empty({nlat_out}, options); - - int *_rids_d = reinterpret_cast(rids_d.data_ptr()); - int *_rlen_d = reinterpret_cast(rlen_d.data_ptr()); - - const int grid = DIV_UP(nlat_out, THREADS); - const int block = THREADS; - - set_rlen_rids_k<<>>(nlat_out, - _row_off_d, - _rids_d, - _rlen_d); - - torch::Tensor rids_sort_d = torch::empty({nlat_out}, options); - torch::Tensor rlen_sort_d = torch::empty({nlat_out}, options); - - int *_rids_sort_d = reinterpret_cast(rids_sort_d.data_ptr()); - int *_rlen_sort_d = reinterpret_cast(rlen_sort_d.data_ptr()); - - size_t temp_storage_bytes = 0; - CHECK_CUDA(cub::DeviceRadixSort::SortPairsDescending(NULL, temp_storage_bytes, - _rlen_d, _rlen_sort_d, - _rids_d, _rids_sort_d, - nlat_out, 0, sizeof(*_rlen_d)*8, stream)); - - options = torch::TensorOptions().dtype(torch::kByte).device(row_off.device()); - torch::Tensor temp_storage_d = torch::empty({int64_t(temp_storage_bytes)}, options); - - void *_temp_storage_d = reinterpret_cast(temp_storage_d.data_ptr()); - - CHECK_CUDA(cub::DeviceRadixSort::SortPairsDescending(_temp_storage_d, temp_storage_bytes, - _rlen_d, _rlen_sort_d, - _rids_d, _rids_sort_d, - nlat_out, 0, sizeof(*_rlen_d)*8, stream)); - return rids_sort_d; -} -// END - CSR rows sorting kernels and functions - - -// BEGIN - 4D tensor permutation kernels and functions -__global__ void empty_k() {} - -static int getPtxver() { - cudaFuncAttributes attrs; - CHECK_CUDA(cudaFuncGetAttributes(&attrs, empty_k)); - return attrs.ptxVersion*10; -} - -at::Tensor permute_4D_to0231(at::Tensor src) { - - auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device()); - torch::Tensor dst = torch::empty({src.size(0), src.size(2), src.size(3), src.size(1)}, options); - - const int ptxv = getPtxver(); - - // to be further specialized for additional archs, if necessary - if (ptxv < 100) { - AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_generic", ([&] { - launch_permute_to0231(src, dst); - })); - CHECK_ERROR("permute_to0231_k_tile_generic"); - } else { - AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_sm100", ([&] { - launch_permute_to0231(src, dst); - })); - CHECK_ERROR("permute_to0231_k_tile_sm100"); - } - - return dst; -} - -at::Tensor permute_4D_to0312(at::Tensor src) { - - auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device()); - torch::Tensor dst = torch::empty({src.size(0), src.size(3), src.size(1), src.size(2)}, options); - - const int ptxv = getPtxver(); - - // to be further specialized for additional archs, if necessary - if (ptxv < 100) { - AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_generic", ([&] { - launch_permute_to0312(src, dst); - })); - CHECK_ERROR("permute_to0312_k_tile_generic"); - } else { - AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_sm100", ([&] { - launch_permute_to0312(src, dst); - })); - CHECK_ERROR("permute_to0312_k_tile_sm100"); - } - - return dst; -} -// END - tensor permutation kernels and functions - -// BEGIN - general host-side functions -unsigned int next_pow2(unsigned int x) { - - x -= 1; - - #pragma unroll - for(int i = 1; i <= sizeof(x)*8 / 2; i *= 2) { - x |= x >> i; - } - return x+1; -} -// END - general host-side functions - -} diff --git a/torch_harmonics/disco/csrc/disco_cuda_utils.cuh b/torch_harmonics/disco/csrc/disco_cuda_utils.cuh deleted file mode 100644 index 76ccb527..00000000 --- a/torch_harmonics/disco/csrc/disco_cuda_utils.cuh +++ /dev/null @@ -1,382 +0,0 @@ -// coding=utf-8 -// -// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. -// SPDX-License-Identifier: BSD-3-Clause -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -#pragma once - -#include -#include -#include - -#define WARP_SIZE (32) -#define FULL_MASK (0xFFFFFFFF) - -#ifndef DIV_UP -#define DIV_UP(a,b) (((a)+((b)-1))/(b)) -#endif - -namespace disco_kernels { - -// CSR rows sorting kernels and functions -at::Tensor sortRows(int nlat_out, at::Tensor row_off, cudaStream_t stream); - -// 4D tensor permutation kernels and functions -at::Tensor permute_4D_to0231(at::Tensor src); -at::Tensor permute_4D_to0312(at::Tensor src); - -// Host tensor dump and CSR manipulation functions -void dump_tensor(const char *fname, at::Tensor t); -void dump_csr(const char *fname, at::Tensor roff, at::Tensor cols); - -int part_csr_rows(int *row_perm, - const at::Tensor roff, - const at::Tensor cols, - int **part_off, - int **part_val); - -int verify_part(const int npart, - const int *part_off, - const int *part_val, - const at::Tensor roff, - const at::Tensor cols); - -void verify_part_new(const int nlon_out, - const int nlat_in, - const int nlon_in, - const int npart, // partitioning data - const int *part_off, - const int *part_val, - const at::Tensor roff, - const at::Tensor cols); - -unsigned int next_pow2(unsigned int x); - - -// utility host functions and templates - -template -int is_aligned(const void *ptr) { - - static_assert(0 == (ALIGN & (ALIGN-1))); - return (0 == (uintptr_t(ptr) & (ALIGN-1))); -} - - -// utility device functions and templates - -template -__device__ FLOATV_T __vset(float x) { - static_assert(sizeof(FLOATV_T) == 0, "Unsupported type for __vset"); - return FLOATV_T{}; -} - -template<> -__device__ float __forceinline__ __vset(float x) { - return x; -} - -__device__ float __forceinline__ __vmul(float a, float b) { - return a*b; -} - -__device__ float __forceinline__ __vadd(float a, float b) { - return a+b; -} - -__device__ float __forceinline__ __vsub(float a, float b) { - return a-b; -} - -__device__ float __forceinline__ __vred(float a) { - return a; -} - -__device__ float __forceinline__ __vscale(float s, float v) { - return v*s; -} - -__device__ float __forceinline__ __vdiv(float s, float v) { - return v/s; -} - -template<> -__device__ float4 __forceinline__ __vset(float x) { - return make_float4(x, x, x, x); -} - -__device__ float4 __forceinline__ __vmul(float4 a, float4 b) { - return make_float4(a.x*b.x, a.y*b.y, a.z*b.z, a.w*b.w); -} - -__device__ float4 __forceinline__ __vadd(float4 a, float4 b) { - return make_float4(a.x+b.x, a.y+b.y, a.z+b.z, a.w+b.w); -} - -__device__ float4 __forceinline__ __vsub(float4 a, float4 b) { - return make_float4(a.x-b.x, a.y-b.y, a.z-b.z, a.w-b.w); -} - -__device__ float __forceinline__ __vred(float4 a) { - return a.x + a.y + a.z + a.w; -} - -__device__ float4 __forceinline__ __vscale(float s, float4 v) { - return make_float4(s*v.x, s*v.y, s*v.z, s*v.w); -} - -__device__ float4 __forceinline__ __vdiv(float s, float4 v) { - return make_float4(s/v.x, s/v.y, s/v.z, s/v.w);; -} - -template -__device__ VAL_T __warp_sum(VAL_T val) { - - #pragma unroll - for(int i = WARP_SIZE/2; i; i /= 2) { - val += __shfl_xor_sync(FULL_MASK, val, i, WARP_SIZE); - } - return val; -} - -template -__device__ VAL_T __block_sum(VAL_T val) { - - const int NWARP = (BDIM_X*BDIM_Y*BDIM_Z) / WARP_SIZE; - - val = __warp_sum(val); - - if constexpr(NWARP > 1) { - - int tid = threadIdx.x; - if constexpr(BDIM_Y > 1) { tid += threadIdx.y*BDIM_X; } - if constexpr(BDIM_Z > 1) { tid += threadIdx.z*BDIM_X*BDIM_Y; } - - const int lid = tid%WARP_SIZE; - const int wid = tid/WARP_SIZE; - - __shared__ VAL_T sh[NWARP]; - - if (lid == 0) { - sh[wid] = val; - } - __syncthreads(); - - if (wid == 0) { - val = (lid < NWARP) ? sh[lid] : 0; - - val = __warp_sum(val); - __syncwarp(); - - if (!lid) { - sh[0] = val; - } - } - __syncthreads(); - - val = sh[0]; - __syncthreads(); - } - return val; -} - -// transpose utils -template -__global__ -__launch_bounds__(BDIM_X*BDIM_Y) -void permute_to0231_k(const int nchn, - const int nlat, - const int nlon, - const at::PackedTensorAccessor32 src, - at::PackedTensorAccessor32 dst) { - - static_assert(!(BDIM_X & (BDIM_X-1))); - static_assert(!(BDIM_Y & (BDIM_Y-1))); - static_assert(BDIM_X >= BDIM_Y); - - __shared__ VAL_T sh[BDIM_X][BDIM_X+1]; - - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - - const int coff = blockIdx.x*BDIM_X; // channel offset - const int woff = blockIdx.y*BDIM_X; // width offset - const int batch = blockIdx.z / nlat; // batch (same for all block) - const int h = blockIdx.z - (batch * nlat); // height (same for all block) - - const int nchn_full = (nchn-coff) >= BDIM_X; - const int nlon_full = (nlon-woff) >= BDIM_X; - - if (nchn_full && nlon_full) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - sh[j+tidy][tidx] = src[batch][coff + j+tidy][h][woff+tidx]; - } - __syncthreads(); - - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy]; - } - } else { - if (woff+tidx < nlon) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - sh[j+tidy][tidx] = (coff + j+tidy < nchn) ? src[batch][coff + j+tidy][h][woff+tidx] : VAL_T(0); - } - } - __syncthreads(); - - if (coff+tidx < nchn) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - if (woff + j+tidy < nlon) { - dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy]; - } - } - } - } - return; -} - -template -void launch_permute_to0231(at::Tensor src, at::Tensor dst){ - dim3 block; - dim3 grid; - - block.x = WARP_SIZE; - block.y = WARPS_X_TILE; - grid.x = DIV_UP(src.size(1), block.x); - grid.y = DIV_UP(src.size(3), block.x); - grid.z = src.size(2)*src.size(0); - - assert(grid.y < 65536); - assert(grid.z < 65536); - - // get stream - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - permute_to0231_k - <<>>(src.size(1), - src.size(2), - src.size(3), - src.packed_accessor32(), - dst.packed_accessor32()); -} - -template -__global__ -__launch_bounds__(BDIM_X*BDIM_Y) -void permute_to0312_k(const int nchn, - const int nlat, - const int nlon, - const at::PackedTensorAccessor32 src, - at::PackedTensorAccessor32 dst) { - - static_assert(!(BDIM_X & (BDIM_X-1))); - static_assert(!(BDIM_Y & (BDIM_Y-1))); - static_assert(BDIM_X >= BDIM_Y); - - __shared__ VAL_T sh[BDIM_X][BDIM_X+1]; - - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - - const int woff = blockIdx.x*BDIM_X; // width offset - const int coff = blockIdx.y*BDIM_X; // channel offset - const int batch = blockIdx.z / nlat; // batch (same for all block) - const int h = blockIdx.z - (batch * nlat); // height (same for all block) - - const int nchn_full = (nchn-coff) >= BDIM_X; - const int nlon_full = (nlon-woff) >= BDIM_X; - - if (nchn_full && nlon_full) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - sh[j+tidy][tidx] = src[batch][h][woff + j+tidy][coff+tidx]; - } - __syncthreads(); - - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy]; - } - } else { - if (coff+tidx < nchn) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - sh[j+tidy][tidx] = (woff + j+tidy < nlon) ? src[batch][h][woff + j+tidy][coff+tidx] : VAL_T(0); - } - } - __syncthreads(); - - if (woff+tidx < nlon) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - if (coff + j+tidy < nchn) { - dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];; - } - } - } - } - return; -} - -template -void launch_permute_to0312(at::Tensor src, at::Tensor dst){ - dim3 block; - dim3 grid; - - block.x = WARP_SIZE; - block.y = WARPS_X_TILE; - grid.x = DIV_UP(src.size(2), block.x); - grid.y = DIV_UP(src.size(3), block.x); - grid.z = src.size(1)*src.size(0); - - assert(grid.y < 65536); - assert(grid.z < 65536); - - // get stream - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - permute_to0312_k - <<>>(src.size(3), - src.size(1), - src.size(2), - src.packed_accessor32(), - dst.packed_accessor32()); -} - -} From 6509e68ff3777a8b0e3c4864e14efd3ca1ea30a9 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Thu, 28 Aug 2025 07:44:48 -0700 Subject: [PATCH 04/31] compiling code --- setup.py | 42 +- torch_harmonics/attention/__init__.py | 2 +- torch_harmonics/disco/__init__.py | 2 +- torch_harmonics/disco/_disco_utils.py | 45 +- torch_harmonics/disco/convolution.py | 12 +- torch_harmonics/disco/csrc/disco_cuda.cuh | 8 +- torch_harmonics/disco/csrc/disco_cuda_bwd.cu | 63 +-- torch_harmonics/disco/csrc/disco_cuda_fwd.cu | 111 +---- .../disco/csrc/disco_interface.cpp | 4 +- torch_harmonics/utils/__init__.py | 45 ++ torch_harmonics/utils/_utils.py | 76 ++++ torch_harmonics/utils/csrc/cuda_utils.cu | 190 +++++++++ torch_harmonics/utils/csrc/cuda_utils.cuh | 387 ++++++++++++++++++ torch_harmonics/utils/csrc/cudamacro.h | 47 +++ torch_harmonics/utils/csrc/utils_helpers.cpp | 58 +++ .../utils/csrc/utils_interface.cpp | 62 +++ 16 files changed, 982 insertions(+), 172 deletions(-) create mode 100644 torch_harmonics/utils/__init__.py create mode 100644 torch_harmonics/utils/_utils.py create mode 100644 torch_harmonics/utils/csrc/cuda_utils.cu create mode 100644 torch_harmonics/utils/csrc/cuda_utils.cuh create mode 100644 torch_harmonics/utils/csrc/cudamacro.h create mode 100644 torch_harmonics/utils/csrc/utils_helpers.cpp create mode 100644 torch_harmonics/utils/csrc/utils_interface.cpp diff --git a/setup.py b/setup.py index cce9dc42..c93a81c8 100644 --- a/setup.py +++ b/setup.py @@ -100,6 +100,19 @@ def get_ext_modules(): cmdclass = {} print(f"Compiling helper routines for torch-harmonics.") + + # UTILITIES + ext_modules.append( + CppExtension( + "utility_helpers", + [ + "torch_harmonics/utils/csrc/utils_helpers.cpp", + ], + extra_compile_args=get_helpers_compile_args(), + ) + ) + + # DISCO ext_modules.append( CppExtension( "disco_helpers", @@ -121,6 +134,32 @@ def get_ext_modules(): ) if BUILD_CPP: + # HELPERS + utility_sources = [ + "torch_harmonics/utils/csrc/utils_interface.cpp", + ] + + if BUILD_CUDA: + print(f"Compiling custom CUDA kernels for torch-harmonics.") + utility_sources.extend([ + "torch_harmonics/utils/csrc/cuda_utils.cu", + ]) + ext_modules.append( + CUDAExtension( + "torch_harmonics.utils._C", + utility_sources, + extra_compile_args=get_compile_args("utils") + ) + ) + else: + ext_modules.append( + CppExtension( + "torch_harmonics.utils._C", + utility_sources, + extra_compile_args=get_compile_args("utils") + ) + ) + # DISCO # Create a single extension that includes both CPU and CUDA code disco_sources = [ @@ -151,7 +190,6 @@ def get_ext_modules(): extra_compile_args=get_compile_args("disco") ) ) - cmdclass["build_ext"] = BuildExtension # ATTENTION # Create a single extension that includes both CPU and CUDA code @@ -183,6 +221,8 @@ def get_ext_modules(): extra_compile_args=get_compile_args("attention") ) ) + + # set cmdclass cmdclass["build_ext"] = BuildExtension return ext_modules, cmdclass diff --git a/torch_harmonics/attention/__init__.py b/torch_harmonics/attention/__init__.py index ee69cb83..4fea72b9 100644 --- a/torch_harmonics/attention/__init__.py +++ b/torch_harmonics/attention/__init__.py @@ -39,6 +39,6 @@ from torch.ops import attention_kernels else: attention_kernels = None - warnings.warn("No optimized kernels are available. Please compile the extension first setting BUILD_CPP and BUILD_CUDA to 1.") + warnings.warn("No optimized attention kernels are available. Please compile the extension first setting BUILD_CPP and BUILD_CUDA to 1.") from .attention import AttentionS2, NeighborhoodAttentionS2 diff --git a/torch_harmonics/disco/__init__.py b/torch_harmonics/disco/__init__.py index f8f09f14..3ece8042 100644 --- a/torch_harmonics/disco/__init__.py +++ b/torch_harmonics/disco/__init__.py @@ -40,6 +40,6 @@ from torch.ops import disco_kernels else: disco_kernels = None - warnings.warn("No optimized kernels are available. Please compile the extension first setting BUILD_CPP and BUILD_CUDA to 1.") + warnings.warn("No optimized DISCO kernels are available. Please compile the extension first setting BUILD_CPP and BUILD_CUDA to 1.") from .convolution import DiscreteContinuousConvS2 #, DiscreteContinuousConvTransposeS2 diff --git a/torch_harmonics/disco/_disco_utils.py b/torch_harmonics/disco/_disco_utils.py index ed1979a2..97369e66 100644 --- a/torch_harmonics/disco/_disco_utils.py +++ b/torch_harmonics/disco/_disco_utils.py @@ -40,13 +40,11 @@ if optimized_kernels_is_available(): # raw forward fake @torch.library.register_fake("disco_kernels::forward") - def _(inp: torch.Tensor, weights: torch.Tensor, - roff_idx: torch.Tensor, ker_idx: torch.Tensor, + def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, - nlat_out: int, nlon_out: int) -> Tuple[torch.Tensor, torch.Tensor]: - out_shape = (inp.shape[0], weights.shape[0] * weights.shape[1], nlat_out, nlon_out) - dout_shape = (inp.shape[0], weights.shape[0] * weights.shape[2], nlat_out, nlon_out) - return torch.empty(out_shape, dtype=inp.dtype, device=inp.device), torch.empty(dout_shape, dtype=inp.dtype, device=inp.device) + kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + out_shape = (inp.shape[0], nlat_out, nlon_out, kernel_size*inp.shape[3]) + return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) # # raw backward fake # @torch.library.register_fake("disco_kernels::backward") @@ -60,12 +58,11 @@ def _(inp: torch.Tensor, weights: torch.Tensor, # forward @torch.library.custom_op("disco_kernels::_disco_s2_contraction_optimized", mutates_args=()) def _disco_s2_contraction_optimized( - inp: torch.Tensor, weights: torch.Tensor, - roff_idx: torch.Tensor, ker_idx: torch.Tensor, + inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, - nlat_out: int, nlon_out: int) -> Tuple[torch.Tensor, torch.Tensor]: - out, dout = disco_kernels.forward.default(inp, weights, roff_idx, ker_idx, row_idx, col_idx, vals, nlat_out, nlon_out) - return out, dout + kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + out = disco_kernels.forward.default(inp, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out) + return out # # transpose # @torch.library.custom_op("disco_kernels::_disco_s2_transpose_contraction_optimized", mutates_args=()) @@ -81,13 +78,11 @@ def _disco_s2_contraction_optimized( # forward fake @torch.library.register_fake("disco_kernels::_disco_s2_contraction_optimized") - def _(inp: torch.Tensor, weights: torch.Tensor, - roff_idx: torch.Tensor, ker_idx: torch.Tensor, + def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, - nlat_out: int, nlon_out: int) -> Tuple[torch.Tensor, torch.Tensor]: - out_shape = (inp.shape[0], weights.shape[0] * weights.shape[1], nlat_out, nlon_out) - dout_shape = (inp.shape[0], weights.shape[0] * weights.shape[2], nlat_out, nlon_out) - return torch.empty(out_shape, dtype=inp.dtype, device=inp.device), torch.empty(dout_shape, dtype=inp.dtype, device=inp.device) + kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + out_shape = (inp.shape[0], nlat_out, nlon_out, kernel_size*inp.shape[3]) + return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) # # transpose fake # @torch.library.register_fake("disco_kernels::_disco_s2_transpose_contraction_optimized") @@ -99,28 +94,22 @@ def _(inp: torch.Tensor, weights: torch.Tensor, #general routines: this is the same for forward and transpose def _setup_context_conv_backward(ctx, inputs, output): - inp, weights, roff_idx, ker_idx, row_idx, col_idx, vals, _, _ = inputs - _, dinp = output - ctx.save_for_backward(dinp, weights, roff_idx, ker_idx, row_idx, col_idx, vals) + inp, roff_idx, ker_idx, row_idx, col_idx, vals, _, _, _ = inputs + ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals) ctx.nlat_in = inp.shape[-2] ctx.nlon_in = inp.shape[-1] # convolution related def _disco_s2_contraction_bwd_optimized(ctx, grad_output): - dinp, weights, roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors - - print("grad_output", grad_output) - - print("SHAPE CHECK", grad_output.shape, dinp.shape, weights.shape) + roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors if ctx.needs_input_grad[0]: - grad_input, wgrad = disco_kernels.backward.default(grad_output, dinp, weights, roff_idx, ker_idx, row_idx, col_idx, vals, + grad_input = disco_kernels.backward.default(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals, ctx.nlat_in, ctx.nlon_in) # Mauro else: grad_input = None - wgrad = None - return grad_input, wgrad, None, None, None, None, None, None, None # Mauro: added a None for weights + return grad_input, None, None, None, None, None, None, None, None # Mauro: added a None for weights if optimized_kernels_is_available(): torch.library.register_autograd( diff --git a/torch_harmonics/disco/convolution.py b/torch_harmonics/disco/convolution.py index b55ac27a..a1d4a49e 100644 --- a/torch_harmonics/disco/convolution.py +++ b/torch_harmonics/disco/convolution.py @@ -34,18 +34,18 @@ from warnings import warn import math +import xmlrpc import torch import torch.nn as nn import nvtx -import numpy as np - from functools import partial from torch_harmonics.cache import lru_cache from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes +from utils import permute_to_0231, permute_to_0312 from ._disco_utils import _get_psi, _disco_s2_contraction_torch #, _disco_s2_transpose_contraction_torch from ._disco_utils import _disco_s2_contraction_optimized #, _disco_s2_transpose_contraction_optimized from torch_harmonics.filter_basis import FilterBasis, get_filter_basis @@ -514,10 +514,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.optimized_kernel: with nvtx.annotate("_disco_s2_contraction_optimized", color="red"): - out, _ = _disco_s2_contraction_optimized( - x, self.weight, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, + xp = permute_to_0231(x) + xpc = _disco_s2_contraction_optimized( + xp, self.weight, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.nlat_out, self.nlon_out ) + xpc = xpc.reshape(xpc.shape[0], self.nlat_out, self.nlon_out, self.groups, self.groupsize_in, self.kernel_size) + yp = torch.einsum("bxygck,gock->bxygo", xpc, self.weight).reshape(xpc.shape[0], self.nlat_out, self.nlon_out, -1).contiguous() + out = permute_to_0312(yp) else: x = _disco_s2_contraction_torch(x, self.psi.to(x.device), self.nlon_out) diff --git a/torch_harmonics/disco/csrc/disco_cuda.cuh b/torch_harmonics/disco/csrc/disco_cuda.cuh index f7c05708..2d342a24 100644 --- a/torch_harmonics/disco/csrc/disco_cuda.cuh +++ b/torch_harmonics/disco/csrc/disco_cuda.cuh @@ -52,11 +52,11 @@ namespace disco_kernels { // forward kernel - std::tuple disco_cuda_fwd(torch::Tensor inp, torch::Tensor weights, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, - torch::Tensor col_idx, torch::Tensor val, int64_t Ho, int64_t Wo); + torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, + torch::Tensor col_idx, torch::Tensor val, int64_t kernel_size, int64_t Ho, int64_t Wo); // backward kernel - std::tuple disco_cuda_bwd(torch::Tensor ograd, torch::Tensor dinp, torch::Tensor weights, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, - torch::Tensor col_idx, torch::Tensor val, int64_t Ho, int64_t Wo); + torch::Tensor disco_cuda_bwd(torch::Tensor ograd, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, + torch::Tensor col_idx, torch::Tensor val, int64_t kernel_size, int64_t Ho, int64_t Wo); } diff --git a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu index 8b1c2f9c..cda05836 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu @@ -30,9 +30,12 @@ #include "disco.h" #include "disco_cuda.cuh" +#include "cuda_utils.cuh" namespace disco_kernels { +using namespace utility_kernels; + template __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int Ho, const int Wo, const int pscale, const int64_t *__restrict__ roff, const int64_t *__restrict__ kers, @@ -198,14 +201,12 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t return; } - std::tuple disco_cuda_bwd(torch::Tensor inp, torch::Tensor dinp, torch::Tensor weights, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, - torch::Tensor col_idx, torch::Tensor val, int64_t Ho, int64_t Wo) + torch::Tensor disco_cuda_bwd(torch::Tensor ograd, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, + torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo) { // some sanity checks - CHECK_CUDA_INPUT_TENSOR(inp); - CHECK_CUDA_INPUT_TENSOR(dinp); - CHECK_CUDA_INPUT_TENSOR(weights); + CHECK_CUDA_INPUT_TENSOR(ograd); CHECK_CUDA_INPUT_TENSOR(roff_idx); CHECK_CUDA_INPUT_TENSOR(ker_idx); CHECK_CUDA_INPUT_TENSOR(row_idx); @@ -213,44 +214,24 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t CHECK_CUDA_INPUT_TENSOR(val); // extract some shapes - int64_t B = inp.size(0); - int64_t Cin = inp.size(1); - int64_t Cout = weights.size(0) * weights.size(2); - int64_t K = weights.size(3); - int64_t BC = B * Cout; - int64_t Hi = inp.size(2); - int64_t Wi = inp.size(3); + int64_t B = ograd.size(0); + int64_t Hi = ograd.size(1); + int64_t Wi = ograd.size(2); + int64_t C = ograd.size(3); + int64_t BC = B * C; int64_t nrows = roff_idx.size(0) - 1; // allocate output - int64_t out_dims[] = {B, Cout, Ho, Wo}; + int64_t out_dims[] = {B, Ho, Wo, C}; // get stream auto stream = at::cuda::getCurrentCUDAStream().stream(); - // print shapes - std::cout << "dinp.dim(): " << dinp.dim() << std::endl; - for (int i=0; ibgikxy", {inp_resh, weights}).reshape({B, Cout, K, Hi, Wi}).contiguous(); - auto wgrad = torch::einsum("bgixy,bgokxy->giok", {inp_resh, dinp_resh}).to(torch::kFloat32).contiguous(); - // extract dtype - auto x_type = x.dtype(); - torch::Tensor xP = x.to(torch::kFloat32); + auto x_type = ograd.dtype(); + torch::Tensor xP = ograd.to(torch::kFloat32); - torch::Tensor out = torch::zeros(out_dims, xP.options()); + torch::Tensor igrad = torch::zeros(out_dims, xP.options()); // assert static_assert(0 == (ELXTH_MAX % 2)); @@ -261,7 +242,7 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), ker_idx.data_ptr(), row_idx.data_ptr(), col_idx.data_ptr(), val.data_ptr(), - xP.data_ptr(), out.data_ptr(), stream); + xP.data_ptr(), igrad.data_ptr(), stream); })); } else if (Wo <= 128 * ELXTH_MAX) { AT_DISPATCH_FLOATING_TYPES(xP.scalar_type(), "disco_backward_cuda", ([&] { @@ -269,7 +250,7 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), ker_idx.data_ptr(), row_idx.data_ptr(), col_idx.data_ptr(), val.data_ptr(), - xP.data_ptr(), out.data_ptr(), stream); + xP.data_ptr(), igrad.data_ptr(), stream); })); } else if (Wo <= 256 * ELXTH_MAX) { AT_DISPATCH_FLOATING_TYPES(xP.scalar_type(), "disco_backward_cuda", ([&] { @@ -277,7 +258,7 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), ker_idx.data_ptr(), row_idx.data_ptr(), col_idx.data_ptr(), val.data_ptr(), - xP.data_ptr(), out.data_ptr(), stream); + xP.data_ptr(), igrad.data_ptr(), stream); })); } else if (Wo <= 512 * ELXTH_MAX) { AT_DISPATCH_FLOATING_TYPES(xP.scalar_type(), "disco_backward_cuda", ([&] { @@ -285,7 +266,7 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), ker_idx.data_ptr(), row_idx.data_ptr(), col_idx.data_ptr(), val.data_ptr(), - xP.data_ptr(), out.data_ptr(), stream); + xP.data_ptr(), igrad.data_ptr(), stream); })); } else if (Wo <= 1024 * ELXTH_MAX) { AT_DISPATCH_FLOATING_TYPES(xP.scalar_type(), "disco_backward_cuda", ([&] { @@ -293,7 +274,7 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), ker_idx.data_ptr(), row_idx.data_ptr(), col_idx.data_ptr(), val.data_ptr(), - xP.data_ptr(), out.data_ptr(), stream); + xP.data_ptr(), igrad.data_ptr(), stream); })); } else { fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo, @@ -302,9 +283,9 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t } // convert back to original dtype - out = out.to(x_type); + igrad = igrad.to(x_type); - return std::make_tuple(out, wgrad); + return igrad; } TORCH_LIBRARY_IMPL(disco_kernels, CUDA, m) diff --git a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu index 3cc37797..778764fe 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu @@ -39,6 +39,8 @@ namespace disco_kernels { +using namespace utility_kernels; + void dump_tensor(const char *fname, at::Tensor t); void dump_csr(const char *fname, at::Tensor roff, at::Tensor cols); void dump_csr_linear(const char *fname, at::Tensor roff, at::Tensor kers, at::Tensor rows, at::Tensor cols, at::Tensor vals); @@ -955,33 +957,31 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, return; } - std::tuple disco_cuda_fwd(torch::Tensor inp, torch::Tensor weights, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, - torch::Tensor col_idx, torch::Tensor val, int64_t Ho, int64_t Wo) + torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, + torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo) { // some sanity checks CHECK_CUDA_INPUT_TENSOR(inp); - CHECK_CUDA_INPUT_TENSOR(weights); CHECK_CUDA_INPUT_TENSOR(roff_idx); CHECK_CUDA_INPUT_TENSOR(ker_idx); CHECK_CUDA_INPUT_TENSOR(row_idx); CHECK_CUDA_INPUT_TENSOR(col_idx); CHECK_CUDA_INPUT_TENSOR(val); - // extract some shapes + // assume input is B, H, W, C int64_t B = inp.size(0); - int64_t C = inp.size(1); - int64_t BC = B * C; - int64_t Hi = inp.size(2); - int64_t Wi = inp.size(3); + int64_t Hi = inp.size(1); + int64_t Wi = inp.size(2); + int64_t C = inp.size(3); + //int64_t BC = B * C; int64_t nrows = roff_idx.size(0) - 1; - int64_t K = weights.size(3); // rename dimensions consistent with attention - int64_t batch_size = inp.size(0); - int64_t nchan = inp.size(1); - int64_t nlat_in = inp.size(2); - int64_t nlon_in = inp.size(3); + int64_t batch_size = B; + int64_t nchan = C; + int64_t nlat_in = Hi; + int64_t nlon_in = Wi; int64_t nlat_out = Ho; int64_t nlon_out = Wo; /* @@ -1126,11 +1126,11 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, // switch back to original layout; // I'm assuming that if x was passed as channel last, then // the output tensor should be K last - torch::Tensor y = yP; - if (!x_is_channels_last) { - y = permute_4D_to0312(y); - // make y {batch_size, nchan_out, nlat_out, nlon_out} - } + //torch::Tensor y = yP; + //if (!x_is_channels_last) { + // y = permute_4D_to0312(y); + // // make y {batch_size, nchan_out, nlat_out, nlon_out} + //} #else // to test before fusion torch::Tensor y = yP; @@ -1148,41 +1148,8 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, // switch to channel-last // version with fused enisum - int64_t ngroup = weights.size(0); - int64_t chan_x_grp_out = weights.size(1); - int64_t chan_x_grp_in = weights.size(2); - int64_t weight_k = weights.size(3); - - int64_t nchan_out = ngroup*chan_x_grp_out; - - printf("weight tensor shape: %ld, %ld, %ld, %ld\n", ngroup, chan_x_grp_out, chan_x_grp_in, weight_k); fflush(stdout); - - if (nchan != chan_x_grp_in*ngroup || K != weight_k) { - fprintf(stderr, - "%s:%d: error, dimension mismatch for weight tensor!\n", - __func__, __LINE__); - exit(EXIT_FAILURE); - } - - - // input: inp[B][Ci][Hi][Wi] -> inp[B][Hi][Wi][Ci] - // - // output: out[[B][Ho][Wo][Co] -> out[B][Co][Ho][Wo] - // with Co = ngroup*chan_x_grp_out - - - // switch to channel-last - - // extract dtype and memory format - bool x_is_channels_last = inp.strides()[1] == 1; auto x_type = inp.dtype(); - - // transpose if required - torch::Tensor xP = inp; - if (!x_is_channels_last) { xP = permute_4D_to0231(xP); } - - // convert datatype - xP = xP.to(torch::kFloat32); + auto xP = inp.to(torch::kFloat32); // to test before fusion int64_t out_dims[] = {batch_size, nlat_out, nlon_out, nchan*K}; @@ -1203,45 +1170,9 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, val, yP); - // store output of convolution - auto disco_y = yP; - - // call einsum - // yP is {batch_size, nlat_out, nlon_out, nchan_in*K}, - // reshape it to {batch_size, nlat_out, nlon_out, ngroup, chan_x_grp_in*K} - auto yP_resh = yP.reshape({batch_size, nlat_out, nlon_out, ngroup, -1}); - - // weight is {ngroup, chan_x_grp_out, chan_x_grp_in, K} - // reshape weight to {ngroup, chan_x_grp_out, chan_x_grp_in*K} - auto weights_resh = weights.reshape({ngroup, chan_x_grp_out, -1}); - - auto out_sum = torch::einsum("bxygc,goc->bxygo", {yP_resh, weights_resh}).contiguous(); - - // out is {batch_size, nlat_out, nlon_out, ngroup, chan_x_grp_out}, - // reshape it ot {batch_size, nlat_out, nlon_out, nchan_out}, - auto out_resh = out_sum.reshape({batch_size, nlat_out, nlon_out, -1}); - - // switch back to original layout; - // I'm assuming that if x was passed as channel last, then - // the output tensor should be K last - torch::Tensor y = out_resh; - if (!x_is_channels_last) { - y = permute_4D_to0312(y); - // make y {batch_size, nchan_out, nlat_out, nlon_out} - } - - // some hacky reshuffling now: - if (!x_is_channels_last) { - disco_y = permute_4D_to0312(yP).reshape({batch_size, ngroup*chan_x_grp_in, K, nlat_out, nlon_out}); - } else { - disco_y = yP.reshape({batch_size, ngroup*chan_x_grp_in, K, nlat_out, nlon_out}); - } - - // convert precision back to starting - y = y.to(x_type); - disco_y = disco_y.to(x_type); + auto y = yP.to(x_type); - std::tuple out = std::make_tuple(y, disco_y); + torch::Tensor out = y; #endif #endif // closes ORIGINAL if diff --git a/torch_harmonics/disco/csrc/disco_interface.cpp b/torch_harmonics/disco/csrc/disco_interface.cpp index 8269efb0..11bf0249 100644 --- a/torch_harmonics/disco/csrc/disco_interface.cpp +++ b/torch_harmonics/disco/csrc/disco_interface.cpp @@ -54,8 +54,8 @@ namespace disco_kernels { // Declare the operators TORCH_LIBRARY(disco_kernels, m) { - m.def("forward(Tensor inp, Tensor weights, Tensor roff_idx, Tensor ker_idx, Tensor row_idx, Tensor col_idx, Tensor vals, int nlat_out, int nlon_out) -> (Tensor, Tensor)"); //, {at::Tag::pt2_compliant_tag}); - m.def("backward(Tensor inp, Tensor dinp, Tensor weights, Tensor roff_idx, Tensor ker_idx, Tensor row_idx, Tensor col_idx, Tensor vals, int nlat_out, int nlon_out) -> (Tensor, Tensor)"); //, {at::Tag::pt2_compliant_tag}); + m.def("forward(Tensor inp, Tensor roff_idx, Tensor ker_idx, Tensor row_idx, Tensor col_idx, Tensor vals, int kernel_size, int nlat_out, int nlon_out) -> Tensor"); //, {at::Tag::pt2_compliant_tag}); + m.def("backward(Tensor inp, Tensor roff_idx, Tensor ker_idx, Tensor row_idx, Tensor col_idx, Tensor vals, int kernel_size, int nlat_out, int nlon_out) -> Tensor"); //, {at::Tag::pt2_compliant_tag}); } } diff --git a/torch_harmonics/utils/__init__.py b/torch_harmonics/utils/__init__.py new file mode 100644 index 00000000..1acbd722 --- /dev/null +++ b/torch_harmonics/utils/__init__.py @@ -0,0 +1,45 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import warnings +import torch + +# we need those helpers +from utility_helpers import optimized_kernels_is_available + +if optimized_kernels_is_available(): + from . import _C + from torch.ops import utility_kernels +else: + utility_kernels = None + warnings.warn("No optimized utility kernels are available. Please compile the extension first setting BUILD_CPP and BUILD_CUDA to 1.") + +from ._utils import permute_to_0231, permute_to_0312 \ No newline at end of file diff --git a/torch_harmonics/utils/_utils.py b/torch_harmonics/utils/_utils.py new file mode 100644 index 00000000..2141d9cc --- /dev/null +++ b/torch_harmonics/utils/_utils.py @@ -0,0 +1,76 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from typing import Optional + +import torch +from utility_helpers import optimized_kernels_is_available +from . import utility_kernels + +# custom kernels +if optimized_kernels_is_available(): + + # fake permutations + @torch.library.register_fake("utility_kernels::permute_0231") + def _(inp: torch.Tensor) -> torch.Tensor: + B, C, H, W = inp.shape + out_shape = (B, H, W, C) + return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) + + @torch.library.register_fake("utility_kernels::permute_0312") + def _(inp: torch.Tensor) -> torch.Tensor: + B, H, W, C = inp.shape + out_shape = (B, C, H, W) + return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) + + # autograds + torch.library.register_autograd( + "utility_kernels::permute_0231", utility_kernels.permute_0312) + + torch.library.register_autograd( + "utility_kernels::permute_0312", utility_kernels.permute_0231) + + +def permute_to_0231(inp: torch.Tensor) -> torch.Tensor: + if optimized_kernels_is_available() and inp.is_cuda: + out = utility_kernels.permute_0231.default(inp) + else: + out = inp.permute(0, 2, 3, 1).contiguous() + return out + +def permute_to_0312(inp: torch.Tensor) -> torch.Tensor: + if optimized_kernels_is_available() and inp.is_cuda: + out = utility_kernels.permute_0312.default(inp) + else: + out = inp.permute(0, 3, 1, 2).contiguous() + return out + + diff --git a/torch_harmonics/utils/csrc/cuda_utils.cu b/torch_harmonics/utils/csrc/cuda_utils.cu new file mode 100644 index 00000000..0a1e3bac --- /dev/null +++ b/torch_harmonics/utils/csrc/cuda_utils.cu @@ -0,0 +1,190 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +//#include +//#include +//#include +#include + +#include + +#include +#include + +#include "cudamacro.h" +#include "cuda_utils.cuh" + +#define THREADS (64) + +#define TRANSP_WARPS_X_TILE_GENERIC (32) +#define TRANSP_WARPS_X_TILE_SM100 (4) + +namespace utility_kernels { + +// BEGIN - CSR rows sorting kernels and functions +__global__ void set_rlen_rids_k(const int n, + const int64_t *__restrict__ offs, + int *__restrict__ rids, + int *__restrict__ rlen) { + + const int nth = gridDim.x*blockDim.x; + const int tid = blockIdx.x*blockDim.x + threadIdx.x; + + for(int i = tid; i < n; i += nth) { + rids[i] = i; + rlen[i] = offs[i+1]-offs[i]; + } + + return; +} + +torch::Tensor sortRows(int nlat_out, torch::Tensor row_off, cudaStream_t stream) { + + int64_t *_row_off_d = reinterpret_cast(row_off.data_ptr()); + + auto options = torch::TensorOptions().dtype(torch::kInt32).device(row_off.device()); + + torch::Tensor rids_d = torch::empty({nlat_out}, options); + torch::Tensor rlen_d = torch::empty({nlat_out}, options); + + int *_rids_d = reinterpret_cast(rids_d.data_ptr()); + int *_rlen_d = reinterpret_cast(rlen_d.data_ptr()); + + const int grid = DIV_UP(nlat_out, THREADS); + const int block = THREADS; + + set_rlen_rids_k<<>>(nlat_out, + _row_off_d, + _rids_d, + _rlen_d); + + torch::Tensor rids_sort_d = torch::empty({nlat_out}, options); + torch::Tensor rlen_sort_d = torch::empty({nlat_out}, options); + + int *_rids_sort_d = reinterpret_cast(rids_sort_d.data_ptr()); + int *_rlen_sort_d = reinterpret_cast(rlen_sort_d.data_ptr()); + + size_t temp_storage_bytes = 0; + CHECK_CUDA(cub::DeviceRadixSort::SortPairsDescending(NULL, temp_storage_bytes, + _rlen_d, _rlen_sort_d, + _rids_d, _rids_sort_d, + nlat_out, 0, sizeof(*_rlen_d)*8, stream)); + + options = torch::TensorOptions().dtype(torch::kByte).device(row_off.device()); + torch::Tensor temp_storage_d = torch::empty({int64_t(temp_storage_bytes)}, options); + + void *_temp_storage_d = reinterpret_cast(temp_storage_d.data_ptr()); + + CHECK_CUDA(cub::DeviceRadixSort::SortPairsDescending(_temp_storage_d, temp_storage_bytes, + _rlen_d, _rlen_sort_d, + _rids_d, _rids_sort_d, + nlat_out, 0, sizeof(*_rlen_d)*8, stream)); + return rids_sort_d; +} +// END - CSR rows sorting kernels and functions + + +// BEGIN - 4D tensor permutation kernels and functions +__global__ void empty_k() {} + +static int getPtxver() { + cudaFuncAttributes attrs; + CHECK_CUDA(cudaFuncGetAttributes(&attrs, empty_k)); + return attrs.ptxVersion*10; +} + +at::Tensor permute_4D_to0231(at::Tensor src) { + + auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device()); + torch::Tensor dst = torch::empty({src.size(0), src.size(2), src.size(3), src.size(1)}, options); + + const int ptxv = getPtxver(); + + // to be further specialized for additional archs, if necessary + if (ptxv < 100) { + AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_generic", ([&] { + launch_permute_to0231(src, dst); + })); + CHECK_ERROR("permute_to0231_k_tile_generic"); + } else { + AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_sm100", ([&] { + launch_permute_to0231(src, dst); + })); + CHECK_ERROR("permute_to0231_k_tile_sm100"); + } + + return dst; +} + +at::Tensor permute_4D_to0312(at::Tensor src) { + + auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device()); + torch::Tensor dst = torch::empty({src.size(0), src.size(3), src.size(1), src.size(2)}, options); + + const int ptxv = getPtxver(); + + // to be further specialized for additional archs, if necessary + if (ptxv < 100) { + AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_generic", ([&] { + launch_permute_to0312(src, dst); + })); + CHECK_ERROR("permute_to0312_k_tile_generic"); + } else { + AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_sm100", ([&] { + launch_permute_to0312(src, dst); + })); + CHECK_ERROR("permute_to0312_k_tile_sm100"); + } + + return dst; +} +// END - tensor permutation kernels and functions + +// BEGIN - general host-side functions +unsigned int next_pow2(unsigned int x) { + + x -= 1; + + #pragma unroll + for(int i = 1; i <= sizeof(x)*8 / 2; i *= 2) { + x |= x >> i; + } + return x+1; +} +// END - general host-side functions + +TORCH_LIBRARY_IMPL(utils_kernels, CUDA, m) + { + m.impl("permute_0231", &permute_4D_to0231); + m.impl("permute_0312", &permute_4D_to0312); + } + +} diff --git a/torch_harmonics/utils/csrc/cuda_utils.cuh b/torch_harmonics/utils/csrc/cuda_utils.cuh new file mode 100644 index 00000000..54d77a95 --- /dev/null +++ b/torch_harmonics/utils/csrc/cuda_utils.cuh @@ -0,0 +1,387 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +// +//#include +//#include + +#include +#include +#include +#include + +#define WARP_SIZE (32) +#define FULL_MASK (0xFFFFFFFF) + +#ifndef DIV_UP +#define DIV_UP(a,b) (((a)+((b)-1))/(b)) +#endif + +namespace utility_kernels { + +// CSR rows sorting kernels and functions +torch::Tensor sortRows(int nlat_out, torch::Tensor row_off, cudaStream_t stream); + +// 4D tensor permutation kernels and functions +torch::Tensor permute_4D_to0231(torch::Tensor src); +torch::Tensor permute_4D_to0312(torch::Tensor src); + +// Host tensor dump and CSR manipulation functions +void dump_tensor(const char *fname, torch::Tensor t); +void dump_csr(const char *fname, torch::Tensor roff, torch::Tensor cols); + +int part_csr_rows(int *row_perm, + const torch::Tensor roff, + const torch::Tensor cols, + int **part_off, + int **part_val); + +int verify_part(const int npart, + const int *part_off, + const int *part_val, + const torch::Tensor roff, + const torch::Tensor cols); + +void verify_part_new(const int nlon_out, + const int nlat_in, + const int nlon_in, + const int npart, // partitioning data + const int *part_off, + const int *part_val, + const torch::Tensor roff, + const torch::Tensor cols); + +unsigned int next_pow2(unsigned int x); + + +// utility host functions and templates + +template +int is_aligned(const void *ptr) { + + static_assert(0 == (ALIGN & (ALIGN-1))); + return (0 == (uintptr_t(ptr) & (ALIGN-1))); +} + + +// utility device functions and templates + +template +__device__ FLOATV_T __vset(float x) { + static_assert(sizeof(FLOATV_T) == 0, "Unsupported type for __vset"); + return FLOATV_T{}; +} + +template<> +__device__ float __forceinline__ __vset(float x) { + return x; +} + +__device__ float __forceinline__ __vmul(float a, float b) { + return a*b; +} + +__device__ float __forceinline__ __vadd(float a, float b) { + return a+b; +} + +__device__ float __forceinline__ __vsub(float a, float b) { + return a-b; +} + +__device__ float __forceinline__ __vred(float a) { + return a; +} + +__device__ float __forceinline__ __vscale(float s, float v) { + return v*s; +} + +__device__ float __forceinline__ __vdiv(float s, float v) { + return v/s; +} + +template<> +__device__ float4 __forceinline__ __vset(float x) { + return make_float4(x, x, x, x); +} + +__device__ float4 __forceinline__ __vmul(float4 a, float4 b) { + return make_float4(a.x*b.x, a.y*b.y, a.z*b.z, a.w*b.w); +} + +__device__ float4 __forceinline__ __vadd(float4 a, float4 b) { + return make_float4(a.x+b.x, a.y+b.y, a.z+b.z, a.w+b.w); +} + +__device__ float4 __forceinline__ __vsub(float4 a, float4 b) { + return make_float4(a.x-b.x, a.y-b.y, a.z-b.z, a.w-b.w); +} + +__device__ float __forceinline__ __vred(float4 a) { + return a.x + a.y + a.z + a.w; +} + +__device__ float4 __forceinline__ __vscale(float s, float4 v) { + return make_float4(s*v.x, s*v.y, s*v.z, s*v.w); +} + +__device__ float4 __forceinline__ __vdiv(float s, float4 v) { + return make_float4(s/v.x, s/v.y, s/v.z, s/v.w);; +} + +template +__device__ VAL_T __warp_sum(VAL_T val) { + + #pragma unroll + for(int i = WARP_SIZE/2; i; i /= 2) { + val += __shfl_xor_sync(FULL_MASK, val, i, WARP_SIZE); + } + return val; +} + +template +__device__ VAL_T __block_sum(VAL_T val) { + + const int NWARP = (BDIM_X*BDIM_Y*BDIM_Z) / WARP_SIZE; + + val = __warp_sum(val); + + if constexpr(NWARP > 1) { + + int tid = threadIdx.x; + if constexpr(BDIM_Y > 1) { tid += threadIdx.y*BDIM_X; } + if constexpr(BDIM_Z > 1) { tid += threadIdx.z*BDIM_X*BDIM_Y; } + + const int lid = tid%WARP_SIZE; + const int wid = tid/WARP_SIZE; + + __shared__ VAL_T sh[NWARP]; + + if (lid == 0) { + sh[wid] = val; + } + __syncthreads(); + + if (wid == 0) { + val = (lid < NWARP) ? sh[lid] : 0; + + val = __warp_sum(val); + __syncwarp(); + + if (!lid) { + sh[0] = val; + } + } + __syncthreads(); + + val = sh[0]; + __syncthreads(); + } + return val; +} + +// transpose utils +template +__global__ +__launch_bounds__(BDIM_X*BDIM_Y) +void permute_to0231_k(const int nchn, + const int nlat, + const int nlon, + const at::PackedTensorAccessor32 src, + at::PackedTensorAccessor32 dst) { + + static_assert(!(BDIM_X & (BDIM_X-1))); + static_assert(!(BDIM_Y & (BDIM_Y-1))); + static_assert(BDIM_X >= BDIM_Y); + + __shared__ VAL_T sh[BDIM_X][BDIM_X+1]; + + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + + const int coff = blockIdx.x*BDIM_X; // channel offset + const int woff = blockIdx.y*BDIM_X; // width offset + const int batch = blockIdx.z / nlat; // batch (same for all block) + const int h = blockIdx.z - (batch * nlat); // height (same for all block) + + const int nchn_full = (nchn-coff) >= BDIM_X; + const int nlon_full = (nlon-woff) >= BDIM_X; + + if (nchn_full && nlon_full) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + sh[j+tidy][tidx] = src[batch][coff + j+tidy][h][woff+tidx]; + } + __syncthreads(); + + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy]; + } + } else { + if (woff+tidx < nlon) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + sh[j+tidy][tidx] = (coff + j+tidy < nchn) ? src[batch][coff + j+tidy][h][woff+tidx] : VAL_T(0); + } + } + __syncthreads(); + + if (coff+tidx < nchn) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + if (woff + j+tidy < nlon) { + dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy]; + } + } + } + } + return; +} + +template +void launch_permute_to0231(torch::Tensor src, torch::Tensor dst){ + dim3 block; + dim3 grid; + + block.x = WARP_SIZE; + block.y = WARPS_X_TILE; + grid.x = DIV_UP(src.size(1), block.x); + grid.y = DIV_UP(src.size(3), block.x); + grid.z = src.size(2)*src.size(0); + + assert(grid.y < 65536); + assert(grid.z < 65536); + + // get stream + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + permute_to0231_k + <<>>(src.size(1), + src.size(2), + src.size(3), + src.packed_accessor32(), + dst.packed_accessor32()); +} + +template +__global__ +__launch_bounds__(BDIM_X*BDIM_Y) +void permute_to0312_k(const int nchn, + const int nlat, + const int nlon, + const at::PackedTensorAccessor32 src, + at::PackedTensorAccessor32 dst) { + + static_assert(!(BDIM_X & (BDIM_X-1))); + static_assert(!(BDIM_Y & (BDIM_Y-1))); + static_assert(BDIM_X >= BDIM_Y); + + __shared__ VAL_T sh[BDIM_X][BDIM_X+1]; + + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + + const int woff = blockIdx.x*BDIM_X; // width offset + const int coff = blockIdx.y*BDIM_X; // channel offset + const int batch = blockIdx.z / nlat; // batch (same for all block) + const int h = blockIdx.z - (batch * nlat); // height (same for all block) + + const int nchn_full = (nchn-coff) >= BDIM_X; + const int nlon_full = (nlon-woff) >= BDIM_X; + + if (nchn_full && nlon_full) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + sh[j+tidy][tidx] = src[batch][h][woff + j+tidy][coff+tidx]; + } + __syncthreads(); + + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy]; + } + } else { + if (coff+tidx < nchn) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + sh[j+tidy][tidx] = (woff + j+tidy < nlon) ? src[batch][h][woff + j+tidy][coff+tidx] : VAL_T(0); + } + } + __syncthreads(); + + if (woff+tidx < nlon) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + if (coff + j+tidy < nchn) { + dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];; + } + } + } + } + return; +} + +template +void launch_permute_to0312(torch::Tensor src, torch::Tensor dst){ + dim3 block; + dim3 grid; + + block.x = WARP_SIZE; + block.y = WARPS_X_TILE; + grid.x = DIV_UP(src.size(2), block.x); + grid.y = DIV_UP(src.size(3), block.x); + grid.z = src.size(1)*src.size(0); + + assert(grid.y < 65536); + assert(grid.z < 65536); + + // get stream + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + permute_to0312_k + <<>>(src.size(3), + src.size(1), + src.size(2), + src.packed_accessor32(), + dst.packed_accessor32()); +} + +} diff --git a/torch_harmonics/utils/csrc/cudamacro.h b/torch_harmonics/utils/csrc/cudamacro.h new file mode 100644 index 00000000..0edef184 --- /dev/null +++ b/torch_harmonics/utils/csrc/cudamacro.h @@ -0,0 +1,47 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#define CHECK_CUDA(call) { \ + cudaError_t err = call; \ + if( cudaSuccess != err) { \ + fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \ + __FILE__, __LINE__, cudaGetErrorString( err) ); \ + exit(EXIT_FAILURE); \ + }} + +#define CHECK_ERROR(errorMessage) { \ + cudaError_t err = cudaGetLastError(); \ + if( cudaSuccess != err) { \ + fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \ + errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) );\ + exit(EXIT_FAILURE); \ + }} diff --git a/torch_harmonics/utils/csrc/utils_helpers.cpp b/torch_harmonics/utils/csrc/utils_helpers.cpp new file mode 100644 index 00000000..d700ef6a --- /dev/null +++ b/torch_harmonics/utils/csrc/utils_helpers.cpp @@ -0,0 +1,58 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include + +// set default values for BUILD_CPP and BUILD_CUDA +#ifndef BUILD_CPP +#define BUILD_CPP 0 +#endif + +#ifndef BUILD_CUDA +#define BUILD_CUDA 0 +#endif + +bool cpp_kernels_is_available() { + return static_cast(BUILD_CPP); +} + +bool cuda_kernels_is_available() { + return static_cast(BUILD_CUDA); +} + +bool optimized_kernels_is_available() { + return cuda_kernels_is_available() || cpp_kernels_is_available(); +} + +PYBIND11_MODULE(utility_helpers, m) +{ + m.def("optimized_kernels_is_available", &optimized_kernels_is_available, "Check if optimized kernels (CUDA or C++) are available."); +} + diff --git a/torch_harmonics/utils/csrc/utils_interface.cpp b/torch_harmonics/utils/csrc/utils_interface.cpp new file mode 100644 index 00000000..ecad921e --- /dev/null +++ b/torch_harmonics/utils/csrc/utils_interface.cpp @@ -0,0 +1,62 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include +#include + +extern "C" { + /* Creates a dummy empty _C module that can be imported from Python. + The import from Python will load the .so consisting of this file + in this extension, so that the TORCH_LIBRARY static initializers + below are run. */ + PyMODINIT_FUNC PyInit__C(void) + { + static struct PyModuleDef module_def = { + PyModuleDef_HEAD_INIT, + "_C", /* name of module */ + NULL, /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + NULL, /* methods */ + }; + return PyModule_Create(&module_def); + } +} + +namespace utility_kernels { + + // Declare the operators + TORCH_LIBRARY(utility_kernels, m) { + m.def("permute_0231(Tensor inp) -> Tensor"); //, {at::Tag::pt2_compliant_tag}); + m.def("permute_0312(Tensor inp) -> Tensor"); + } + +} + From 67389648827f771b8786b98a5adc72f96b64163f Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Thu, 28 Aug 2025 08:58:12 -0700 Subject: [PATCH 05/31] almost working refactor --- torch_harmonics/disco/convolution.py | 2 +- torch_harmonics/utils/csrc/cuda_utils.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_harmonics/disco/convolution.py b/torch_harmonics/disco/convolution.py index a1d4a49e..2d265ce5 100644 --- a/torch_harmonics/disco/convolution.py +++ b/torch_harmonics/disco/convolution.py @@ -45,7 +45,7 @@ from torch_harmonics.cache import lru_cache from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes -from utils import permute_to_0231, permute_to_0312 +from torch_harmonics.utils import permute_to_0231, permute_to_0312 from ._disco_utils import _get_psi, _disco_s2_contraction_torch #, _disco_s2_transpose_contraction_torch from ._disco_utils import _disco_s2_contraction_optimized #, _disco_s2_transpose_contraction_optimized from torch_harmonics.filter_basis import FilterBasis, get_filter_basis diff --git a/torch_harmonics/utils/csrc/cuda_utils.cu b/torch_harmonics/utils/csrc/cuda_utils.cu index 0a1e3bac..40a1f270 100644 --- a/torch_harmonics/utils/csrc/cuda_utils.cu +++ b/torch_harmonics/utils/csrc/cuda_utils.cu @@ -181,7 +181,7 @@ unsigned int next_pow2(unsigned int x) { } // END - general host-side functions -TORCH_LIBRARY_IMPL(utils_kernels, CUDA, m) +TORCH_LIBRARY_IMPL(utility_kernels, CUDA, m) { m.impl("permute_0231", &permute_4D_to0231); m.impl("permute_0312", &permute_4D_to0312); From 7389a95e0e9069e68d988616273a02f56492b2ba Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Fri, 29 Aug 2025 03:27:42 -0700 Subject: [PATCH 06/31] fixed a lot of issues --- setup.py | 5 +- tests/test_permute.py | 116 ++++++ torch_harmonics/disco/convolution.py | 4 +- torch_harmonics/disco/csrc/disco_cuda_bwd.cu | 2 +- torch_harmonics/disco/csrc/disco_cuda_fwd.cu | 2 +- torch_harmonics/utils/_utils.py | 40 +- .../utils/csrc/{cuda_utils.cu => csr_cuda.cu} | 64 +-- torch_harmonics/utils/csrc/csr_cuda.cuh | 213 ++++++++++ torch_harmonics/utils/csrc/cuda_utils.cuh | 387 ------------------ torch_harmonics/utils/csrc/permute_cpu.cpp | 52 +++ torch_harmonics/utils/csrc/permute_cuda.cu | 114 ++++++ torch_harmonics/utils/csrc/permute_cuda.cuh | 225 ++++++++++ .../utils/csrc/utils_interface.cpp | 4 +- 13 files changed, 752 insertions(+), 476 deletions(-) create mode 100644 tests/test_permute.py rename torch_harmonics/utils/csrc/{cuda_utils.cu => csr_cuda.cu} (69%) create mode 100644 torch_harmonics/utils/csrc/csr_cuda.cuh delete mode 100644 torch_harmonics/utils/csrc/cuda_utils.cuh create mode 100644 torch_harmonics/utils/csrc/permute_cpu.cpp create mode 100644 torch_harmonics/utils/csrc/permute_cuda.cu create mode 100644 torch_harmonics/utils/csrc/permute_cuda.cuh diff --git a/setup.py b/setup.py index c93a81c8..526a23b3 100644 --- a/setup.py +++ b/setup.py @@ -137,12 +137,13 @@ def get_ext_modules(): # HELPERS utility_sources = [ "torch_harmonics/utils/csrc/utils_interface.cpp", + "torch_harmonics/utils/csrc/permute_cpu.cpp", ] if BUILD_CUDA: print(f"Compiling custom CUDA kernels for torch-harmonics.") utility_sources.extend([ - "torch_harmonics/utils/csrc/cuda_utils.cu", + "torch_harmonics/utils/csrc/permute_cuda.cu", ]) ext_modules.append( CUDAExtension( @@ -170,7 +171,7 @@ def get_ext_modules(): if BUILD_CUDA: print(f"Compiling custom CUDA kernels for torch-harmonics.") disco_sources.extend([ - "torch_harmonics/utils/csrc/cuda_utils.cu", + "torch_harmonics/utils/csrc/csr_cuda.cu", "torch_harmonics/disco/csrc/disco_cuda_fwd.cu", "torch_harmonics/disco/csrc/disco_cuda_bwd.cu", ]) diff --git a/tests/test_permute.py b/tests/test_permute.py new file mode 100644 index 00000000..b9cd82c6 --- /dev/null +++ b/tests/test_permute.py @@ -0,0 +1,116 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import os +import unittest +from parameterized import parameterized, parameterized_class + + +import torch +from torch.library import opcheck +from torch_harmonics.utils import permute_to_0231, permute_to_0312 + +_devices = [(torch.device("cpu"),)] +if torch.cuda.is_available(): + _devices.append((torch.device("cuda"),)) + +@parameterized_class(("device"), _devices) +class TestPermutation(unittest.TestCase): + """Test the optimized convolution module (CPU/CUDA if available).""" + + def setUp(self): + torch.manual_seed(333) + if self.device.type == "cuda": + torch.cuda.manual_seed(333) + + @parameterized.expand( + [ + [8, 8, 16, 32, "0231"], + [8, 1, 16, 32, "0231"], + [1, 8, 16, 32, "0231"], + [8, 8, 16, 32, "0312"], + [8, 1, 16, 32, "0312"], + [1, 8, 16, 32, "0312"], + ], + skip_on_empty=True, + ) + def test_permutation( + self, batch_size, channels, nlat, nlon, mode + ): + # create input + if mode == "0231": + inp = torch.randn(batch_size, channels, nlat, nlon, device=self.device) + permute_fn = permute_to_0231 + permute_shape = (0, 2, 3, 1) + else: + inp = torch.randn(batch_size, nlat, nlon, channels, device=self.device) + permute_fn = permute_to_0312 + permute_shape = (0, 3, 1, 2) + inp.requires_grad = True + + # forward test + out_opt = permute_fn(inp) + out_naive = inp.permute(*permute_shape).contiguous().clone() + self.assertTrue(torch.allclose(out_opt, out_naive)) + + # backward test + ograd = torch.randn_like(out_opt) + out_opt.backward(ograd) + igrad_opt = inp.grad.clone() + inp.grad = None + out_naive.backward(ograd) + igrad_naive = inp.grad.clone() + self.assertTrue(torch.allclose(igrad_opt, igrad_naive)) + + + @parameterized.expand( + [ + [8, 8, 16, 32, "0231"], + [8, 8, 16, 32, "0312"], + ], + skip_on_empty=True, + ) + def test_pt2_compatibility(self, batch_size, channels, nlat, nlon, mode): + + if mode == "0231": + inp = torch.randn(batch_size, channels, nlat, nlon, device=self.device) + permute_fn = torch.ops.utility_kernels.permute_to_0231 + else: + inp = torch.randn(batch_size, nlat, nlon, channels, device=self.device) + permute_fn = torch.ops.utility_kernels.permute_to_0312 + + test_inputs = (inp, ) + + opcheck(permute_fn, test_inputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/torch_harmonics/disco/convolution.py b/torch_harmonics/disco/convolution.py index 2d265ce5..d374c671 100644 --- a/torch_harmonics/disco/convolution.py +++ b/torch_harmonics/disco/convolution.py @@ -516,8 +516,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: with nvtx.annotate("_disco_s2_contraction_optimized", color="red"): xp = permute_to_0231(x) xpc = _disco_s2_contraction_optimized( - xp, self.weight, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, - self.nlat_out, self.nlon_out + xp, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, + self.kernel_size, self.nlat_out, self.nlon_out ) xpc = xpc.reshape(xpc.shape[0], self.nlat_out, self.nlon_out, self.groups, self.groupsize_in, self.kernel_size) yp = torch.einsum("bxygck,gock->bxygo", xpc, self.weight).reshape(xpc.shape[0], self.nlat_out, self.nlon_out, -1).contiguous() diff --git a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu index cda05836..710f7ae8 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu @@ -30,7 +30,7 @@ #include "disco.h" #include "disco_cuda.cuh" -#include "cuda_utils.cuh" +#include "csr_cuda.cuh" namespace disco_kernels { diff --git a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu index 778764fe..3328ae39 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu @@ -31,7 +31,7 @@ #include "cudamacro.h" #include "disco.h" #include "disco_cuda.cuh" -#include "cuda_utils.cuh" +#include "csr_cuda.cuh" #define THREADS (64) diff --git a/torch_harmonics/utils/_utils.py b/torch_harmonics/utils/_utils.py index 2141d9cc..f79aa013 100644 --- a/torch_harmonics/utils/_utils.py +++ b/torch_harmonics/utils/_utils.py @@ -39,38 +39,42 @@ if optimized_kernels_is_available(): # fake permutations - @torch.library.register_fake("utility_kernels::permute_0231") + @torch.library.register_fake("utility_kernels::permute_to_0231") def _(inp: torch.Tensor) -> torch.Tensor: B, C, H, W = inp.shape out_shape = (B, H, W, C) return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) - @torch.library.register_fake("utility_kernels::permute_0312") + @torch.library.register_fake("utility_kernels::permute_to_0312") def _(inp: torch.Tensor) -> torch.Tensor: B, H, W, C = inp.shape out_shape = (B, C, H, W) return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) - # autograds - torch.library.register_autograd( - "utility_kernels::permute_0231", utility_kernels.permute_0312) + # autograds: shallow wrappers around the default kernels + def _permute_to_0231_bwd(ctx, grad_output): + return utility_kernels.permute_to_0312.default(grad_output) + + def _permute_to_0312_bwd(ctx, grad_output): + return utility_kernels.permute_to_0231.default(grad_output) torch.library.register_autograd( - "utility_kernels::permute_0312", utility_kernels.permute_0231) + "utility_kernels::permute_to_0231", _permute_to_0231_bwd) + torch.library.register_autograd( + "utility_kernels::permute_to_0312", _permute_to_0312_bwd) -def permute_to_0231(inp: torch.Tensor) -> torch.Tensor: - if optimized_kernels_is_available() and inp.is_cuda: - out = utility_kernels.permute_0231.default(inp) - else: - out = inp.permute(0, 2, 3, 1).contiguous() - return out +if optimized_kernels_is_available(): + def permute_to_0231(inp: torch.Tensor) -> torch.Tensor: + return utility_kernels.permute_to_0231.default(inp) + + def permute_to_0312(inp: torch.Tensor) -> torch.Tensor: + return utility_kernels.permute_to_0312.default(inp) +else: + def permute_to_0231(inp: torch.Tensor) -> torch.Tensor: + return inp.permute(0, 2, 3, 1).contiguous() -def permute_to_0312(inp: torch.Tensor) -> torch.Tensor: - if optimized_kernels_is_available() and inp.is_cuda: - out = utility_kernels.permute_0312.default(inp) - else: - out = inp.permute(0, 3, 1, 2).contiguous() - return out + def permute_to_0312(inp: torch.Tensor) -> torch.Tensor: + return inp.permute(0, 3, 1, 2).contiguous() diff --git a/torch_harmonics/utils/csrc/cuda_utils.cu b/torch_harmonics/utils/csrc/csr_cuda.cu similarity index 69% rename from torch_harmonics/utils/csrc/cuda_utils.cu rename to torch_harmonics/utils/csrc/csr_cuda.cu index 40a1f270..f6f0f5b1 100644 --- a/torch_harmonics/utils/csrc/cuda_utils.cu +++ b/torch_harmonics/utils/csrc/csr_cuda.cu @@ -40,7 +40,7 @@ #include #include "cudamacro.h" -#include "cuda_utils.cuh" +#include "csr_cuda.cuh" #define THREADS (64) @@ -112,62 +112,6 @@ torch::Tensor sortRows(int nlat_out, torch::Tensor row_off, cudaStream_t stream) // END - CSR rows sorting kernels and functions -// BEGIN - 4D tensor permutation kernels and functions -__global__ void empty_k() {} - -static int getPtxver() { - cudaFuncAttributes attrs; - CHECK_CUDA(cudaFuncGetAttributes(&attrs, empty_k)); - return attrs.ptxVersion*10; -} - -at::Tensor permute_4D_to0231(at::Tensor src) { - - auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device()); - torch::Tensor dst = torch::empty({src.size(0), src.size(2), src.size(3), src.size(1)}, options); - - const int ptxv = getPtxver(); - - // to be further specialized for additional archs, if necessary - if (ptxv < 100) { - AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_generic", ([&] { - launch_permute_to0231(src, dst); - })); - CHECK_ERROR("permute_to0231_k_tile_generic"); - } else { - AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_sm100", ([&] { - launch_permute_to0231(src, dst); - })); - CHECK_ERROR("permute_to0231_k_tile_sm100"); - } - - return dst; -} - -at::Tensor permute_4D_to0312(at::Tensor src) { - - auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device()); - torch::Tensor dst = torch::empty({src.size(0), src.size(3), src.size(1), src.size(2)}, options); - - const int ptxv = getPtxver(); - - // to be further specialized for additional archs, if necessary - if (ptxv < 100) { - AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_generic", ([&] { - launch_permute_to0312(src, dst); - })); - CHECK_ERROR("permute_to0312_k_tile_generic"); - } else { - AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_sm100", ([&] { - launch_permute_to0312(src, dst); - })); - CHECK_ERROR("permute_to0312_k_tile_sm100"); - } - - return dst; -} -// END - tensor permutation kernels and functions - // BEGIN - general host-side functions unsigned int next_pow2(unsigned int x) { @@ -181,10 +125,4 @@ unsigned int next_pow2(unsigned int x) { } // END - general host-side functions -TORCH_LIBRARY_IMPL(utility_kernels, CUDA, m) - { - m.impl("permute_0231", &permute_4D_to0231); - m.impl("permute_0312", &permute_4D_to0312); - } - } diff --git a/torch_harmonics/utils/csrc/csr_cuda.cuh b/torch_harmonics/utils/csrc/csr_cuda.cuh new file mode 100644 index 00000000..956346b9 --- /dev/null +++ b/torch_harmonics/utils/csrc/csr_cuda.cuh @@ -0,0 +1,213 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +// +//#include +//#include + +#include +#include +#include +#include + +#define WARP_SIZE (32) +#define FULL_MASK (0xFFFFFFFF) + +#ifndef DIV_UP +#define DIV_UP(a,b) (((a)+((b)-1))/(b)) +#endif + +namespace utility_kernels { + +// CSR rows sorting kernels and functions +torch::Tensor sortRows(int nlat_out, torch::Tensor row_off, cudaStream_t stream); + + +// Host tensor dump and CSR manipulation functions +void dump_tensor(const char *fname, torch::Tensor t); +void dump_csr(const char *fname, torch::Tensor roff, torch::Tensor cols); + +int part_csr_rows(int *row_perm, + const torch::Tensor roff, + const torch::Tensor cols, + int **part_off, + int **part_val); + +int verify_part(const int npart, + const int *part_off, + const int *part_val, + const torch::Tensor roff, + const torch::Tensor cols); + +void verify_part_new(const int nlon_out, + const int nlat_in, + const int nlon_in, + const int npart, // partitioning data + const int *part_off, + const int *part_val, + const torch::Tensor roff, + const torch::Tensor cols); + +unsigned int next_pow2(unsigned int x); + + +// utility host functions and templates + +template +int is_aligned(const void *ptr) { + + static_assert(0 == (ALIGN & (ALIGN-1))); + return (0 == (uintptr_t(ptr) & (ALIGN-1))); +} + + +// utility device functions and templates + +template +__device__ FLOATV_T __vset(float x) { + static_assert(sizeof(FLOATV_T) == 0, "Unsupported type for __vset"); + return FLOATV_T{}; +} + +template<> +__device__ float __forceinline__ __vset(float x) { + return x; +} + +__device__ float __forceinline__ __vmul(float a, float b) { + return a*b; +} + +__device__ float __forceinline__ __vadd(float a, float b) { + return a+b; +} + +__device__ float __forceinline__ __vsub(float a, float b) { + return a-b; +} + +__device__ float __forceinline__ __vred(float a) { + return a; +} + +__device__ float __forceinline__ __vscale(float s, float v) { + return v*s; +} + +__device__ float __forceinline__ __vdiv(float s, float v) { + return v/s; +} + +template<> +__device__ float4 __forceinline__ __vset(float x) { + return make_float4(x, x, x, x); +} + +__device__ float4 __forceinline__ __vmul(float4 a, float4 b) { + return make_float4(a.x*b.x, a.y*b.y, a.z*b.z, a.w*b.w); +} + +__device__ float4 __forceinline__ __vadd(float4 a, float4 b) { + return make_float4(a.x+b.x, a.y+b.y, a.z+b.z, a.w+b.w); +} + +__device__ float4 __forceinline__ __vsub(float4 a, float4 b) { + return make_float4(a.x-b.x, a.y-b.y, a.z-b.z, a.w-b.w); +} + +__device__ float __forceinline__ __vred(float4 a) { + return a.x + a.y + a.z + a.w; +} + +__device__ float4 __forceinline__ __vscale(float s, float4 v) { + return make_float4(s*v.x, s*v.y, s*v.z, s*v.w); +} + +__device__ float4 __forceinline__ __vdiv(float s, float4 v) { + return make_float4(s/v.x, s/v.y, s/v.z, s/v.w);; +} + +template +__device__ VAL_T __warp_sum(VAL_T val) { + + #pragma unroll + for(int i = WARP_SIZE/2; i; i /= 2) { + val += __shfl_xor_sync(FULL_MASK, val, i, WARP_SIZE); + } + return val; +} + +template +__device__ VAL_T __block_sum(VAL_T val) { + + const int NWARP = (BDIM_X*BDIM_Y*BDIM_Z) / WARP_SIZE; + + val = __warp_sum(val); + + if constexpr(NWARP > 1) { + + int tid = threadIdx.x; + if constexpr(BDIM_Y > 1) { tid += threadIdx.y*BDIM_X; } + if constexpr(BDIM_Z > 1) { tid += threadIdx.z*BDIM_X*BDIM_Y; } + + const int lid = tid%WARP_SIZE; + const int wid = tid/WARP_SIZE; + + __shared__ VAL_T sh[NWARP]; + + if (lid == 0) { + sh[wid] = val; + } + __syncthreads(); + + if (wid == 0) { + val = (lid < NWARP) ? sh[lid] : 0; + + val = __warp_sum(val); + __syncwarp(); + + if (!lid) { + sh[0] = val; + } + } + __syncthreads(); + + val = sh[0]; + __syncthreads(); + } + return val; +} + +} diff --git a/torch_harmonics/utils/csrc/cuda_utils.cuh b/torch_harmonics/utils/csrc/cuda_utils.cuh deleted file mode 100644 index 54d77a95..00000000 --- a/torch_harmonics/utils/csrc/cuda_utils.cuh +++ /dev/null @@ -1,387 +0,0 @@ -// coding=utf-8 -// -// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. -// SPDX-License-Identifier: BSD-3-Clause -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -#pragma once - -// -//#include -//#include - -#include -#include -#include -#include - -#define WARP_SIZE (32) -#define FULL_MASK (0xFFFFFFFF) - -#ifndef DIV_UP -#define DIV_UP(a,b) (((a)+((b)-1))/(b)) -#endif - -namespace utility_kernels { - -// CSR rows sorting kernels and functions -torch::Tensor sortRows(int nlat_out, torch::Tensor row_off, cudaStream_t stream); - -// 4D tensor permutation kernels and functions -torch::Tensor permute_4D_to0231(torch::Tensor src); -torch::Tensor permute_4D_to0312(torch::Tensor src); - -// Host tensor dump and CSR manipulation functions -void dump_tensor(const char *fname, torch::Tensor t); -void dump_csr(const char *fname, torch::Tensor roff, torch::Tensor cols); - -int part_csr_rows(int *row_perm, - const torch::Tensor roff, - const torch::Tensor cols, - int **part_off, - int **part_val); - -int verify_part(const int npart, - const int *part_off, - const int *part_val, - const torch::Tensor roff, - const torch::Tensor cols); - -void verify_part_new(const int nlon_out, - const int nlat_in, - const int nlon_in, - const int npart, // partitioning data - const int *part_off, - const int *part_val, - const torch::Tensor roff, - const torch::Tensor cols); - -unsigned int next_pow2(unsigned int x); - - -// utility host functions and templates - -template -int is_aligned(const void *ptr) { - - static_assert(0 == (ALIGN & (ALIGN-1))); - return (0 == (uintptr_t(ptr) & (ALIGN-1))); -} - - -// utility device functions and templates - -template -__device__ FLOATV_T __vset(float x) { - static_assert(sizeof(FLOATV_T) == 0, "Unsupported type for __vset"); - return FLOATV_T{}; -} - -template<> -__device__ float __forceinline__ __vset(float x) { - return x; -} - -__device__ float __forceinline__ __vmul(float a, float b) { - return a*b; -} - -__device__ float __forceinline__ __vadd(float a, float b) { - return a+b; -} - -__device__ float __forceinline__ __vsub(float a, float b) { - return a-b; -} - -__device__ float __forceinline__ __vred(float a) { - return a; -} - -__device__ float __forceinline__ __vscale(float s, float v) { - return v*s; -} - -__device__ float __forceinline__ __vdiv(float s, float v) { - return v/s; -} - -template<> -__device__ float4 __forceinline__ __vset(float x) { - return make_float4(x, x, x, x); -} - -__device__ float4 __forceinline__ __vmul(float4 a, float4 b) { - return make_float4(a.x*b.x, a.y*b.y, a.z*b.z, a.w*b.w); -} - -__device__ float4 __forceinline__ __vadd(float4 a, float4 b) { - return make_float4(a.x+b.x, a.y+b.y, a.z+b.z, a.w+b.w); -} - -__device__ float4 __forceinline__ __vsub(float4 a, float4 b) { - return make_float4(a.x-b.x, a.y-b.y, a.z-b.z, a.w-b.w); -} - -__device__ float __forceinline__ __vred(float4 a) { - return a.x + a.y + a.z + a.w; -} - -__device__ float4 __forceinline__ __vscale(float s, float4 v) { - return make_float4(s*v.x, s*v.y, s*v.z, s*v.w); -} - -__device__ float4 __forceinline__ __vdiv(float s, float4 v) { - return make_float4(s/v.x, s/v.y, s/v.z, s/v.w);; -} - -template -__device__ VAL_T __warp_sum(VAL_T val) { - - #pragma unroll - for(int i = WARP_SIZE/2; i; i /= 2) { - val += __shfl_xor_sync(FULL_MASK, val, i, WARP_SIZE); - } - return val; -} - -template -__device__ VAL_T __block_sum(VAL_T val) { - - const int NWARP = (BDIM_X*BDIM_Y*BDIM_Z) / WARP_SIZE; - - val = __warp_sum(val); - - if constexpr(NWARP > 1) { - - int tid = threadIdx.x; - if constexpr(BDIM_Y > 1) { tid += threadIdx.y*BDIM_X; } - if constexpr(BDIM_Z > 1) { tid += threadIdx.z*BDIM_X*BDIM_Y; } - - const int lid = tid%WARP_SIZE; - const int wid = tid/WARP_SIZE; - - __shared__ VAL_T sh[NWARP]; - - if (lid == 0) { - sh[wid] = val; - } - __syncthreads(); - - if (wid == 0) { - val = (lid < NWARP) ? sh[lid] : 0; - - val = __warp_sum(val); - __syncwarp(); - - if (!lid) { - sh[0] = val; - } - } - __syncthreads(); - - val = sh[0]; - __syncthreads(); - } - return val; -} - -// transpose utils -template -__global__ -__launch_bounds__(BDIM_X*BDIM_Y) -void permute_to0231_k(const int nchn, - const int nlat, - const int nlon, - const at::PackedTensorAccessor32 src, - at::PackedTensorAccessor32 dst) { - - static_assert(!(BDIM_X & (BDIM_X-1))); - static_assert(!(BDIM_Y & (BDIM_Y-1))); - static_assert(BDIM_X >= BDIM_Y); - - __shared__ VAL_T sh[BDIM_X][BDIM_X+1]; - - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - - const int coff = blockIdx.x*BDIM_X; // channel offset - const int woff = blockIdx.y*BDIM_X; // width offset - const int batch = blockIdx.z / nlat; // batch (same for all block) - const int h = blockIdx.z - (batch * nlat); // height (same for all block) - - const int nchn_full = (nchn-coff) >= BDIM_X; - const int nlon_full = (nlon-woff) >= BDIM_X; - - if (nchn_full && nlon_full) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - sh[j+tidy][tidx] = src[batch][coff + j+tidy][h][woff+tidx]; - } - __syncthreads(); - - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy]; - } - } else { - if (woff+tidx < nlon) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - sh[j+tidy][tidx] = (coff + j+tidy < nchn) ? src[batch][coff + j+tidy][h][woff+tidx] : VAL_T(0); - } - } - __syncthreads(); - - if (coff+tidx < nchn) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - if (woff + j+tidy < nlon) { - dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy]; - } - } - } - } - return; -} - -template -void launch_permute_to0231(torch::Tensor src, torch::Tensor dst){ - dim3 block; - dim3 grid; - - block.x = WARP_SIZE; - block.y = WARPS_X_TILE; - grid.x = DIV_UP(src.size(1), block.x); - grid.y = DIV_UP(src.size(3), block.x); - grid.z = src.size(2)*src.size(0); - - assert(grid.y < 65536); - assert(grid.z < 65536); - - // get stream - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - permute_to0231_k - <<>>(src.size(1), - src.size(2), - src.size(3), - src.packed_accessor32(), - dst.packed_accessor32()); -} - -template -__global__ -__launch_bounds__(BDIM_X*BDIM_Y) -void permute_to0312_k(const int nchn, - const int nlat, - const int nlon, - const at::PackedTensorAccessor32 src, - at::PackedTensorAccessor32 dst) { - - static_assert(!(BDIM_X & (BDIM_X-1))); - static_assert(!(BDIM_Y & (BDIM_Y-1))); - static_assert(BDIM_X >= BDIM_Y); - - __shared__ VAL_T sh[BDIM_X][BDIM_X+1]; - - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - - const int woff = blockIdx.x*BDIM_X; // width offset - const int coff = blockIdx.y*BDIM_X; // channel offset - const int batch = blockIdx.z / nlat; // batch (same for all block) - const int h = blockIdx.z - (batch * nlat); // height (same for all block) - - const int nchn_full = (nchn-coff) >= BDIM_X; - const int nlon_full = (nlon-woff) >= BDIM_X; - - if (nchn_full && nlon_full) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - sh[j+tidy][tidx] = src[batch][h][woff + j+tidy][coff+tidx]; - } - __syncthreads(); - - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy]; - } - } else { - if (coff+tidx < nchn) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - sh[j+tidy][tidx] = (woff + j+tidy < nlon) ? src[batch][h][woff + j+tidy][coff+tidx] : VAL_T(0); - } - } - __syncthreads(); - - if (woff+tidx < nlon) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - if (coff + j+tidy < nchn) { - dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];; - } - } - } - } - return; -} - -template -void launch_permute_to0312(torch::Tensor src, torch::Tensor dst){ - dim3 block; - dim3 grid; - - block.x = WARP_SIZE; - block.y = WARPS_X_TILE; - grid.x = DIV_UP(src.size(2), block.x); - grid.y = DIV_UP(src.size(3), block.x); - grid.z = src.size(1)*src.size(0); - - assert(grid.y < 65536); - assert(grid.z < 65536); - - // get stream - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - permute_to0312_k - <<>>(src.size(3), - src.size(1), - src.size(2), - src.packed_accessor32(), - dst.packed_accessor32()); -} - -} diff --git a/torch_harmonics/utils/csrc/permute_cpu.cpp b/torch_harmonics/utils/csrc/permute_cpu.cpp new file mode 100644 index 00000000..e604b927 --- /dev/null +++ b/torch_harmonics/utils/csrc/permute_cpu.cpp @@ -0,0 +1,52 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include +#include + +namespace utility_kernels { + +torch::Tensor permute_4D_to0231_cpu(torch::Tensor src) { + // CPU implementation using standard permute + return src.permute({0, 2, 3, 1}).contiguous(); +} + +torch::Tensor permute_4D_to0312_cpu(torch::Tensor src) { + // CPU implementation using standard permute + return src.permute({0, 3, 1, 2}).contiguous(); +} + +TORCH_LIBRARY_IMPL(utility_kernels, CPU, m) +{ + m.impl("permute_to_0231", &permute_4D_to0231_cpu); + m.impl("permute_to_0312", &permute_4D_to0312_cpu); +} + +} diff --git a/torch_harmonics/utils/csrc/permute_cuda.cu b/torch_harmonics/utils/csrc/permute_cuda.cu new file mode 100644 index 00000000..e97eb574 --- /dev/null +++ b/torch_harmonics/utils/csrc/permute_cuda.cu @@ -0,0 +1,114 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +//#include +//#include +//#include +#include + +#include + +#include +#include + +#include "cudamacro.h" +#include "permute_cuda.cuh" + +// Define the missing macros +#define TRANSP_WARPS_X_TILE_GENERIC (32) +#define TRANSP_WARPS_X_TILE_SM100 (4) + +namespace utility_kernels { + + // BEGIN - 4D tensor permutation kernels and functions +__global__ void empty_k() {} + +static int getPtxver() { + cudaFuncAttributes attrs; + CHECK_CUDA(cudaFuncGetAttributes(&attrs, empty_k)); + return attrs.ptxVersion*10; +} + +torch::Tensor permute_4D_to0231(torch::Tensor src) { + + auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device()); + torch::Tensor dst = torch::empty({src.size(0), src.size(2), src.size(3), src.size(1)}, options); + + const int ptxv = getPtxver(); + + // to be further specialized for additional archs, if necessary + if (ptxv < 100) { + AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_generic", ([&] { + launch_permute_to0231(src, dst); + })); + CHECK_ERROR("permute_to0231_k_tile_generic"); + } else { + AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_sm100", ([&] { + launch_permute_to0231(src, dst); + })); + CHECK_ERROR("permute_to0231_k_tile_sm100"); + } + + return dst; +} + +torch::Tensor permute_4D_to0312(torch::Tensor src) { + + auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device()); + torch::Tensor dst = torch::empty({src.size(0), src.size(3), src.size(1), src.size(2)}, options); + + const int ptxv = getPtxver(); + + // to be further specialized for additional archs, if necessary + if (ptxv < 100) { + AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_generic", ([&] { + launch_permute_to0312(src, dst); + })); + CHECK_ERROR("permute_to0312_k_tile_generic"); + } else { + AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_sm100", ([&] { + launch_permute_to0312(src, dst); + })); + CHECK_ERROR("permute_to0312_k_tile_sm100"); + } + + return dst; +} + +TORCH_LIBRARY_IMPL(utility_kernels, CUDA, m) +{ + m.impl("permute_to_0231", &permute_4D_to0231); + m.impl("permute_to_0312", &permute_4D_to0312); +} + +// END - tensor permutation kernels and functions + +} \ No newline at end of file diff --git a/torch_harmonics/utils/csrc/permute_cuda.cuh b/torch_harmonics/utils/csrc/permute_cuda.cuh new file mode 100644 index 00000000..768b84db --- /dev/null +++ b/torch_harmonics/utils/csrc/permute_cuda.cuh @@ -0,0 +1,225 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +// +//#include +//#include + +#include +#include +#include +#include + +#define WARP_SIZE (32) +#define FULL_MASK (0xFFFFFFFF) + +#ifndef DIV_UP +#define DIV_UP(a,b) (((a)+((b)-1))/(b)) +#endif + +namespace utility_kernels { + +// transpose utils +template +__global__ +__launch_bounds__(BDIM_X*BDIM_Y) +void permute_to0231_k(const int nchn, + const int nlat, + const int nlon, + const at::PackedTensorAccessor32 src, + at::PackedTensorAccessor32 dst) { + +static_assert(!(BDIM_X & (BDIM_X-1))); +static_assert(!(BDIM_Y & (BDIM_Y-1))); +static_assert(BDIM_X >= BDIM_Y); + +__shared__ VAL_T sh[BDIM_X][BDIM_X+1]; + +const int tidx = threadIdx.x; +const int tidy = threadIdx.y; + +const int coff = blockIdx.x*BDIM_X; // channel offset +const int woff = blockIdx.y*BDIM_X; // width offset +const int batch = blockIdx.z / nlat; // batch (same for all block) +const int h = blockIdx.z - (batch * nlat); // height (same for all block) + +const int nchn_full = (nchn-coff) >= BDIM_X; +const int nlon_full = (nlon-woff) >= BDIM_X; + +if (nchn_full && nlon_full) { +#pragma unroll +for(int j = 0; j < BDIM_X; j += BDIM_Y) { + sh[j+tidy][tidx] = src[batch][coff + j+tidy][h][woff+tidx]; +} +__syncthreads(); + +#pragma unroll +for(int j = 0; j < BDIM_X; j += BDIM_Y) { + dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy]; +} +} else { +if (woff+tidx < nlon) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + sh[j+tidy][tidx] = (coff + j+tidy < nchn) ? src[batch][coff + j+tidy][h][woff+tidx] : VAL_T(0); + } +} +__syncthreads(); + +if (coff+tidx < nchn) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + if (woff + j+tidy < nlon) { + dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy]; + } + } +} +} +return; +} + +template +void launch_permute_to0231(torch::Tensor src, torch::Tensor dst){ +dim3 block; +dim3 grid; + +block.x = WARP_SIZE; +block.y = WARPS_X_TILE; +grid.x = DIV_UP(src.size(1), block.x); +grid.y = DIV_UP(src.size(3), block.x); +grid.z = src.size(2)*src.size(0); + +assert(grid.y < 65536); +assert(grid.z < 65536); + +// get stream +auto stream = at::cuda::getCurrentCUDAStream().stream(); + +permute_to0231_k + <<>>(src.size(1), + src.size(2), + src.size(3), + src.packed_accessor32(), + dst.packed_accessor32()); +} + +template +__global__ +__launch_bounds__(BDIM_X*BDIM_Y) +void permute_to0312_k(const int nchn, + const int nlat, + const int nlon, + const at::PackedTensorAccessor32 src, + at::PackedTensorAccessor32 dst) { + +static_assert(!(BDIM_X & (BDIM_X-1))); +static_assert(!(BDIM_Y & (BDIM_Y-1))); +static_assert(BDIM_X >= BDIM_Y); + +__shared__ VAL_T sh[BDIM_X][BDIM_X+1]; + +const int tidx = threadIdx.x; +const int tidy = threadIdx.y; + +const int woff = blockIdx.x*BDIM_X; // width offset +const int coff = blockIdx.y*BDIM_X; // channel offset +const int batch = blockIdx.z / nlat; // batch (same for all block) +const int h = blockIdx.z - (batch * nlat); // height (same for all block) + +const int nchn_full = (nchn-coff) >= BDIM_X; +const int nlon_full = (nlon-woff) >= BDIM_X; + +if (nchn_full && nlon_full) { +#pragma unroll +for(int j = 0; j < BDIM_X; j += BDIM_Y) { + sh[j+tidy][tidx] = src[batch][h][woff + j+tidy][coff+tidx]; +} +__syncthreads(); + +#pragma unroll +for(int j = 0; j < BDIM_X; j += BDIM_Y) { + dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy]; +} +} else { +if (coff+tidx < nchn) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + sh[j+tidy][tidx] = (woff + j+tidy < nlon) ? src[batch][h][woff + j+tidy][coff+tidx] : VAL_T(0); + } +} +__syncthreads(); + +if (woff+tidx < nlon) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + if (coff + j+tidy < nchn) { + dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];; + } + } +} +} +return; +} + +template +void launch_permute_to0312(torch::Tensor src, torch::Tensor dst){ +dim3 block; +dim3 grid; + +block.x = WARP_SIZE; +block.y = WARPS_X_TILE; +grid.x = DIV_UP(src.size(2), block.x); +grid.y = DIV_UP(src.size(3), block.x); +grid.z = src.size(1)*src.size(0); + +assert(grid.y < 65536); +assert(grid.z < 65536); + +// get stream +auto stream = at::cuda::getCurrentCUDAStream().stream(); + +permute_to0312_k + <<>>(src.size(3), + src.size(1), + src.size(2), + src.packed_accessor32(), + dst.packed_accessor32()); +} + +torch::Tensor permute_4D_to0312(torch::Tensor src); +torch::Tensor permute_4D_to0231(torch::Tensor src); + +} \ No newline at end of file diff --git a/torch_harmonics/utils/csrc/utils_interface.cpp b/torch_harmonics/utils/csrc/utils_interface.cpp index ecad921e..d7d901b2 100644 --- a/torch_harmonics/utils/csrc/utils_interface.cpp +++ b/torch_harmonics/utils/csrc/utils_interface.cpp @@ -54,8 +54,8 @@ namespace utility_kernels { // Declare the operators TORCH_LIBRARY(utility_kernels, m) { - m.def("permute_0231(Tensor inp) -> Tensor"); //, {at::Tag::pt2_compliant_tag}); - m.def("permute_0312(Tensor inp) -> Tensor"); + m.def("permute_to_0231(Tensor inp) -> Tensor", {at::Tag::pt2_compliant_tag}); + m.def("permute_to_0312(Tensor inp) -> Tensor", {at::Tag::pt2_compliant_tag}); } } From da3c206e09f704ea64e1fcb97806c2627b98563c Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Fri, 29 Aug 2025 04:25:04 -0700 Subject: [PATCH 07/31] fixed fwd pass --- torch_harmonics/disco/_disco_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_harmonics/disco/_disco_utils.py b/torch_harmonics/disco/_disco_utils.py index 97369e66..b4a9c111 100644 --- a/torch_harmonics/disco/_disco_utils.py +++ b/torch_harmonics/disco/_disco_utils.py @@ -94,10 +94,11 @@ def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, #general routines: this is the same for forward and transpose def _setup_context_conv_backward(ctx, inputs, output): - inp, roff_idx, ker_idx, row_idx, col_idx, vals, _, _, _ = inputs + inp, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, _, _ = inputs ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals) ctx.nlat_in = inp.shape[-2] ctx.nlon_in = inp.shape[-1] + ctx.kernel_size = kernel_size # convolution related def _disco_s2_contraction_bwd_optimized(ctx, grad_output): @@ -105,7 +106,7 @@ def _disco_s2_contraction_bwd_optimized(ctx, grad_output): if ctx.needs_input_grad[0]: grad_input = disco_kernels.backward.default(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals, - ctx.nlat_in, ctx.nlon_in) # Mauro + ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) # Mauro else: grad_input = None From 29679836a671268bcc730ada899a792692a0c339 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Fri, 29 Aug 2025 05:16:01 -0700 Subject: [PATCH 08/31] re-implementing transpose conv --- torch_harmonics/disco/_disco_utils.py | 129 ++++++++++++-------------- torch_harmonics/disco/convolution.py | 4 +- 2 files changed, 63 insertions(+), 70 deletions(-) diff --git a/torch_harmonics/disco/_disco_utils.py b/torch_harmonics/disco/_disco_utils.py index b4a9c111..01ecbd9d 100644 --- a/torch_harmonics/disco/_disco_utils.py +++ b/torch_harmonics/disco/_disco_utils.py @@ -46,14 +46,13 @@ def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, out_shape = (inp.shape[0], nlat_out, nlon_out, kernel_size*inp.shape[3]) return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) - # # raw backward fake - # @torch.library.register_fake("disco_kernels::backward") - # def _(inp: torch.Tensor, weights: torch.Tensor, - # roff_idx: torch.Tensor, ker_idx: torch.Tensor, - # row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, - # nlat_out: int, nlon_out: int) -> torch.Tensor: - # out_shape = (inp.shape[0], weights.shape[0] * weights.shape[2], nlat_out, nlon_out) - # return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) + # raw backward fake + @torch.library.register_fake("disco_kernels::backward") + def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, + row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, + kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + out_shape = (inp.shape[0], nlat_out, nlon_out, kernel_size*inp.shape[3]) + return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) # forward @torch.library.custom_op("disco_kernels::_disco_s2_contraction_optimized", mutates_args=()) @@ -64,17 +63,14 @@ def _disco_s2_contraction_optimized( out = disco_kernels.forward.default(inp, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out) return out - # # transpose - # @torch.library.custom_op("disco_kernels::_disco_s2_transpose_contraction_optimized", mutates_args=()) - # def _disco_s2_transpose_contraction_optimized( - # inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, - # row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, - # weights: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - # itype = inp.dtype - # inp = inp.to(torch.float32).contiguous() - # out = disco_kernels.backward.default(inp, roff_idx, ker_idx, row_idx, col_idx, vals, weights, kernel_size, nlat_out, nlon_out) - # out = out.to(itype) - # return out + # transpose + @torch.library.custom_op("disco_kernels::_disco_s2_transpose_contraction_optimized", mutates_args=()) + def _disco_s2_transpose_contraction_optimized( + inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, + row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, + kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + out = disco_kernels.backward.default(inp, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out) + return out # forward fake @torch.library.register_fake("disco_kernels::_disco_s2_contraction_optimized") @@ -84,13 +80,13 @@ def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, out_shape = (inp.shape[0], nlat_out, nlon_out, kernel_size*inp.shape[3]) return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) - # # transpose fake - # @torch.library.register_fake("disco_kernels::_disco_s2_transpose_contraction_optimized") - # def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, - # row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, - # weights: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - # out_shape = (inp.shape[0], weights.shape[0] * weights.shape[1], nlat_out, nlon_out) - # return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) + # transpose fake + @torch.library.register_fake("disco_kernels::_disco_s2_transpose_contraction_optimized") + def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, + row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, + kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + out_shape = (inp.shape[0], nlat_out, nlon_out, kernel_size*inp.shape[3]) + return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) #general routines: this is the same for forward and transpose def _setup_context_conv_backward(ctx, inputs, output): @@ -116,24 +112,21 @@ def _disco_s2_contraction_bwd_optimized(ctx, grad_output): torch.library.register_autograd( "disco_kernels::_disco_s2_contraction_optimized", _disco_s2_contraction_bwd_optimized, setup_context=_setup_context_conv_backward) -# # Transpose convolution related -# def _disco_s2_transpose_contraction_bwd_optimized(ctx, grad_output): -# roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors +# Transpose convolution related +def _disco_s2_transpose_contraction_bwd_optimized(ctx, grad_output): + roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors -# if ctx.needs_input_grad[0]: -# gtype = grad_output.dtype -# grad_output = grad_output.to(torch.float32).contiguous() -# grad_input = disco_kernels.forward.default(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals, -# torch.empty(0), ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) # Mauro -# grad_input = grad_input.to(gtype) -# else: -# grad_input = None + if ctx.needs_input_grad[0]: + grad_input = disco_kernels.forward.default(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals, + ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) # Mauro + else: + grad_input = None -# return grad_input, None, None, None, None, None, None, None, None, None # Mauro: added a None for weights + return grad_input, None, None, None, None, None, None, None, None, None # Mauro: added a None for weights -# if optimized_kernels_is_available(): -# torch.library.register_autograd( -# "disco_kernels::_disco_s2_transpose_contraction_optimized", _disco_s2_transpose_contraction_bwd_optimized, setup_context=_setup_context_conv_backward) +if optimized_kernels_is_available(): + torch.library.register_autograd( + "disco_kernels::_disco_s2_transpose_contraction_optimized", _disco_s2_transpose_contraction_bwd_optimized, setup_context=_setup_context_conv_backward) # torch kernel related functions def _get_psi(kernel_size: int, psi_idx: torch.Tensor, psi_vals: torch.Tensor, nlat_in: int, nlon_in: int, nlat_out: int, nlon_out: int, nlat_in_local: Optional[int] = None, nlat_out_local: Optional[int] = None, semi_transposed: Optional[bool] = False): @@ -192,39 +185,39 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in return y -# # transpose convolution -# def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int): -# assert len(psi.shape) == 3 -# assert len(x.shape) == 5 -# psi = psi.to(x.device) +# transpose convolution +def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int): + assert len(psi.shape) == 3 + assert len(x.shape) == 5 + psi = psi.to(x.device) -# batch_size, n_chans, kernel_size, nlat_in, nlon_in = x.shape -# kernel_size, nlat_out, n_out = psi.shape + batch_size, n_chans, kernel_size, nlat_in, nlon_in = x.shape + kernel_size, nlat_out, n_out = psi.shape -# assert n_out % nlon_out == 0 -# assert nlon_out >= nlon_in -# pscale = nlon_out // nlon_in + assert n_out % nlon_out == 0 + assert nlon_out >= nlon_in + pscale = nlon_out // nlon_in -# # interleave zeros along the longitude dimension to allow for fractional offsets to be considered -# x_ext = torch.zeros(kernel_size, nlat_in, nlon_out, batch_size * n_chans, device=x.device, dtype=x.dtype) -# x = x.reshape(batch_size * n_chans, kernel_size, nlat_in, nlon_in).permute(1, 2, 3, 0).contiguous() + # interleave zeros along the longitude dimension to allow for fractional offsets to be considered + x_ext = torch.zeros(kernel_size, nlat_in, nlon_out, batch_size * n_chans, device=x.device, dtype=x.dtype) + x = x.reshape(batch_size * n_chans, kernel_size, nlat_in, nlon_in).permute(1, 2, 3, 0).contiguous() -# # x has shape kernel_size x nlat_in x nlon_in x batch_size * n_chans -# # we only need to apoply the nlon stride here, since nlat stride is taken care of by the kernel -# x_ext[:, :, ::pscale, :] = x[...] + # x has shape kernel_size x nlat_in x nlon_in x batch_size * n_chans + # we only need to apoply the nlon stride here, since nlat stride is taken care of by the kernel + x_ext[:, :, ::pscale, :] = x[...] -# # create output tensor -# y = torch.zeros(kernel_size, nlon_out, nlat_out, batch_size * n_chans, device=x.device, dtype=x.dtype) + # create output tensor + y = torch.zeros(kernel_size, nlon_out, nlat_out, batch_size * n_chans, device=x.device, dtype=x.dtype) -# for pout in range(nlon_out): -# # we need to repeatedly roll the input tensor to faciliate the shifted multiplication -# # TODO: double-check why this has to happen first -# x_ext = torch.roll(x_ext, -1, dims=2) -# # sparse contraction with the modified psi -# y[:, pout, :, :] = torch.bmm(psi, x_ext.reshape(kernel_size, nlat_in * nlon_out, -1)) + for pout in range(nlon_out): + # we need to repeatedly roll the input tensor to faciliate the shifted multiplication + # TODO: double-check why this has to happen first + x_ext = torch.roll(x_ext, -1, dims=2) + # sparse contraction with the modified psi + y[:, pout, :, :] = torch.bmm(psi, x_ext.reshape(kernel_size, nlat_in * nlon_out, -1)) -# # sum over the kernel dimension and reshape to the correct output size -# y = y.sum(dim=0).permute(2, 1, 0).reshape(batch_size, n_chans, nlat_out, nlon_out).contiguous() + # sum over the kernel dimension and reshape to the correct output size + y = y.sum(dim=0).permute(2, 1, 0).reshape(batch_size, n_chans, nlat_out, nlon_out).contiguous() -# return y + return y diff --git a/torch_harmonics/disco/convolution.py b/torch_harmonics/disco/convolution.py index d374c671..e75c8ba8 100644 --- a/torch_harmonics/disco/convolution.py +++ b/torch_harmonics/disco/convolution.py @@ -46,8 +46,8 @@ from torch_harmonics.cache import lru_cache from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes from torch_harmonics.utils import permute_to_0231, permute_to_0312 -from ._disco_utils import _get_psi, _disco_s2_contraction_torch #, _disco_s2_transpose_contraction_torch -from ._disco_utils import _disco_s2_contraction_optimized #, _disco_s2_transpose_contraction_optimized +from ._disco_utils import _get_psi, _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch +from ._disco_utils import _disco_s2_contraction_optimized, _disco_s2_transpose_contraction_optimized from torch_harmonics.filter_basis import FilterBasis, get_filter_basis from disco_helpers import optimized_kernels_is_available, preprocess_psi From b0b4d44698bb86af1e3b9598ab2fb7baa93991f1 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 1 Sep 2025 05:29:34 -0700 Subject: [PATCH 09/31] working CPU branch --- setup.py | 2 +- tests/test_convolution.py | 19 +- torch_harmonics/__init__.py | 2 +- torch_harmonics/disco/__init__.py | 2 +- torch_harmonics/disco/_disco_utils.py | 12 +- torch_harmonics/disco/convolution.py | 307 +++++++++++------------ torch_harmonics/disco/csrc/disco_cpu.cpp | 30 ++- torch_harmonics/disco/csrc/disco_cpu.h | 14 +- 8 files changed, 196 insertions(+), 192 deletions(-) diff --git a/setup.py b/setup.py index 526a23b3..a5bd9e78 100644 --- a/setup.py +++ b/setup.py @@ -165,7 +165,7 @@ def get_ext_modules(): # Create a single extension that includes both CPU and CUDA code disco_sources = [ "torch_harmonics/disco/csrc/disco_interface.cpp", - #"torch_harmonics/disco/csrc/disco_cpu.cpp" + "torch_harmonics/disco/csrc/disco_cpu.cpp" ] if BUILD_CUDA: diff --git a/tests/test_convolution.py b/tests/test_convolution.py index a4261614..b798cd9b 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -47,8 +47,8 @@ _devices = [(torch.device("cpu"),)] -if torch.cuda.is_available(): - _devices.append((torch.device("cuda"),)) +#if torch.cuda.is_available(): +# _devices.append((torch.device("cuda"),)) # perf thresholds # CPU results normalized to 16 OpenMP threads, @@ -198,7 +198,7 @@ def setUp(self): [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", False, 1e-4, False], [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", False, 1e-4, False], [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4, False], - # transpose convolution + # # transpose convolution [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False], [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False], [8, 4, 2, (12, 24), (24, 48), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False], @@ -302,6 +302,7 @@ def test_sparse_against_dense( w_ref = torch.empty_like(conv.weight) with torch.no_grad(): w_ref.copy_(conv.weight) + w_ref = w_ref.reshape(-1, w_ref.shape[2], w_ref.shape[3]) w_ref.requires_grad = True # create an input signal @@ -331,7 +332,7 @@ def test_sparse_against_dense( # compare self.assertTrue(torch.allclose(x_grad, x_ref_grad, rtol=tol, atol=tol)) - self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad, rtol=tol, atol=tol)) + self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad.unsqueeze(0), rtol=tol, atol=tol)) @parameterized.expand( @@ -517,10 +518,10 @@ def test_device_instantiation(self, batch_size, in_channels, out_channels, in_sh @parameterized.expand( [ - [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False], - [8, 4, 2, (16, 32), (8, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False], - [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False], - [8, 4, 2, (8, 16), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False], + # [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False], + # [8, 4, 2, (16, 32), (8, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False], + # [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False], + # [8, 4, 2, (8, 16), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False], ], skip_on_empty=True, @@ -602,7 +603,7 @@ def test_optimized_pt2_compatibility( @parameterized.expand( [ - [8, 4, 2, (91, 180), (91, 180), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4], + # [8, 4, 2, (91, 180), (91, 180), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4], ], skip_on_empty=True, ) diff --git a/torch_harmonics/__init__.py b/torch_harmonics/__init__.py index 77e43db0..2d390b54 100644 --- a/torch_harmonics/__init__.py +++ b/torch_harmonics/__init__.py @@ -32,7 +32,7 @@ __version__ = "0.8.1" from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT -from .disco import DiscreteContinuousConvS2 #, DiscreteContinuousConvTransposeS2 +from .disco import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 from .resample import ResampleS2 from .attention import AttentionS2, NeighborhoodAttentionS2 from . import quadrature diff --git a/torch_harmonics/disco/__init__.py b/torch_harmonics/disco/__init__.py index 3ece8042..14ae7462 100644 --- a/torch_harmonics/disco/__init__.py +++ b/torch_harmonics/disco/__init__.py @@ -42,4 +42,4 @@ disco_kernels = None warnings.warn("No optimized DISCO kernels are available. Please compile the extension first setting BUILD_CPP and BUILD_CUDA to 1.") -from .convolution import DiscreteContinuousConvS2 #, DiscreteContinuousConvTransposeS2 +from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 diff --git a/torch_harmonics/disco/_disco_utils.py b/torch_harmonics/disco/_disco_utils.py index 01ecbd9d..b591a23b 100644 --- a/torch_harmonics/disco/_disco_utils.py +++ b/torch_harmonics/disco/_disco_utils.py @@ -43,7 +43,7 @@ def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - out_shape = (inp.shape[0], nlat_out, nlon_out, kernel_size*inp.shape[3]) + out_shape = (inp.shape[0], nlat_out, nlon_out, inp.shape[3], kernel_size) return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) # raw backward fake @@ -51,7 +51,7 @@ def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - out_shape = (inp.shape[0], nlat_out, nlon_out, kernel_size*inp.shape[3]) + out_shape = (inp.shape[0], nlat_out, nlon_out, inp.shape[3]) return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) # forward @@ -77,7 +77,7 @@ def _disco_s2_transpose_contraction_optimized( def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - out_shape = (inp.shape[0], nlat_out, nlon_out, kernel_size*inp.shape[3]) + out_shape = (inp.shape[0], nlat_out, nlon_out, inp.shape[3], kernel_size) return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) # transpose fake @@ -85,15 +85,15 @@ def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - out_shape = (inp.shape[0], nlat_out, nlon_out, kernel_size*inp.shape[3]) + out_shape = (inp.shape[0], nlat_out, nlon_out, inp.shape[3]) return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) #general routines: this is the same for forward and transpose def _setup_context_conv_backward(ctx, inputs, output): inp, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, _, _ = inputs ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals) - ctx.nlat_in = inp.shape[-2] - ctx.nlon_in = inp.shape[-1] + ctx.nlat_in = inp.shape[1] + ctx.nlon_in = inp.shape[2] ctx.kernel_size = kernel_size # convolution related diff --git a/torch_harmonics/disco/convolution.py b/torch_harmonics/disco/convolution.py index e75c8ba8..bee495cc 100644 --- a/torch_harmonics/disco/convolution.py +++ b/torch_harmonics/disco/convolution.py @@ -510,188 +510,177 @@ def psi_idx(self): @nvtx.annotate("forward", color="purple") def forward(self, x: torch.Tensor) -> torch.Tensor: - #print("input x.shape:", x.shape) - if self.optimized_kernel: with nvtx.annotate("_disco_s2_contraction_optimized", color="red"): xp = permute_to_0231(x) + xpc = _disco_s2_contraction_optimized( xp, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out - ) - xpc = xpc.reshape(xpc.shape[0], self.nlat_out, self.nlon_out, self.groups, self.groupsize_in, self.kernel_size) - yp = torch.einsum("bxygck,gock->bxygo", xpc, self.weight).reshape(xpc.shape[0], self.nlat_out, self.nlon_out, -1).contiguous() - out = permute_to_0312(yp) + ).reshape(x.shape[0], self.nlat_out, self.nlon_out, self.groups, self.groupsize_in, self.kernel_size) + + outp = torch.einsum("bxygck,gock->bxygo", xpc, self.weight).reshape(xpc.shape[0], self.nlat_out, self.nlon_out, -1).contiguous() + + out = permute_to_0312(outp) else: x = _disco_s2_contraction_torch(x, self.psi.to(x.device), self.nlon_out) - - #print("y.shape:", x.shape) - #print("self.groups:", self.groups, "self.groupsize:", self.groupsize) - #print("weight.shape:", self.weight.shape) - #pippo = self.weight.clone() - #pippo = pippo.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2]) - #print("after reshape, weight.shape:", pippo.shape) # extract shape B, _, K, H, W = x.shape with nvtx.annotate("reshape", color="blue"): x = x.reshape(B, self.groups, self.groupsize_in, K, H, W) - #print("after reshape, x.shape:", x.shape) - # do weight multiplication with nvtx.annotate("einsum", color="blue"): out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight).contiguous() - #print("out.shape:", out.shape) + out = out.reshape(B, -1, H, W) - #cpu_tensor = out.detach().cpu().numpy() - #np.savetxt('yout_einsum.ref.txt', cpu_tensor.flatten(), fmt='%.6f') + if self.bias is not None: + out = out + self.bias.reshape(1, -1, 1, 1) + + return out + + +class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): + """ + Discrete-continuous (DISCO) transpose convolutions on the 2-Sphere as described in [1]. + + Parameters + ----------- + in_channels: int + Number of input channels + out_channels: int + Number of output channels + in_shape: Tuple[int] + Input shape of the convolution tensor + out_shape: Tuple[int] + Output shape of the convolution tensor + kernel_shape: Union[int, Tuple[int], Tuple[int, int]] + Shape of the kernel + basis_type: Optional[str] + Type of the basis functions + basis_norm_mode: Optional[str] + Mode for basis normalization + groups: Optional[int] + Number of groups + grid_in: Optional[str] + Input grid type + grid_out: Optional[str] + Output grid type + bias: Optional[bool] + Whether to use bias + theta_cutoff: Optional[float] + Theta cutoff for the filter basis functions + optimized_kernel: Optional[bool] + Whether to use the optimized kernel (if available) + + Returns + -------- + out: torch.Tensor + Output tensor + + References + ---------- + [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + in_shape: Tuple[int], + out_shape: Tuple[int], + kernel_shape: Union[int, Tuple[int], Tuple[int, int]], + basis_type: Optional[str] = "piecewise linear", + basis_norm_mode: Optional[str] = "mean", + groups: Optional[int] = 1, + grid_in: Optional[str] = "equiangular", + grid_out: Optional[str] = "equiangular", + bias: Optional[bool] = True, + theta_cutoff: Optional[float] = None, + optimized_kernel: Optional[bool] = True, + ): + super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias, optimized_kernel) + + self.nlat_in, self.nlon_in = in_shape + self.nlat_out, self.nlon_out = out_shape + + # make sure the p-shift works by checking that longitudes are divisible + assert self.nlon_out % self.nlon_in == 0 - print("weight.shape:", self.weight.shape) - print("after reshape, out.shape:", out.shape) - print("\n") + # bandlimit + if theta_cutoff is None: + theta_cutoff = torch.pi / float(self.nlat_in - 1) + if theta_cutoff <= 0.0: + raise ValueError("Error, theta_cutoff has to be positive.") + + # switch in_shape and out_shape since we want the transpose convolution + idx, vals, _ = _precompute_convolution_tensor_s2( + out_shape, + in_shape, + self.filter_basis, + grid_in=grid_out, + grid_out=grid_in, + theta_cutoff=theta_cutoff, + transpose_normalization=True, + basis_norm_mode=basis_norm_mode, + merge_quadrature=True, + ) + + # sort the values + ker_idx = idx[0, ...].contiguous() + row_idx = idx[1, ...].contiguous() + col_idx = idx[2, ...].contiguous() + vals = vals.contiguous() + + if self.optimized_kernel: + # preprocessed data-structure for GPU kernel + roff_idx = preprocess_psi(self.kernel_size, self.nlat_in, ker_idx, row_idx, col_idx, vals).contiguous() + self.register_buffer("psi_roff_idx", roff_idx, persistent=False) + + # save all datastructures + self.register_buffer("psi_ker_idx", ker_idx, persistent=False) + self.register_buffer("psi_row_idx", row_idx, persistent=False) + self.register_buffer("psi_col_idx", col_idx, persistent=False) + self.register_buffer("psi_vals", vals, persistent=False) + + # also store psi just in case + if not self.optimized_kernel: + self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, semi_transposed=True) + + def extra_repr(self): + return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" + + @property + def psi_idx(self): + return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + # extract shape + B, C, H, W = x.shape + + if self.optimized_kernel: + xp = permute_to_0231(x) + xp = xp.reshape(B, H, W, self.groups, self.groupsize_in) + xpc = torch.einsum("bxygc,gock->bxygok", xp, self.weight).reshape(B, H, W, self.groups * self.groupsize_out, -1).contiguous() + outp = _disco_s2_transpose_contraction_optimized( + xpc, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out + ) + out = permute_to_0312(outp) + else: + x = x.reshape(B, self.groups, self.groupsize_in, H, W) + + # do weight multiplication + xc = torch.einsum("bgcxy,gock->bgokxy", x, self.weight).contiguous() + xc = xc.reshape(B, self.groups* self.groupsize_out, -1, H, W) + + # disco contraction + out = _disco_s2_transpose_contraction_torch(xc, self.psi_st.to(x.device), self.nlon_out) if self.bias is not None: out = out + self.bias.reshape(1, -1, 1, 1) return out - - -# class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): -# """ -# Discrete-continuous (DISCO) transpose convolutions on the 2-Sphere as described in [1]. - -# Parameters -# ----------- -# in_channels: int -# Number of input channels -# out_channels: int -# Number of output channels -# in_shape: Tuple[int] -# Input shape of the convolution tensor -# out_shape: Tuple[int] -# Output shape of the convolution tensor -# kernel_shape: Union[int, Tuple[int], Tuple[int, int]] -# Shape of the kernel -# basis_type: Optional[str] -# Type of the basis functions -# basis_norm_mode: Optional[str] -# Mode for basis normalization -# groups: Optional[int] -# Number of groups -# grid_in: Optional[str] -# Input grid type -# grid_out: Optional[str] -# Output grid type -# bias: Optional[bool] -# Whether to use bias -# theta_cutoff: Optional[float] -# Theta cutoff for the filter basis functions -# optimized_kernel: Optional[bool] -# Whether to use the optimized kernel (if available) - -# Returns -# -------- -# out: torch.Tensor -# Output tensor - -# References -# ---------- -# [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603 -# """ - -# def __init__( -# self, -# in_channels: int, -# out_channels: int, -# in_shape: Tuple[int], -# out_shape: Tuple[int], -# kernel_shape: Union[int, Tuple[int], Tuple[int, int]], -# basis_type: Optional[str] = "piecewise linear", -# basis_norm_mode: Optional[str] = "mean", -# groups: Optional[int] = 1, -# grid_in: Optional[str] = "equiangular", -# grid_out: Optional[str] = "equiangular", -# bias: Optional[bool] = True, -# theta_cutoff: Optional[float] = None, -# optimized_kernel: Optional[bool] = True, -# ): -# super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias, optimized_kernel) - -# self.nlat_in, self.nlon_in = in_shape -# self.nlat_out, self.nlon_out = out_shape - -# # make sure the p-shift works by checking that longitudes are divisible -# assert self.nlon_out % self.nlon_in == 0 - -# # bandlimit -# if theta_cutoff is None: -# theta_cutoff = torch.pi / float(self.nlat_in - 1) - -# if theta_cutoff <= 0.0: -# raise ValueError("Error, theta_cutoff has to be positive.") - -# # switch in_shape and out_shape since we want the transpose convolution -# idx, vals, _ = _precompute_convolution_tensor_s2( -# out_shape, -# in_shape, -# self.filter_basis, -# grid_in=grid_out, -# grid_out=grid_in, -# theta_cutoff=theta_cutoff, -# transpose_normalization=True, -# basis_norm_mode=basis_norm_mode, -# merge_quadrature=True, -# ) - -# # sort the values -# ker_idx = idx[0, ...].contiguous() -# row_idx = idx[1, ...].contiguous() -# col_idx = idx[2, ...].contiguous() -# vals = vals.contiguous() - -# if self.optimized_kernel: -# # preprocessed data-structure for GPU kernel -# roff_idx = preprocess_psi(self.kernel_size, self.nlat_in, ker_idx, row_idx, col_idx, vals).contiguous() -# self.register_buffer("psi_roff_idx", roff_idx, persistent=False) - -# # save all datastructures -# self.register_buffer("psi_ker_idx", ker_idx, persistent=False) -# self.register_buffer("psi_row_idx", row_idx, persistent=False) -# self.register_buffer("psi_col_idx", col_idx, persistent=False) -# self.register_buffer("psi_vals", vals, persistent=False) - -# # also store psi just in case -# if not self.optimized_kernel: -# self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, semi_transposed=True) - -# def extra_repr(self): -# return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" - -# @property -# def psi_idx(self): -# return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous() - -# def forward(self, x: torch.Tensor) -> torch.Tensor: - -# # extract shape -# B, C, H, W = x.shape -# x = x.reshape(B, self.groups, self.groupsize, H, W) - -# # do weight multiplication -# x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous() -# x = x.reshape(B, -1, x.shape[-3], H, W) - -# if self.optimized_kernel: -# out = _disco_s2_transpose_contraction_optimized( -# x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out -# ) -# else: -# out = _disco_s2_transpose_contraction_torch(x, self.psi_st.to(x.device), self.nlon_out) - -# if self.bias is not None: -# out = out + self.bias.reshape(1, -1, 1, 1) - -# return out diff --git a/torch_harmonics/disco/csrc/disco_cpu.cpp b/torch_harmonics/disco/csrc/disco_cpu.cpp index b959802f..1549789b 100644 --- a/torch_harmonics/disco/csrc/disco_cpu.cpp +++ b/torch_harmonics/disco/csrc/disco_cpu.cpp @@ -34,7 +34,7 @@ namespace disco_kernels { // cpu ops torch::Tensor disco_cpu_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, - torch::Tensor col_idx, torch::Tensor vals, torch::Tensor weights, int64_t K, int64_t Ho, int64_t Wo) { + torch::Tensor col_idx, torch::Tensor vals, int64_t K, int64_t Ho, int64_t Wo) { // sanity checks CHECK_CPU_INPUT_TENSOR(inp); @@ -44,12 +44,16 @@ namespace disco_kernels { CHECK_CPU_INPUT_TENSOR(col_idx); CHECK_CPU_INPUT_TENSOR(vals); + // convert input to fp32 + auto inp_dtype = inp.scalar_type(); + inp = inp.to(torch::kFloat32).contiguous(); + // initialize output tensor - auto out = torch::zeros({inp.size(0), inp.size(1), K, Ho, Wo}, inp.options()); + auto out = torch::zeros({inp.size(0), Ho, Wo, inp.size(3), K}, inp.options()); AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cpu", ([&] { disco_fwd_cpu( - inp.size(0), inp.size(1), K, inp.size(2), inp.size(3), + inp.size(0), inp.size(3), K, inp.size(1), inp.size(2), Ho, Wo, vals.size(0), roff_idx.size(0) - 1, inp.packed_accessor64(), roff_idx.packed_accessor64(), @@ -60,11 +64,14 @@ namespace disco_kernels { out.packed_accessor64()); })); + // convert to input datatype + out = out.to(inp_dtype); + return out; } torch::Tensor disco_cpu_bwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, - torch::Tensor col_idx, torch::Tensor vals, torch::Tensor weights, int64_t K, int64_t Ho, int64_t Wo) { + torch::Tensor col_idx, torch::Tensor vals, int64_t K, int64_t Ho, int64_t Wo) { // sanity checks CHECK_CPU_INPUT_TENSOR(inp); @@ -73,14 +80,18 @@ namespace disco_kernels { CHECK_CPU_INPUT_TENSOR(row_idx); CHECK_CPU_INPUT_TENSOR(col_idx); CHECK_CPU_INPUT_TENSOR(vals); + + // convert input to fp32 + auto inp_dtype = inp.scalar_type(); + inp = inp.to(torch::kFloat32).contiguous(); - // initialize output tensor - auto out = torch::zeros({inp.size(0), inp.size(1), Ho, Wo}, inp.options()); + // initialize output tensor: assume channels last format + auto out = torch::zeros({inp.size(0), Ho, Wo, inp.size(3)}, inp.options()); AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cpu", ([&] { disco_bwd_cpu( - inp.size(0), inp.size(1), K, inp.size(3), - inp.size(4), Ho, Wo, vals.size(0), roff_idx.size(0) - 1, + inp.size(0), inp.size(3), K, inp.size(1), inp.size(2), + Ho, Wo, vals.size(0), roff_idx.size(0) - 1, inp.packed_accessor64(), roff_idx.packed_accessor64(), ker_idx.packed_accessor64(), @@ -90,6 +101,9 @@ namespace disco_kernels { out.packed_accessor64()); })); + // convert to input datatype + out = out.to(inp_dtype); + return out; } diff --git a/torch_harmonics/disco/csrc/disco_cpu.h b/torch_harmonics/disco/csrc/disco_cpu.h index 29b242c2..921ff5ae 100644 --- a/torch_harmonics/disco/csrc/disco_cpu.h +++ b/torch_harmonics/disco/csrc/disco_cpu.h @@ -55,10 +55,10 @@ namespace disco_kernels { const int64_t nblock_wo = static_cast((Wo + block_wo - 1) / block_wo); // loop over matrix entries - #pragma omp parallel for collapse(3) + #pragma omp parallel for simd collapse(3) for (int64_t b = 0; b < B; b++) { - for (int64_t c = 0; c < C; c++) { - for (int64_t row = 0; row < nnr; row++) { + for (int64_t row = 0; row < nnr; row++) { + for (int64_t c = 0; c < C; c++) { // since the rows are ordered accordingly, we can compute ho and ker in here int64_t ho = row_idx[roff_idx[row]]; @@ -89,12 +89,12 @@ namespace disco_kernels { for (int64_t wo = wo_start; wo < wo_end; wo++) { // compute shifted w int64_t wipp = static_cast((wi + pscale * wo) % Wi); - out_tmp[wo-wo_start] += val * inp[b][c][hi][wipp]; + out_tmp[wo-wo_start] += val * inp[b][hi][wipp][c]; } } // write out for (int64_t wo = wo_start; wo < wo_end; wo++) { - out[b][c][ker][ho][wo] = out_tmp[wo-wo_start]; + out[b][ho][wo][c][ker] = out_tmp[wo-wo_start]; } } } @@ -117,7 +117,7 @@ namespace disco_kernels { const int64_t pscale = static_cast(Wo / Wi); // loop over matrix entries - #pragma omp parallel for collapse(2) + #pragma omp parallel for simd collapse(2) for (int64_t b = 0; b < B; b++) { for (int64_t c = 0; c < C; c++) { @@ -142,7 +142,7 @@ namespace disco_kernels { for (int64_t wi = 0; wi < Wi; wi++) { // compute shifted w int64_t wopp = static_cast((wo + pscale * wi) % Wo); - out[b][c][ho][wopp] += val * inp[b][c][ker][hi][wi]; + out[b][ho][wopp][c] += val * inp[b][hi][wi][c][ker]; } } } From 544d0c706fce3639eeec66f9a9d74633ec77feee Mon Sep 17 00:00:00 2001 From: Mauro Bisson Date: Tue, 2 Sep 2025 17:17:15 -0700 Subject: [PATCH 10/31] Added preliminary version of channel-last BWD kernel. --- torch_harmonics/disco/csrc/disco_cuda_bwd.cu | 898 ++++++++++++++++++- torch_harmonics/disco/csrc/disco_cuda_fwd.cu | 136 +-- torch_harmonics/utils/csrc/csr_cuda.cuh | 17 + 3 files changed, 922 insertions(+), 129 deletions(-) diff --git a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu index 710f7ae8..fa747ae4 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu @@ -31,6 +31,11 @@ #include "disco.h" #include "disco_cuda.cuh" #include "csr_cuda.cuh" +#include "cudamacro.h" + +#define THREADS (64) + +#define MAX_LOCAL_ARR_LEN (16) namespace disco_kernels { @@ -201,6 +206,825 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t return; } +// BEGIN NEW CHANNEL-LAST VERSION + +template +static __global__ void pack_vals_k(const int64_t K, + const int64_t nlat_out, + const int64_t *__restrict__ row_off, + const VAL_T *__restrict__ val_dat, + VAL_T *__restrict__ val_pck) { + + const int tidx = threadIdx.x; + const int wid = blockIdx.x*blockDim.y + threadIdx.y; + if (wid >= nlat_out) { + return; + } + + const int64_t rbeg = row_off[wid]; + const int64_t rend = row_off[wid+1]; + + const int rlen = rend-rbeg; + + val_pck += rbeg*K; + + for(int off = tidx; off < rlen; off += blockDim.x) { + for(int ker = 0; ker < K; ker++) { + + val_pck[off*K + ker] = val_dat[ row_off[ker*nlat_out + wid] + off]; + } + } + + return; +} + +template +static __device__ void processCSR_Kpow2_shm_d(const int wi, + const int rlen, + const int nchans, // no. of input FLOATV_T elements along channel dim + const int nlon_out, + const int pscale, + const int K, + const FLOATV_T *__restrict__ shx, + const int64_t *__restrict__ cols, + const FLOATV_T *__restrict__ vals, + float *__restrict__ shy, + float *__restrict__ y) { + const int tidx = threadIdx.x; + + const int log2_K = __ffs(K)-1; + + const int tidxDivK = tidx >> log2_K; + const int tidxModK = tidx & (K-1); + + vals += tidxModK; + + const int BDIMX_div_K = BDIM_X >> log2_K; + + for(int chan = tidx; chan < nchans; chan += WARP_SIZE) { + shy[chan] = 0; + } + __syncwarp(); + + for(int off = 0; off < rlen; off++) { + + const int64_t col = cols[off]; + + const int ho = col / nlon_out; + const int wo = col - (ho*nlon_out); + + int wop = wo + pscale*wi; + wop -= (wop / nlon_out)*nlon_out; + + float *_y = y + int64_t(ho)*nlon_out*nchans + int64_t(wop)*nchans; + + float *_shy = shy + tidxDivK; + + const FLOATV_T myval = vals[0]; + + for(int i = 0; i < nchans*K; i += WARP_SIZE) { + + float sum = (i+tidx < nchans*K) ? __vred(__vmul(myval, shx[i+tidx])) : 0; + + for(int j = 1; j < K; j *= 2) { + sum += __shfl_xor_sync(FULL_MASK, sum, j); + } + + if (i+tidx < nchans*K && !tidxModK) { + _shy[0] += sum; + } + _shy += BDIMX_div_K; + } + __syncwarp(); + + for(int chan = tidx; chan < nchans; chan += WARP_SIZE) { + atomicAdd(_y+chan, shy[chan]); + shy[chan] = 0; + } + __syncwarp(); + + vals += K; + } + + return; +} + +template +static __device__ void processCSR_Kanyv_shm_d(const int wi, + const int rlen, + const int nchans, // no. of input FLOATV_T elements along channel dim + const int nlon_out, + const int pscale, + const int K, + const FLOATV_T *__restrict__ shx, + const int64_t *__restrict__ cols, + const FLOATV_T *__restrict__ vals, + float *__restrict__ shy, + float *__restrict__ y) { + const int tidx = threadIdx.x; + + for(int chan = tidx; chan < nchans; chan += WARP_SIZE) { + shy[chan] = 0; + } + __syncwarp(); + + for(int off = 0; off < rlen; off++) { + + const int64_t col = cols[off]; + + const int ho = col / nlon_out; + const int wo = col - (ho*nlon_out); + + int wop = wo + pscale*wi; + wop -= (wop / nlon_out)*nlon_out; + + float *_y = y + int64_t(ho)*nlon_out*nchans + int64_t(wop)*nchans; + + for(int chan = tidx; chan < nchans*K; chan += WARP_SIZE) { + + const int cDivK = chan / K; + const int cModK = chan - (cDivK*K); + + float sum = __vred(__vmul(vals[cModK], shx[chan])); + + atomicAdd(shy+cDivK, sum); + } + __syncwarp(); + + for(int chan = tidx; chan < nchans; chan += WARP_SIZE) { + atomicAdd(_y+chan, shy[chan]); + shy[chan] = 0; + } + __syncwarp(); + + vals += K; + } + + return; +} + +template // either float or float4 +__global__ +__launch_bounds__(BDIM) +void s2_disco_bwd_generic_vec_k(int nchans, // no. of input float (not FLOATV_T!) elements along channel dim + int nlat_in, + int nlon_in, + int nlat_out, + int nlon_out, + int pscale, + int K, // no. of output FLOATV_T elem along K dim (kernel size) + const FLOATV_T *__restrict__ x, + const int32_t *__restrict__ row_idx, + const int64_t *__restrict__ row_off, + const int64_t *__restrict__ col_idx, + const FLOATV_T *__restrict__ val_pck, + float *__restrict__ y) { + + constexpr int VEC_SIZE = sizeof(FLOATV_T) / sizeof(float); + + const int tidx = threadIdx.x; + + const int batch = blockIdx.y; + const int ctaid = blockIdx.x*blockDim.y + threadIdx.y; + + if (ctaid >= nlat_in*nlon_in) { + return; + } + +#if 1 + const int h = ctaid / nlon_in; + const int wi = ctaid - (h*nlon_in); + const int hi = row_idx[h]; +#else + // for now don't use row_idx + const int hi = ctaid / nlon_in; + const int wi = ctaid - (hi*nlon_in); +#endif + + x += int64_t(batch)*nlat_in*nlon_in*nchans*K + int64_t(hi)*nlon_in*nchans*K + int64_t(wi)*nchans*K; + y += int64_t(batch)*nlat_out*nlon_out*nchans; + + extern __shared__ __align__(sizeof(float4)) float shext[]; + FLOATV_T *shx = reinterpret_cast(shext + (VEC_SIZE*nchans*K + nchans)*threadIdx.y); + float *shy = reinterpret_cast(shx + nchans*K); + + for(int chan = tidx; chan < nchans*K; chan += WARP_SIZE) { + shx[chan] = x[chan]; + } + + const int64_t rbeg = row_off[hi]; + const int64_t rend = row_off[hi+1]; + + col_idx += rbeg; + val_pck += rbeg*K; // val_pck CSR contains K values per element + + const int rlen = rend-rbeg; + + // check if BDIM_X is a multiple of K; since BDIM_X is a power of 2, check if K is also a power of two + if (!(K & K-1) && K <= WARP_SIZE) { processCSR_Kpow2_shm_d(wi, rlen, nchans, nlon_out, pscale, K, shx, col_idx, val_pck, shy, y); } + else { processCSR_Kanyv_shm_d(wi, rlen, nchans, nlon_out, pscale, K, shx, col_idx, val_pck, shy, y); } + + return; +} + +template +static __device__ void processCSR_Kpow2_reg_d(const int wi, + const int rlen, + const int nchans, // no. of input FLOATV_T elements along channel dim + const int nlon_out, + const int pscale, + const int K, + const FLOATV_T (&locx)[NLOC], + const int64_t *__restrict__ cols, + const FLOATV_T *__restrict__ vals, + float *(&shYOff)[BDIM_X+SHPAD], + float *__restrict__ shy, // NO LONGER USED + float *__restrict__ y) { + constexpr int NLOC_M1 = NLOC-1; + + const int tidx = threadIdx.x; + + unsigned int subwarp_mask = FULL_MASK; + + if constexpr(BDIM_X <= WARP_SIZE) { + const int tidy = threadIdx.y; + constexpr unsigned int MASK = (1ull << BDIM_X)-1; + subwarp_mask = MASK << (tidy*BDIM_X); + } + + // K is a power of two <= BDIM_X + const int log2_K = __popc(K-1); + + const int tidxDivK = tidx >> log2_K; + const int tidxModK = tidx & (K-1); + + cols += tidx; + vals += tidxModK; + + const int BDIMX_div_K = BDIM_X >> log2_K; + + for(int off = 0; off < rlen; off++) { + if ((off % BDIM_X) == 0) { + __sync(); + + const int64_t col = (off+tidx < rlen) ? cols[0] : 0; + + const int ho = col / nlon_out; + const int wo = col - (ho*nlon_out); + + int wop = wo + pscale*wi; + wop -= (wop / nlon_out)*nlon_out; + + shYOff[tidx] = y + int64_t(ho)*nlon_out*nchans + int64_t(wop)*nchans; + cols += BDIM_X; + + __sync(); + } + + float *_y = shYOff[off % BDIM_X] + tidxDivK; + + const FLOATV_T myval = vals[0]; + + float locy[NLOC]; + + #pragma unroll + for(int i = 0; i < NLOC; i++) { + locy[i] = __vred(__vmul(myval, locx[i])); + } + + // K is a power of two <= 32 + #pragma unroll + for(int j = 1; j < BDIM_X; j *= 2) { + + if (j >= K) break; + + #pragma unroll + for(int i = 0; i < NLOC; i++) { + locy[i] += __shfl_xor_sync(subwarp_mask, locy[i], j); + } + } + + if (!tidxModK) { + // NLOC*BDIM_X >= nchans*K + // NLOC_M1*BDIM_X < nchans*K => NLOC_M1*BDIM_X/K < nchans + + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + atomicAdd(_y + i*BDIMX_div_K, locy[i]); + } + if (NLOC_M1*BDIM_X+tidx < nchans*K) { + atomicAdd(_y + NLOC_M1*BDIMX_div_K, locy[NLOC_M1]); + } + } + vals += K; + } + + return; +} + +#if 0 +template +static __device__ void processCSR_Kpow2_reg_d2(const int wi, + const int rlen, + const int nchans, // no. of input FLOATV_T elements along channel dim + const int nlon_out, + const int pscale, + const int K, + const FLOATV_T (&locx)[NLOC], + const int64_t *__restrict__ cols, + const FLOATV_T *__restrict__ vals, + float *(&shYOff)[BDIM_X+SHPAD], + float *__restrict__ shy, // NO LONGER USED + float *__restrict__ y) { + constexpr int NLOC_M1 = NLOC-1; + + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + + const int log2_K = __ffs(K)-1; + + const int tidxDivK = tidx >> log2_K; + const int tidxModK = tidx & (K-1); + + vals += tidxModK; + + const int BDIMX_div_K = BDIM_X >> log2_K; +#if 1 + for(int chan = tidx; chan < nchans; chan += BDIM_X) { + shy[chan] = 0; + } + __sync(); +#endif + + cols += tidx; + + for(int off = 0; off < rlen; off++) { + if ((off % BDIM_X) == 0) { + __sync(); + + const int64_t col = (off+tidx < rlen) ? cols[0] : 0; + + const int ho = col / nlon_out; + const int wo = col - (ho*nlon_out); + + int wop = wo + pscale*wi; + wop -= (wop / nlon_out)*nlon_out; + + shYOff[tidx] = y + int64_t(ho)*nlon_out*nchans + int64_t(wop)*nchans; + cols += BDIM_X; + + __sync(); + } + + float *_y = shYOff[off % BDIM_X];// + tidxDivK; + float *_shy = shy + tidxDivK; + + const FLOATV_T myval = vals[0]; + + float locy[NLOC]; + + #pragma unroll + for(int i = 0; i < NLOC; i++) { + locy[i] = __vred(__vmul(myval, locx[i])); + } + + #pragma unroll + for(int i = 0; i < NLOC; i++) { + + // K is a power of two <= 32 + for(int j = 1; j < K; j *= 2) { + locy[i] += __shfl_xor_sync(FULL_MASK, locy[i], j); + } + } + + if (!tidxModK) { + // NLOC*BDIM_X >= nchans*K + // NLOC_M1*BDIM_X < nchans*K => NLOC_M1*BDIM_X/K < nchans + #pragma unroll + for(int i = 0; i < NLOC; i++) { + _shy[i*BDIMX_div_K] += locy[i]; + } + // if (NLOC_M1*BDIM_X+tidx < nchans*K) { + // _shy[NLOC_M1*BDIMX_div_K] += locy[NLOC_M1]; + // } + } + __sync(); + + for(int chan = tidx; chan < nchans; chan += BDIM_X) { + atomicAdd(_y+chan, shy[chan]); + shy[chan] = 0; + } + __sync(); + + vals += K; + } + + return; +} +#endif + +template +static __device__ void processCSR_Kanyv_reg_d(const int wi, + const int rlen, + const int nchans, // no. of input FLOATV_T elements along channel dim + const int nlon_out, + const int pscale, + const int K, + const FLOATV_T (&locx)[NLOC], + const int64_t *__restrict__ cols, + const FLOATV_T *__restrict__ vals, + float *(&shYOff)[BDIM_X+SHPAD], + float *__restrict__ shy, + float *__restrict__ y) { + const int tidx = threadIdx.x; + + for(int chan = tidx; chan < nchans; chan += BDIM_X) { + shy[chan] = 0; + } + __sync(); + + cols += tidx; + + for(int off = 0; off < rlen; off++) { + + if ((off % BDIM_X) == 0) { + __sync(); + + const int64_t col = (off+tidx < rlen) ? cols[0] : 0; + + const int ho = col / nlon_out; + const int wo = col - (ho*nlon_out); + + int wop = wo + pscale*wi; + wop -= (wop / nlon_out)*nlon_out; + + shYOff[tidx] = y + int64_t(ho)*nlon_out*nchans + int64_t(wop)*nchans; + cols += BDIM_X; + + __sync(); + } + + float *_y = shYOff[off % BDIM_X]; + + // shy is allocated as ceil(nchans / (BDIM_X/K))*(BDIM_X/K) + // so we can just loop NLOC timss + #pragma unroll + for(int i = 0; i < NLOC; i++) { + + const int chan = i*BDIM_X+tidx; + const int cDivK = chan / K; + const int cModK = chan - (cDivK*K); + + float sum = __vred(__vmul(vals[cModK], locx[i])); + + atomicAdd(shy+cDivK, sum); + } + __sync(); + + for(int chan = tidx; chan < nchans; chan += BDIM_X) { + atomicAdd(_y+chan, shy[chan]); + shy[chan] = 0; + } + __sync(); + + vals += K; + } + + return; +} + +template // either float or float4 +__global__ +__launch_bounds__(BDIM_X*BDIM_Y) +void s2_disco_bwd_special_vec_k(int nchans, // no. of input float (not FLOATV_T!) elements along channel dim + int nlat_in, + int nlon_in, + int nlat_out, + int nlon_out, + int pscale, + int K, // no. of output FLOATV_T elem along K dim (kernel size) + const FLOATV_T *__restrict__ x, + const int32_t *__restrict__ row_idx, + const int64_t *__restrict__ row_off, + const int64_t *__restrict__ col_idx, + const FLOATV_T *__restrict__ val_pck, + float *__restrict__ y) { + + static_assert(0 == (BDIM_X & (BDIM_X-1))); + static_assert(0 == (BDIM_Y & (BDIM_Y-1))); + static_assert((BDIM_X <= 32 && BDIM_Y > 1) || + (BDIM_X > 32 && BDIM_Y == 1)) ; + + constexpr int NLOC_M1 = NLOC-1; + + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + + const int batch = blockIdx.y; + const int ctaid = blockIdx.x*blockDim.y + threadIdx.y; + + if (ctaid >= nlat_in*nlon_in) { + return; + } + +#if 1 + const int h = ctaid / nlon_out; + const int wi = ctaid - (h*nlon_out); + const int hi = row_idx[h]; +#else + // for now don't use row_idx + const int hi = ctaid / nlon_out; + const int wi = ctaid - (hi*nlon_out); +#endif + + x += int64_t(batch)*nlat_in*nlon_in*nchans*K + int64_t(hi)*nlon_in*nchans*K + int64_t(wi)*nchans*K + tidx; + y += int64_t(batch)*nlat_out*nlon_out*nchans; + + FLOATV_T locx[NLOC]; + + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + locx[i] = x[i*BDIM_X]; + } + locx[NLOC_M1] = __vset(0.0f); + if (NLOC_M1*BDIM_X+tidx < nchans*K) { + locx[NLOC_M1] = x[NLOC_M1*BDIM_X]; + } + + // only used if K is not a multiple of 2 + extern __shared__ __align__(sizeof(float4)) float shext[]; + float *shy = shext + nchans*threadIdx.y; + + const int64_t rbeg = row_off[hi]; + const int64_t rend = row_off[hi+1]; + + col_idx += rbeg; + val_pck += rbeg*K; // val_pck CSR contains K values per element + + const int rlen = rend-rbeg; + + constexpr int PAD = (BDIM_X < WARP_SIZE) ? 1 : 0; + __shared__ float *shYOffAll[BDIM_Y][BDIM_X+PAD]; + + // check if BDIM_X is a multiple of K; since BDIM_X is a power of 2, check if K is also a power of two + if (!(K & K-1) && K <= BDIM_X) { processCSR_Kpow2_reg_d(wi, rlen, nchans, nlon_out, pscale, K, locx, col_idx, val_pck, shYOffAll[tidy], NULL, y); } + else { processCSR_Kanyv_reg_d(wi, rlen, nchans, nlon_out, pscale, K, locx, col_idx, val_pck, shYOffAll[tidy], shy, y); } + + return; + +} + +template +void launch_gen_disco_bwd(int64_t batch_size, + int64_t nchans, + int64_t nlat_in, + int64_t nlon_in, + int64_t nlat_out, + int64_t nlon_out, + int64_t K, + FLOATV_T *__restrict__ _xp, + int32_t *_row_idx, + int64_t *_row_off, + int64_t *_col_idx, + FLOATV_T *_val_pck, + float *__restrict__ _yp, + cudaStream_t stream) { + + dim3 block(WARP_SIZE, THREADS/WARP_SIZE); + dim3 grid(DIV_UP(nlat_in*nlon_in, block.y), batch_size); + + size_t shsize = (sizeof(FLOATV_T)*(nchans*K) + sizeof(float)*nchans)*block.y; + + const int pscale = nlon_out / nlon_in; +#if 1 + printf("Launching s2_disco_bwd_generic_vec_k<%d, float%s><<<..., ..., %zu, ...>>> with:\n" + "\tnchan_out: %ld\n" + "\tK: %ld\n\n", + THREADS, sizeof(FLOATV_T)==16?"4":"", shsize, nchans, K); +#endif + // will use only the first 1/K-th of the CSR, i.e. only the first nlat_out rows + s2_disco_bwd_generic_vec_k + <<>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out, pscale, K, + _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp); + CHECK_ERROR("s2_disco_bwd_generic_vec_k"); + + return; +} + +template +void launch_spc_disco_bwd(int nloc, // "BDIM_X*nloc" >= nchans + int64_t batch_size, + int64_t nchans, + int64_t nlat_in, + int64_t nlon_in, + int64_t nlat_out, + int64_t nlon_out, + int64_t K, + FLOATV_T *__restrict__ _xp, + int32_t *_row_idx, + int64_t *_row_off, + int64_t *_col_idx, + FLOATV_T *_val_pck, + float *__restrict__ _yp, + cudaStream_t stream) { + + if (CUR_LOC_SIZE == nloc) { + + // block size set to 64 threads + constexpr int BDIM_Y = (BDIM_X <= WARP_SIZE) ? THREADS / BDIM_X : 1; + + // groups in gridDim.y + dim3 block(BDIM_X, BDIM_Y); + dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size); + + // could be (BDIM_X/K) instead of BDIM_X but let's keep it simple + size_t shsize = (K & (K-1)) ? sizeof(float)*DIV_UP(nchans, BDIM_X)*BDIM_X*block.y : 0; + //size_t shsize = sizeof(float)*DIV_UP(nchans, BDIM_X)*BDIM_X*block.y; + + const int pscale = nlon_out / nlon_in; +#if 1 + printf("Launching s2_disco_bwd_special_vec_k<%d, %d, %d, float%s><<<(%d, %d), (%d, %d), ..., %zu, ...>>> with:\n" + "\tnchans: %ld\n" + "\tK: %ld\n\n", + BDIM_X, BDIM_Y, CUR_LOC_SIZE, sizeof(FLOATV_T)==16?"4":"", grid.x, grid.y, block.x, block.y, shsize, nchans, K); +#endif + s2_disco_bwd_special_vec_k + <<>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out, pscale, K, + _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp); + + CHECK_ERROR("s2_disco_bwd_special_vec_k"); + + return; + } + if constexpr(CUR_LOC_SIZE < MAX_LOC_SIZE) { + launch_spc_disco_bwd(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, + K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); + } + return; +} + +static void s2_disco_bwd_dispatch(int64_t batch_size, + int64_t nchans, + int64_t nlat_in, + int64_t nlon_in, + int64_t nlat_out, + int64_t nlon_out, + int64_t K, + at::Tensor xP, + at::Tensor row_off, // CSR row offsets + at::Tensor col_idx, // CSR column indices + at::Tensor val_dat, // CSR value data + at::Tensor yP) { + + static_assert(0 == (MAX_LOCAL_ARR_LEN & (MAX_LOCAL_ARR_LEN-1))); + + if (batch_size <= 0 || + nchans <= 0 || + nlon_in <= 0 || + nlat_out <= 0 || + nlon_out <= 0 || + K <= 0 || + K > WARP_SIZE) { + + fprintf(stderr, + ":%s:%d: invalid value of one or more input parameters!\n", + __FILE__, __LINE__); + exit(EXIT_FAILURE); + } + + // get stream + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + // sort row indices (ho-s) in descending order + // based on (row_off[ho+1]-row_off[ho]) + at::Tensor row_idx = sortRows(nlat_in, row_off, stream); + + + // replace the K sequential CRSs in "val_dat": + // + // val_dat[ 0: nnz/K) for ker = 0 + // val_dat[nnz/K:2*nnz/K) for ker = 1 + // ... + // val_dat[nnz/K:2*nnz/K) for ker = K-1 + // + // with a packed CSR: + // + // val_dat[nnz/K][K], i.e. with a CSR where elements of the original K CSRs are packed in consecutive elements + assert(0 == (val_idx.size(0) % K)); + + // move into "disco_cuda_utils.cu" IF val_dat format won't be changed upstream in the call chain + int64_t val_dims[] = {val_dat.size(0)}; + auto options = torch::TensorOptions().device(val_dat.device()).dtype(val_dat.dtype()); + torch::Tensor val_pck = torch::zeros(val_dims, options); + { + dim3 block(WARP_SIZE, THREADS/WARP_SIZE); + dim3 grid(DIV_UP(nlat_out, block.y)); + + pack_vals_k<<>>(K, nlat_out, + row_off.data_ptr(), + val_dat.data_ptr(), + val_pck.data_ptr()); + } + // if K is a multiple of VEC_SIZE it will be read with vector lds + + // smallest power of two "bdimx" (>=4) s.t. bdimx*MAX_LOCAL_ARR_LEN >= nchans*K + int bdimx; + bdimx = DIV_UP(nchans*K, MAX_LOCAL_ARR_LEN); + bdimx = max(bdimx, WARP_SIZE/8); // min 4 threads per group + bdimx = next_pow2(bdimx); + + float *_xp = reinterpret_cast(xP.data_ptr()); + float *_yp = reinterpret_cast(yP.data_ptr()); + + int32_t *_row_idx = reinterpret_cast(row_idx.data_ptr()); + int64_t *_row_off = reinterpret_cast(row_off.data_ptr()); + int64_t *_col_idx = reinterpret_cast(col_idx.data_ptr()); + float *_val_pck = reinterpret_cast(val_pck.data_ptr()); + + constexpr int VEC_SIZE = sizeof(float4) / sizeof(float); + + if (!is_aligned(_yp) || + !is_aligned(_xp) || + !is_aligned(_val_pck) || + (K % VEC_SIZE) != 0) { + + //printf("%s:%d: VEC_SIZE: %d, nchans: %d, K: %d, _xp: %p, _yp: %p\n", __func__, __LINE__, VEC_SIZE, nchans, K, _xp, _yp); + + const int nloc = DIV_UP(nchans*K, bdimx); + + // to avoid the compilation of unused template instances; + // we use a block size BDIM_X that is the smallest power of 2 + // such that BDIM_X*MAX_LOCAL_ARR_LEN >= nchans*K, so + // BDIM_X > 32 are used only for: + // + // (BDIM_X-1)*MAX_LOCAL_ARR_LEN < nchan <= BDIM_X*MAX_LOCAL_ARR_LEN + constexpr int MIN_LOC_ARR_LEN = MAX_LOCAL_ARR_LEN/2+1; + + // use 2D blocks only if 32 threads are enough + switch(bdimx) { + case 8: launch_spc_disco_bwd< 8, 1, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; + case 16: launch_spc_disco_bwd< 16, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; + case 32: launch_spc_disco_bwd< 32, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; + case 64: launch_spc_disco_bwd< 64, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; + case 128: launch_spc_disco_bwd< 128, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; + case 256: launch_spc_disco_bwd< 256, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; + case 512: launch_spc_disco_bwd< 512, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; + case 1024: launch_spc_disco_bwd<1024, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; + default: launch_gen_disco_bwd ( batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; + } + + } else { + + //printf("%s:%d: VEC_SIZE: %d, nchans: %d, K: %d, _xp: %p, _yp: %p\n", __func__, __LINE__, VEC_SIZE, nchans, K, _xp, _yp); + + float4 *_xp4 = reinterpret_cast(_xp); + //float4 *_yp4 = reinterpret_cast(_yp); + + float4 *_val_pck4 = reinterpret_cast(_val_pck); + + K /= VEC_SIZE; + const int nloc = DIV_UP(nchans*K, bdimx); + + constexpr int MAX_LOCAL_VEC_LEN = MAX_LOCAL_ARR_LEN / VEC_SIZE; + constexpr int MIN_LOC_VEC_LEN = MAX_LOCAL_VEC_LEN/2+1; + + // use 2D blocks only if 32 threads are enough + switch(bdimx) { + case 8: launch_spc_disco_bwd< 8, 1, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, _row_idx, _row_off, _col_idx, _val_pck4, _yp, stream); break; + case 16: launch_spc_disco_bwd< 16, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, _row_idx, _row_off, _col_idx, _val_pck4, _yp, stream); break; + case 32: launch_spc_disco_bwd< 32, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, _row_idx, _row_off, _col_idx, _val_pck4, _yp, stream); break; + case 64: launch_spc_disco_bwd< 64, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, _row_idx, _row_off, _col_idx, _val_pck4, _yp, stream); break; + case 128: launch_spc_disco_bwd< 128, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, _row_idx, _row_off, _col_idx, _val_pck4, _yp, stream); break; + case 256: launch_spc_disco_bwd< 256, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, _row_idx, _row_off, _col_idx, _val_pck4, _yp, stream); break; + case 512: launch_spc_disco_bwd< 512, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, _row_idx, _row_off, _col_idx, _val_pck4, _yp, stream); break; + case 1024: launch_spc_disco_bwd<1024, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, _row_idx, _row_off, _col_idx, _val_pck4, _yp, stream); break; + default: launch_gen_disco_bwd ( batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, _row_idx, _row_off, _col_idx, _val_pck4, _yp, stream); break; + } + } + return; +} + +// END NEW CHANNEL-LAST VERSION torch::Tensor disco_cuda_bwd(torch::Tensor ograd, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo) { @@ -212,7 +1036,7 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t CHECK_CUDA_INPUT_TENSOR(row_idx); CHECK_CUDA_INPUT_TENSOR(col_idx); CHECK_CUDA_INPUT_TENSOR(val); - +#if 0 // extract some shapes int64_t B = ograd.size(0); int64_t Hi = ograd.size(1); @@ -281,7 +1105,79 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t 1024 * ELXTH_MAX); exit(EXIT_FAILURE); } +#else + // extract some shapes + int64_t batch_size = ograd.size(0); + int64_t nlat_in = ograd.size(1); + int64_t nlon_in = ograd.size(2); + int64_t nchan = ograd.size(3); + int64_t nrows = roff_idx.size(0) - 1; + + printf("%s:%d: batch_size: %ld, nlat_in: %ld, nlon_in: %ld, C: %ld\n", + __func__, __LINE__, batch_size, nlat_in, nlon_in, nchan); + + if (nchan % K) { + fprintf(stderr, + "%s:%d: error, number of channles of output gradient (%ld) is expected to be a multiple of kernel size (%ld)!\n", + __func__, __LINE__, nchan, K); + exit(EXIT_FAILURE); + } + int64_t Co = nchan/K; + + printf("K: %ld, Cin: %ld\n", K, nchan/K); + + int64_t nlat_out = Ho; + int64_t nlon_out = Wo; + + // allocate output + int64_t out_dims[] = {batch_size, Ho, Wo, Co}; + + // get stream + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + // extract dtype + auto x_type = ograd.dtype(); + torch::Tensor xP = ograd.to(torch::kFloat32); + + torch::Tensor igrad = torch::zeros(out_dims, xP.options()); + +#if 0 + printf("%s:%d: tensors info:\n", __func__, __LINE__); + printf("\tbatch_size: %ld\n", batch_size); + printf("\t nlat_in: %ld\n", nlat_in); + printf("\t nlon_in: %ld\n", nlon_in); + printf("\t C: %ld\n", nchan); + printf("\t K: %ld\n\n", K); + printf("\troff_idx.size(0)-1 == nlat_in*K: %d\n", roff_idx.size(0)-1 == nlat_in*K); + //printf("\tinp channle-last: %d\n", x_is_channels_last); + printf("\treshaped inp to: {%ld, %ld, %ld, %ld}\n", xP.size(0), xP.size(1), xP.size(2), xP.size(3)); + fflush(stdout); + //exit(1); +#endif + + // call channel-last kernel implementation + s2_disco_bwd_dispatch(batch_size, + Co, //nchan, + nlat_in, + nlon_in, + nlat_out, + nlon_out, + K, + xP, + roff_idx, + col_idx, + val, + igrad); +/* + // switch back to original layout; + torch::Tensor out = yP; + if (!x_is_channels_last) { + out = permute_4D_to0312(yP); + // make y {batch_size, nchan, nlat_out, nlon_out} + } +*/ +#endif // convert back to original dtype igrad = igrad.to(x_type); diff --git a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu index 3328ae39..2a06bb93 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu @@ -423,15 +423,6 @@ __device__ void processCSR_Kpow2_reg_d(const int wo, const int tidx = threadIdx.x; const int tidy = threadIdx.y; - // unused if BDIM_X > WARP_SIZE - unsigned int subwarp_mask = FULL_MASK; - - if constexpr(BDIM_X <= WARP_SIZE) { - constexpr unsigned int MASK = (1ull << BDIM_X)-1; - subwarp_mask = MASK << (tidy*BDIM_X); - } - - // only used in K_POWER_2==1 branch const int log2_K = __ffs(K)-1; const int tidxDivK = tidx >> log2_K; @@ -445,8 +436,7 @@ __device__ void processCSR_Kpow2_reg_d(const int wo, for(int off = 0; off < rlen; off++) { if ((off % BDIM_X) == 0) { - if constexpr(BDIM_X <= WARP_SIZE) { __syncwarp(subwarp_mask); } - else { __syncthreads(); } + __sync(); const int64_t col = (off+tidx < rlen) ? cols[0] : 0; @@ -462,8 +452,7 @@ __device__ void processCSR_Kpow2_reg_d(const int wo, shXOff[tidx] = x + int64_t(hi)*nlon_in*nchan_in + int64_t(wip)*nchan_in; cols += BDIM_X; - if constexpr(BDIM_X <= WARP_SIZE) { __syncwarp(subwarp_mask); } - else { __syncthreads(); } + __sync(); } const float *_x = shXOff[off % BDIM_X] + tidxDivK; @@ -510,21 +499,12 @@ __device__ void processCSR_Kanyv_reg_d(const int wo, const int tidx = threadIdx.x; const int tidy = threadIdx.y; - // unused if BDIM_X > WARP_SIZE - unsigned int subwarp_mask = 0xFFFFFFFF; - - if constexpr(BDIM_X <= WARP_SIZE) { - constexpr unsigned int MASK = (1ull << BDIM_X)-1; - subwarp_mask = MASK << (tidy*BDIM_X); - } - cols += tidx; for(int off = 0; off < rlen; off++) { if ((off % BDIM_X) == 0) { - if constexpr(BDIM_X <= WARP_SIZE) { __syncwarp(subwarp_mask); } - else { __syncthreads(); } + __sync(); const int64_t col = (off+tidx < rlen) ? cols[0] : 0; @@ -540,8 +520,7 @@ __device__ void processCSR_Kanyv_reg_d(const int wo, shXOff[tidx] = x + int64_t(hi)*nlon_in*nchan_in + int64_t(wip)*nchan_in; cols += BDIM_X; - if constexpr(BDIM_X <= WARP_SIZE) { __syncwarp(subwarp_mask); } - else { __syncthreads(); } + __sync(); } const float *_x = shXOff[off % BDIM_X]; @@ -763,6 +742,7 @@ void launch_spc_disco_fwd(int nloc, // "BDIM_X*nloc" >= nchans static void s2_disco_fwd_dispatch(int64_t batch_size, int64_t nchan_in, + int64_t nlat_in, int64_t nlon_in, int64_t nlat_out, int64_t nlon_out, @@ -823,8 +803,6 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, } // if K is a multiple of VEC_SIZE it will be read with vector lds - const int nlat_in = xP.size(1); - // smallest power of two "bdimx" (>=4) s.t. bdimx*MAX_LOCAL_ARR_LEN >= nchan_in*K int bdimx; bdimx = DIV_UP(nchan_in*K, MAX_LOCAL_ARR_LEN); @@ -984,12 +962,7 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, int64_t nlon_in = Wi; int64_t nlat_out = Ho; int64_t nlon_out = Wo; -/* - int64_t ngroup = 1; - if (std::getenv("S2_NGROUP")) { - ngroup = atoi(std::getenv("S2_NGROUP")); - } -*/ + printf("%s:%d: batch_size: %ld, nchan: %ld, nlat_in: %ld, nlon_in: %ld, nlat_out: %ld, nlon_out: %ld, nrows: %ld, nnz_tot: %ld, K: %ld\n", __func__, __LINE__, batch_size, nchan, nlat_in, nlon_in, nlat_out, nlon_out, nrows, col_idx.size(0), K); @@ -1053,98 +1026,6 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, } #else -#if 0 // FUSED VERSION - // switch to channel-last - // version with fused enisum - - int64_t ngroup = weights.size(0); - int64_t chan_x_grp_out = weights.size(1); - int64_t chan_x_grp_in = weights.size(2); - int64_t weight_k = weights.size(3); - - int64_t nchan_out = ngroup*chan_x_grp_out; - - printf("weight tensor shape: %ld, %ld, %ld, %ld\n", ngroup, chan_x_grp_out, chan_x_grp_in, weight_k); fflush(stdout); - - if (nchan != chan_x_grp_in*ngroup || K != weight_k) { - fprintf(stderr, - "%s:%d: error, dimension mismatch for weight tensor!\n", - __func__, __LINE__); - exit(EXIT_FAILURE); - } - - - // input: inp[B][Ci][Hi][Wi] -> inp[B][Hi][Wi][Ci] - // - // output: out[[B][Ho][Wo][Co] -> out[B][Co][Ho][Wo] - // with Co = ngroup*chan_x_grp_out - - - // switch to channel-last - - // extract dtype - auto x_type = inp.dtype(); - torch::Tensor xP = inp.to(torch::kFloat32); - - // exract memory format: this is much safer than checking is_contiguous(at::MemoryFormat::ChannelsLast) - // the former fails for num_channels == 1 - bool x_is_channels_last = xP.strides()[1] == 1; - - // transpose if required - if (!x_is_channels_last) { xP = permute_4D_to0231(xP); } - -#if 1 - int64_t out_dims[] = {batch_size, nlat_out, nlon_out, nchan_out}; - auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype()); - torch::Tensor yP = torch::zeros(out_dims, options); // this will be empty_like() - // y is {batch_size, nlat_out, nlon_out, nchan_out}, -#else - // to test before fusion - int64_t out_dims[] = {batch_size, nlat_out, nlon_out, nchan*K}; - auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype()); - torch::Tensor yP = torch::zeros(out_dims, options); - // y is {batch_size, nlat_out, nlon_out, nchan*K}, -#endif - // call channel-last kernel implementation - s2_disco_fwd_dispatch(batch_size, - nchan, - nchan_out, - ngroup, - nlon_in, - //nlat_in, - nlat_out, - nlon_out, - K, - xP, - roff_idx, - col_idx, - val, - weights, - yP); - -#if 1 - // switch back to original layout; - // I'm assuming that if x was passed as channel last, then - // the output tensor should be K last - //torch::Tensor y = yP; - //if (!x_is_channels_last) { - // y = permute_4D_to0312(y); - // // make y {batch_size, nchan_out, nlat_out, nlon_out} - //} -#else - // to test before fusion - torch::Tensor y = yP; - if (!x_is_channels_last) { - y = permute_4D_to0312(y); - // make y {batch_size, nchan, K, nlat_out, nlon_out} - y = y.reshape({batch_size, nchan, K, nlat_out, nlon_out}); - } else { - // make y {batch_size, nlat_out, nlon_out, nchan, K} - y = y.reshape({batch_size, nlat_out, nlon_out, nchan, K}); - } -#endif - -#else // VERSION WITH SEPARATED EINSUM // switch to channel-last // version with fused enisum @@ -1159,8 +1040,8 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, // call channel-last kernel implementation s2_disco_fwd_dispatch(batch_size, nchan, + nlat_in, nlon_in, - //nlat_in, nlat_out, nlon_out, K, @@ -1173,7 +1054,6 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, auto y = yP.to(x_type); torch::Tensor out = y; -#endif #endif // closes ORIGINAL if #if 1 @@ -1182,7 +1062,7 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, CHECK_CUDA(cudaStreamSynchronize(stream)); printf("done\n"); fflush(stdout); - //dump_tensor("yout.txt", out); + dump_tensor("yout.txt", out); //dump_csr_linear("csr_disco.txt", roff_idx, ker_idx, row_idx, col_idx, val); //dump_out_kers("out_kers", out); } diff --git a/torch_harmonics/utils/csrc/csr_cuda.cuh b/torch_harmonics/utils/csrc/csr_cuda.cuh index 956346b9..a4653cfb 100644 --- a/torch_harmonics/utils/csrc/csr_cuda.cuh +++ b/torch_harmonics/utils/csrc/csr_cuda.cuh @@ -156,6 +156,23 @@ __device__ float4 __forceinline__ __vdiv(float s, float4 v) { return make_float4(s/v.x, s/v.y, s/v.z, s/v.w);; } +template +static __device__ void __sync() { + + unsigned int subwarp_mask = FULL_MASK; + + if constexpr(BDIM_X <= WARP_SIZE) { + const int tidy = threadIdx.y; + constexpr unsigned int MASK = (1ull << BDIM_X)-1; + subwarp_mask = MASK << (tidy*BDIM_X); + } + + if constexpr(BDIM_X <= WARP_SIZE) { __syncwarp(subwarp_mask); } + else { __syncthreads(); } + + return; +} + template __device__ VAL_T __warp_sum(VAL_T val) { From 02ba8c1cb78f982b5dca7dd394cf40c3bdbf78b6 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Tue, 2 Sep 2025 22:58:16 -0700 Subject: [PATCH 11/31] commenting out tests --- tests/test_convolution.py | 189 ++++++++++++++++++++++++++------------ 1 file changed, 129 insertions(+), 60 deletions(-) diff --git a/tests/test_convolution.py b/tests/test_convolution.py index b798cd9b..d612c9d2 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -39,16 +39,21 @@ from torch.library import opcheck from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 -from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes +from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longitudes from torch_harmonics.disco import cuda_kernels_is_available, optimized_kernels_is_available +from disco_helpers import preprocess_psi +from torch_harmonics.filter_basis import get_filter_basis +from torch_harmonics.disco.convolution import _precompute_convolution_tensor_s2 + +from testutils import compare_tensors if not optimized_kernels_is_available(): print(f"Warning: Couldn't import optimized disco convolution kernels") _devices = [(torch.device("cpu"),)] -#if torch.cuda.is_available(): -# _devices.append((torch.device("cuda"),)) +if torch.cuda.is_available(): + _devices.append((torch.device("cuda"),)) # perf thresholds # CPU results normalized to 16 OpenMP threads, @@ -183,6 +188,99 @@ def setUp(self): if self.device.type == "cuda": torch.cuda.manual_seed(333) + @parameterized.expand( + [ + # piecewise linear + # normal isotropic + [(16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular"], + [(17, 32), (17, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular"], + # normal anisotropic + [(16, 32), (16, 32), (3, 4), "piecewise linear", "mean", "equiangular", "equiangular"], + [(16, 32), (16, 32), (3, 2), "piecewise linear", "mean", "equiangular", "equiangular"], + # downsampling isotropic + [(16, 32), (8, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular"], + [(17, 32), (9, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular"], + # downsampling anisotropic + [(16, 32), (8, 16), (3, 4), "piecewise linear", "mean", "equiangular", "equiangular"], + [(16, 32), (8, 16), (3, 2), "piecewise linear", "mean", "equiangular", "equiangular"], + # morlet + # normal isotropic + [(16, 32), (16, 32), (1), "morlet", "mean", "equiangular", "equiangular"], # important for attention + [(16, 32), (16, 32), (3), "morlet", "mean", "equiangular", "equiangular"], + [(17, 32), (17, 32), (3), "morlet", "mean", "equiangular", "equiangular"], + # normal anisotropic + [(16, 32), (16, 32), (3, 4), "morlet", "mean", "equiangular", "equiangular"], + [(16, 32), (16, 32), (3, 2), "morlet", "mean", "equiangular", "equiangular"], + # downsampling isotropic + [(16, 32), (8, 16), (1), "morlet", "mean", "equiangular", "equiangular"], # important for attention + [(16, 32), (8, 16), (3), "morlet", "mean", "equiangular", "equiangular"], + [(17, 32), (9, 16), (3), "morlet", "mean", "equiangular", "equiangular"], + # downsampling anisotropic + [(16, 32), (8, 16), (3, 4), "morlet", "mean", "equiangular", "equiangular"], + [(16, 32), (8, 16), (3, 2), "morlet", "mean", "equiangular", "equiangular"], + # zernike + # normal + [(16, 32), (16, 32), (1), "zernike", "mean", "equiangular", "equiangular"], + [(16, 32), (16, 32), (3, 3), "zernike", "mean", "equiangular", "equiangular"], + [(17, 32), (17, 32), (3, 3), "zernike", "mean", "equiangular", "equiangular"], + # downsampling + [(16, 32), (8, 16), (1), "zernike", "mean", "equiangular", "equiangular"], + [(16, 32), (8, 16), (3, 3), "zernike", "mean", "equiangular", "equiangular"], + [(17, 32), (9, 16), (3, 3), "zernike", "mean", "equiangular", "equiangular"], + ], + skip_on_empty=True, + ) + def test_convolution_tensor_integrity(self, in_shape, out_shape, kernel_shape, basis_type, basis_norm_mode, grid_in, grid_out, verbose=False): + + nlat_in, nlon_in = in_shape + nlat_out, nlon_out = out_shape + + filter_basis = get_filter_basis(kernel_shape=kernel_shape, basis_type=basis_type) + + # use default value cutoff + theta_cutoff = torch.pi / float(nlat_out - 1) + + idx, vals, _ = _precompute_convolution_tensor_s2( + in_shape=in_shape, + out_shape=out_shape, + filter_basis=filter_basis, + grid_in=grid_in, + grid_out=grid_out, + theta_cutoff=theta_cutoff, + transpose_normalization=False, + basis_norm_mode=basis_norm_mode, + merge_quadrature=True, + ) + + ker_idx = idx[0, ...].contiguous() + row_idx = idx[1, ...].contiguous() + col_idx = idx[2, ...].contiguous() + vals = vals.contiguous() + + # sort values + roff_idx = preprocess_psi(filter_basis.kernel_size, nlat_out, ker_idx, row_idx, col_idx, vals).contiguous() + + # check shapes + self.assertTrue(ker_idx.shape[0] == row_idx.shape[0], f"ker_idx and row_idx have to have the same shape: found {ker_idx.shape[0]} and {row_idx.shape[0]}") + self.assertTrue(ker_idx.shape[0] == col_idx.shape[0], f"ker_idx and col_idx have to have the same shape: found {ker_idx.shape[0]} and {col_idx.shape[0]}") + self.assertTrue(ker_idx.shape[0] == vals.shape[0], f"ker_idx and vals have to have the same shape: found {ker_idx.shape[0]} and {vals.shape[0]}") + self.assertTrue((roff_idx.shape[0] - 1) == filter_basis.kernel_size * nlat_out, f"roff_idx has to have shape: found {(roff_idx.shape[0] - 1)} and {filter_basis.kernel_size * nlat_out}") + + # the multiplicitiy in ker_idx has to be the same for all kernel indices + unique, counts = torch.unique(ker_idx, return_counts=True) + self.assertTrue(torch.all(counts.max() == counts), f"The multiplicity in ker_idx has to be the same for all kernel indices: found {counts} for entries {unique}") + + if verbose: + print(f"\n ker_idx = {ker_idx},\n row_idx = {row_idx},\n col_idx = {col_idx}") + + # the following has to be true: the row_idx and col_idx have to be the same for all kernel indices + row_idx_ref = row_idx[ker_idx == 0] + col_idx_ref = col_idx[ker_idx == 0] + for k in range(1, filter_basis.kernel_size): + self.assertTrue(torch.all(row_idx_ref == row_idx[ker_idx == k]), f"The row_idx has to be the same for all kernel indices: found {row_idx_ref} for entries {ker_idx == k}") + self.assertTrue(torch.all(col_idx_ref == col_idx[ker_idx == k]), f"The row_idx has to be the same for all kernel indices: found {col_idx_ref} for entries {ker_idx == k}") + + @parameterized.expand( [ # regular convolution @@ -198,12 +296,12 @@ def setUp(self): [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", False, 1e-4, False], [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", False, 1e-4, False], [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4, False], - # # transpose convolution + # transpose convolution [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False], [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False], [8, 4, 2, (12, 24), (24, 48), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False], [8, 4, 2, (12, 24), (24, 48), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False], - [8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", True, 1e-4, False], + [8, 2, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", True, 1e-4, False], [8, 4, 2, (12, 24), (24, 48), (2, 1), "morlet", "mean", "equiangular", "equiangular", True, 1e-4, False], [8, 4, 2, (12, 24), (24, 48), (3), "zernike", "mean", "equiangular", "equiangular", True, 1e-4, False], [8, 4, 2, (8, 8), (16, 24), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False], @@ -241,11 +339,6 @@ def test_sparse_against_dense( nlat_in, nlon_in = in_shape nlat_out, nlon_out = out_shape - if isinstance(kernel_shape, int): - theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat_in - 1) - else: - theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1) - Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2 conv = Conv( in_channels, @@ -259,7 +352,7 @@ def test_sparse_against_dense( grid_in=grid_in, grid_out=grid_out, bias=False, - theta_cutoff=theta_cutoff, + theta_cutoff=None, optimized_kernel=use_optimized_kernels, ).to(self.device) @@ -272,7 +365,7 @@ def test_sparse_against_dense( filter_basis, grid_in=grid_out, grid_out=grid_in, - theta_cutoff=theta_cutoff, + theta_cutoff=conv.theta_cutoff, transpose_normalization=transpose, basis_norm_mode=basis_norm_mode, merge_quadrature=True, @@ -280,7 +373,7 @@ def test_sparse_against_dense( psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_in, conv.nlat_out * conv.nlon_out)).to_dense() - self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_in, nlat_out * nlon_out))) + self.assertTrue(compare_tensors("psi", psi, psi_dense[:, :, 0].reshape(-1, nlat_in, nlat_out * nlon_out))) else: psi_dense = _precompute_convolution_tensor_dense( in_shape, @@ -288,7 +381,7 @@ def test_sparse_against_dense( filter_basis, grid_in=grid_in, grid_out=grid_out, - theta_cutoff=theta_cutoff, + theta_cutoff=conv.theta_cutoff, transpose_normalization=transpose, basis_norm_mode=basis_norm_mode, merge_quadrature=True, @@ -296,13 +389,12 @@ def test_sparse_against_dense( psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_out, conv.nlat_in * conv.nlon_in)).to_dense() - self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in))) + self.assertTrue(compare_tensors("psi", psi, psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in))) # create a copy of the weight w_ref = torch.empty_like(conv.weight) with torch.no_grad(): w_ref.copy_(conv.weight) - w_ref = w_ref.reshape(-1, w_ref.shape[2], w_ref.shape[3]) w_ref.requires_grad = True # create an input signal @@ -328,11 +420,11 @@ def test_sparse_against_dense( x_ref_grad = x_ref.grad.clone() # compare results - self.assertTrue(torch.allclose(y, y_ref, rtol=tol, atol=tol)) + self.assertTrue(compare_tensors("output", y, y_ref, rtol=tol, atol=tol, verbose=verbose)) # compare - self.assertTrue(torch.allclose(x_grad, x_ref_grad, rtol=tol, atol=tol)) - self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad.unsqueeze(0), rtol=tol, atol=tol)) + self.assertTrue(compare_tensors("input gradient", x_grad, x_ref_grad, rtol=tol, atol=tol, verbose=verbose)) + self.assertTrue(compare_tensors("weight gradient", conv.weight.grad, w_ref.grad.unsqueeze(0), rtol=tol, atol=tol, verbose=verbose)) @parameterized.expand( @@ -381,13 +473,7 @@ def test_optimized_against_torch( if verbose: print(f"Testing DISCO convolution on {in_shape[0]}x{in_shape[1]} {grid_in} grid to {out_shape[0]}x{out_shape[1]} {grid_out} grid on {self.device.type} device") - nlat_in, nlon_in = in_shape - nlat_out, nlon_out = out_shape - - if isinstance(kernel_shape, int): - theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat_in - 1) - else: - theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1) + nlat_in, _ = in_shape Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2 @@ -403,7 +489,7 @@ def test_optimized_against_torch( grid_in=grid_in, grid_out=grid_out, bias=False, - theta_cutoff=theta_cutoff, + theta_cutoff=None, optimized_kernel=False, ).to(self.device) @@ -419,7 +505,7 @@ def test_optimized_against_torch( grid_in=grid_in, grid_out=grid_out, bias=False, - theta_cutoff=theta_cutoff, + theta_cutoff=None, optimized_kernel=True, ).to(self.device) @@ -444,11 +530,11 @@ def test_optimized_against_torch( inp_grad_opt = inp.grad.clone() # compare results - self.assertTrue(torch.allclose(out_naive, out_opt, rtol=tol, atol=tol)) + self.assertTrue(compare_tensors("output", out_naive, out_opt, rtol=tol, atol=tol, verbose=verbose)) # compare - self.assertTrue(torch.allclose(inp_grad_naive, inp_grad_opt, rtol=tol, atol=tol)) - self.assertTrue(torch.allclose(conv_naive.weight.grad, conv_opt.weight.grad, rtol=tol, atol=tol)) + self.assertTrue(compare_tensors("input gradient", inp_grad_naive, inp_grad_opt, rtol=tol, atol=tol, verbose=verbose)) + self.assertTrue(compare_tensors("weight gradient", conv_naive.weight.grad, conv_opt.weight.grad, rtol=tol, atol=tol, verbose=verbose)) @parameterized.expand( @@ -466,11 +552,6 @@ def test_device_instantiation(self, batch_size, in_channels, out_channels, in_sh nlat_in, nlon_in = in_shape nlat_out, nlon_out = out_shape - if isinstance(kernel_shape, int): - theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat_in - 1) - else: - theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1) - # get handle Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2 @@ -487,7 +568,7 @@ def test_device_instantiation(self, batch_size, in_channels, out_channels, in_sh grid_in=grid_in, grid_out=grid_out, bias=False, - theta_cutoff=theta_cutoff, + theta_cutoff=None, ) #torch.set_default_device(self.device) @@ -504,7 +585,7 @@ def test_device_instantiation(self, batch_size, in_channels, out_channels, in_sh grid_in=grid_in, grid_out=grid_out, bias=False, - theta_cutoff=theta_cutoff, + theta_cutoff=None, ) # since we specified the device specifier everywhere, it should always @@ -518,10 +599,10 @@ def test_device_instantiation(self, batch_size, in_channels, out_channels, in_sh @parameterized.expand( [ - # [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False], - # [8, 4, 2, (16, 32), (8, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False], - # [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False], - # [8, 4, 2, (8, 16), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False], + [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, False], + [8, 4, 2, (16, 32), (8, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, False], + [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, False], + [8, 4, 2, (8, 16), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, False], ], skip_on_empty=True, @@ -540,7 +621,6 @@ def test_optimized_pt2_compatibility( grid_in, grid_out, transpose, - tol, verbose, ): """Tests whether the optimized kernels are PyTorch 2 compatible""" @@ -551,13 +631,7 @@ def test_optimized_pt2_compatibility( if verbose: print(f"Testing DISCO convolution on {in_shape[0]}x{in_shape[1]} {grid_in} grid to {out_shape[0]}x{out_shape[1]} {grid_out} grid on {self.device.type} device") - nlat_in, nlon_in = in_shape - nlat_out, nlon_out = out_shape - - if isinstance(kernel_shape, int): - theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat_in - 1) - else: - theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1) + nlat_in, _ = in_shape Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2 conv = Conv( @@ -572,14 +646,14 @@ def test_optimized_pt2_compatibility( grid_in=grid_in, grid_out=grid_out, bias=False, - theta_cutoff=theta_cutoff, + theta_cutoff=None, ).to(self.device) # forward test if not transpose: - inp = torch.randn(batch_size, in_channels, *in_shape, device=self.device) + inp = torch.randn(batch_size, *in_shape, in_channels, device=self.device) else: - inp = torch.randn(batch_size, conv.kernel_size, in_channels, *in_shape, device=self.device) + inp = torch.randn(batch_size, *in_shape, in_channels, conv.kernel_size, device=self.device) test_inputs = (inp, conv.psi_roff_idx, conv.psi_ker_idx, conv.psi_row_idx, conv.psi_col_idx, conv.psi_vals, conv.kernel_size, conv.nlat_out, conv.nlon_out) @@ -603,7 +677,7 @@ def test_optimized_pt2_compatibility( @parameterized.expand( [ - # [8, 4, 2, (91, 180), (91, 180), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4], + #[8, 4, 2, (91, 180), (91, 180), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4], ], skip_on_empty=True, ) @@ -616,11 +690,6 @@ def test_perf(self, batch_size, in_channels, out_channels, in_shape, out_shape, nlat_in, nlon_in = in_shape nlat_out, nlon_out = out_shape - if isinstance(kernel_shape, int): - theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat_in - 1) - else: - theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1) - # get handle Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2 @@ -637,7 +706,7 @@ def test_perf(self, batch_size, in_channels, out_channels, in_shape, out_shape, grid_in=grid_in, grid_out=grid_out, bias=True, - theta_cutoff=theta_cutoff, + theta_cutoff=None, optimized_kernel=True, ).to(self.device) From 8068939afb8a9da74beae1778fb1c989ce33bfea Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Tue, 2 Sep 2025 23:44:45 -0700 Subject: [PATCH 12/31] adding reshapes --- torch_harmonics/disco/convolution.py | 2 +- torch_harmonics/disco/csrc/disco_cuda_bwd.cu | 13 +++---------- torch_harmonics/disco/csrc/disco_cuda_fwd.cu | 2 +- 3 files changed, 5 insertions(+), 12 deletions(-) diff --git a/torch_harmonics/disco/convolution.py b/torch_harmonics/disco/convolution.py index bee495cc..a6be7b9a 100644 --- a/torch_harmonics/disco/convolution.py +++ b/torch_harmonics/disco/convolution.py @@ -665,7 +665,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.optimized_kernel: xp = permute_to_0231(x) xp = xp.reshape(B, H, W, self.groups, self.groupsize_in) - xpc = torch.einsum("bxygc,gock->bxygok", xp, self.weight).reshape(B, H, W, self.groups * self.groupsize_out, -1).contiguous() + xpc = torch.einsum("bxygc,gock->bxygok", xp, self.weight).reshape(B, H, W, -1, self.kernel_size).contiguous() outp = _disco_s2_transpose_contraction_optimized( xpc, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out ) diff --git a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu index fa747ae4..8c3ad7f8 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu @@ -1110,20 +1110,13 @@ static void s2_disco_bwd_dispatch(int64_t batch_size, int64_t batch_size = ograd.size(0); int64_t nlat_in = ograd.size(1); int64_t nlon_in = ograd.size(2); - int64_t nchan = ograd.size(3); + int64_t Co = ograd.size(3); + int64_t nchan = Co * K; int64_t nrows = roff_idx.size(0) - 1; printf("%s:%d: batch_size: %ld, nlat_in: %ld, nlon_in: %ld, C: %ld\n", __func__, __LINE__, batch_size, nlat_in, nlon_in, nchan); - if (nchan % K) { - fprintf(stderr, - "%s:%d: error, number of channles of output gradient (%ld) is expected to be a multiple of kernel size (%ld)!\n", - __func__, __LINE__, nchan, K); - exit(EXIT_FAILURE); - } - int64_t Co = nchan/K; - printf("K: %ld, Cin: %ld\n", K, nchan/K); int64_t nlat_out = Ho; @@ -1137,7 +1130,7 @@ static void s2_disco_bwd_dispatch(int64_t batch_size, // extract dtype auto x_type = ograd.dtype(); - torch::Tensor xP = ograd.to(torch::kFloat32); + torch::Tensor xP = ograd.reshape({batch_size, nlat_in, nlon_in, nchan}).to(torch::kFloat32); torch::Tensor igrad = torch::zeros(out_dims, xP.options()); diff --git a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu index 2a06bb93..0fbb673a 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu @@ -1051,7 +1051,7 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, val, yP); - auto y = yP.to(x_type); + auto y = yP.reshape({batch_size, nlat_out, nlon_out, nchan, K}).to(x_type); torch::Tensor out = y; From 357e62004a98bb8d482e645e90ffb75db5d64077 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 3 Sep 2025 07:37:57 -0700 Subject: [PATCH 13/31] adding contiguous call before passing to kernel --- torch_harmonics/disco/_disco_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_harmonics/disco/_disco_utils.py b/torch_harmonics/disco/_disco_utils.py index b591a23b..d953bdfa 100644 --- a/torch_harmonics/disco/_disco_utils.py +++ b/torch_harmonics/disco/_disco_utils.py @@ -60,6 +60,7 @@ def _disco_s2_contraction_optimized( inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + inp = inp.contiguous() out = disco_kernels.forward.default(inp, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out) return out @@ -69,6 +70,7 @@ def _disco_s2_transpose_contraction_optimized( inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + inp = inp.contiguous() out = disco_kernels.backward.default(inp, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out) return out From f208ccd26ec736f86036a3c181832dcf62387e57 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 8 Sep 2025 00:47:07 -0700 Subject: [PATCH 14/31] dbg --- torch_harmonics/disco/_disco_utils.py | 2 - torch_harmonics/disco/convolution.py | 57 +++++++++++++++++++- torch_harmonics/disco/csrc/disco_cuda_bwd.cu | 8 ++- torch_harmonics/disco/csrc/disco_cuda_fwd.cu | 6 +-- torch_harmonics/filter_basis.py | 37 ++++++++----- 5 files changed, 87 insertions(+), 23 deletions(-) diff --git a/torch_harmonics/disco/_disco_utils.py b/torch_harmonics/disco/_disco_utils.py index d953bdfa..b591a23b 100644 --- a/torch_harmonics/disco/_disco_utils.py +++ b/torch_harmonics/disco/_disco_utils.py @@ -60,7 +60,6 @@ def _disco_s2_contraction_optimized( inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - inp = inp.contiguous() out = disco_kernels.forward.default(inp, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out) return out @@ -70,7 +69,6 @@ def _disco_s2_transpose_contraction_optimized( inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - inp = inp.contiguous() out = disco_kernels.backward.default(inp, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out) return out diff --git a/torch_harmonics/disco/convolution.py b/torch_harmonics/disco/convolution.py index a6be7b9a..73f585f7 100644 --- a/torch_harmonics/disco/convolution.py +++ b/torch_harmonics/disco/convolution.py @@ -166,6 +166,57 @@ def _normalize_convolution_tensor_s2( return psi_vals +def _pad_convolution_tensor_s2(psi_idx, psi_vals, kernel_size) -> Tuple[torch.Tensor, torch.Tensor]: + """Pads convolution tensor values with zeros to allow kernel vectorization. + + This function modifies the tensor to ensure that the number of elements per row is consistent for each kernel. + This is essential for vectorizing the kernel operations. + + Parameters + ----------- + psi_idx: torch.Tensor + Index tensor for the sparse convolution tensor. + psi_vals: torch.Tensor + Value tensor for the sparse convolution tensor. + kernel_size: int + Number of kernel basis functions. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + Padded index and value tensors. + """ + + ker_idx = psi_idx[0, ...] + row_idx = psi_idx[1, ...] + col_idx = psi_idx[2, ...] + + # check number of elements per row for each kernel and collect unique indices in the same go: + indices = {} + nnz = 0 + for ik in range(kernel_size): + iidx = torch.argwhere(ker_idx == ik) + nnz_tmp = iidx.shape[0] + nnz = max(nnz, nnz_tmp) + #row_idx_ker = row_idx[] + #col_idx_ker = col_idx[torch.where(ker_idx == ik)[0]] + + print(f"max number of elements per row: {nnz}") + + # create padded kernels + ker_idx_pad = torch.zeros(nnz*kernel_size, dtype=psi_idx.dtype, device=psi_idx.device) + row_idx_pad = torch.zeros(nnz*kernel_size, dtype=psi_idx.dtype, device=psi_idx.device) + col_idx_pad = torch.zeros(nnz*kernel_size, dtype=psi_idx.dtype, device=psi_idx.device) + vals_pad = torch.zeros(nnz*kernel_size, dtype=psi_vals.dtype, device=psi_vals.device) + + #off = 0 + for ik in range(kernel_size): + iidx = torch.argwhere(ker_idx == ik) + print(f"kernel {ik}:", row_idx[iidx], col_idx[iidx]) + + return psi_idx_pad, psi_vals_pad + + @lru_cache(typed=True, copy=True) def _precompute_convolution_tensor_s2( in_shape: Tuple[int], @@ -479,7 +530,11 @@ def __init__( merge_quadrature=True, ) - # sort the values + #if self.optimized_kernel: + # # pad the convolution tensor to allow kernel vectorization + # idx, vals = _pad_convolution_tensor_s2(idx, vals, self.kernel_size) + + # extract values and indices ker_idx = idx[0, ...].contiguous() row_idx = idx[1, ...].contiguous() col_idx = idx[2, ...].contiguous() diff --git a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu index 8c3ad7f8..4124d7d0 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu @@ -1114,10 +1114,8 @@ static void s2_disco_bwd_dispatch(int64_t batch_size, int64_t nchan = Co * K; int64_t nrows = roff_idx.size(0) - 1; - printf("%s:%d: batch_size: %ld, nlat_in: %ld, nlon_in: %ld, C: %ld\n", - __func__, __LINE__, batch_size, nlat_in, nlon_in, nchan); - - printf("K: %ld, Cin: %ld\n", K, nchan/K); + printf("%s:%d: batch_size: %ld, nlat_in: %ld, nlon_in: %ld, nchan: %ld, K: %ld\n", + __func__, __LINE__, batch_size, nlat_in, nlon_in, nchan, K); int64_t nlat_out = Ho; int64_t nlon_out = Wo; @@ -1130,7 +1128,7 @@ static void s2_disco_bwd_dispatch(int64_t batch_size, // extract dtype auto x_type = ograd.dtype(); - torch::Tensor xP = ograd.reshape({batch_size, nlat_in, nlon_in, nchan}).to(torch::kFloat32); + torch::Tensor xP = ograd.reshape({batch_size, nlat_in, nlon_in, nchan}).to(torch::kFloat32).contiguous(); torch::Tensor igrad = torch::zeros(out_dims, xP.options()); diff --git a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu index 0fbb673a..eaccb80a 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu @@ -823,7 +823,7 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, !is_aligned(_val_pck) || (K % VEC_SIZE) != 0) { - //printf("%s:%d: VEC_SIZE: %d, nchan_in: %d, K: %d, _xp: %p, _yp: %p\n", __func__, __LINE__, VEC_SIZE, nchan_in, K, _xp, _yp); + printf("Is aligned: %s:%d: VEC_SIZE: %d, nchan_in: %d, K: %d, _xp: %p, _yp: %p\n", __func__, __LINE__, VEC_SIZE, nchan_in, K, _xp, _yp); const int nloc = DIV_UP(nchan_in*K, bdimx); @@ -850,7 +850,7 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, } else { - //printf("%s:%d: VEC_SIZE: %d, nchan_in: %d, K: %d, _xp: %p, _yp: %p\n", __func__, __LINE__, VEC_SIZE, nchan_in, K, _xp, _yp); + printf("Is not aligned: %s:%d: VEC_SIZE: %d, nchan_in: %d, K: %d, _xp: %p, _yp: %p\n", __func__, __LINE__, VEC_SIZE, nchan_in, K, _xp, _yp); //float4 *_xp4 = reinterpret_cast(_xp); float4 *_yp4 = reinterpret_cast(_yp); @@ -1030,7 +1030,7 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, // version with fused enisum auto x_type = inp.dtype(); - auto xP = inp.to(torch::kFloat32); + auto xP = inp.to(torch::kFloat32).contiguous(); // to test before fusion int64_t out_dims[] = {batch_size, nlat_out, nlon_out, nchan*K}; diff --git a/torch_harmonics/filter_basis.py b/torch_harmonics/filter_basis.py index e7163d4e..2882e89b 100644 --- a/torch_harmonics/filter_basis.py +++ b/torch_harmonics/filter_basis.py @@ -131,8 +131,11 @@ def _compute_support_vals_isotropic(self, r: torch.Tensor, phi: torch.Tensor, r_ ir = (ikernel + 0.5) * dr # find the indices where the rotated position falls into the support of the kernel - iidx = torch.argwhere(((r - ir).abs() <= dr) & (r <= r_cutoff)) + iidx = torch.argwhere(((r - ir.max()).abs() <= dr) & (r <= r_cutoff)) + #vals = 1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr + #iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool, device=r.device)) vals = 1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr + vals = torch.clamp(vals, min=0.0) return iidx, vals @@ -155,11 +158,15 @@ def _compute_support_vals_anisotropic(self, r: torch.Tensor, phi: torch.Tensor, ir = (ikernel // nphi + 0.5) * dr iphi = (ikernel % nphi) * dphi - math.pi + #cond_r = ((r - ir.max()).abs() <= dr) & (r <= r_cutoff) + cond_r = (r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool, device=r.device) + + # find the indices where the rotated position falls into the support of the kernel if nr % 2 == 1: # find the support - cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff) - cond_phi = (ikernel == 0) | (_circle_dist(phi, iphi).abs() <= dphi) + cond_phi = (_circle_dist(phi, iphi.max()).abs() <= dphi) + #print(_circle_dist(phi, iphi).abs() <= dphi) # find indices where conditions are met iidx = torch.argwhere(cond_r & cond_phi) # compute the distance to the collocation points @@ -169,27 +176,33 @@ def _compute_support_vals_anisotropic(self, r: torch.Tensor, phi: torch.Tensor, vals = 1 - dist_r / dr vals *= torch.where((iidx[:, 0] > 0), (1 - dist_phi / dphi), 1.0) + # clamp values + vals = torch.clamp(vals, min=0.0) else: + print("IN THERE!") # in the even case, the inner basis functions overlap into areas with a negative areas rn = -r phin = torch.where(phi + math.pi >= math.pi, phi - math.pi, phi + math.pi) - # find the support - cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff) - cond_phi = _circle_dist(phi, iphi).abs() <= dphi - cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff) - cond_phin = _circle_dist(phin, iphi) <= dphi + cond_phi = _circle_dist(phi, iphi.max()).abs() <= dphi + #cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff) + cond_rn = (rn <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool, device=rn.device) + cond_phin = _circle_dist(phin, iphi.min()) <= dphi # find indices where conditions are met iidx = torch.argwhere((cond_r & cond_phi) | (cond_rn & cond_phin)) + #iidx = torch.argwhere(cond_r | cond_rn) # & cond_phin) dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() dist_phi = _circle_dist(phi[iidx[:, 1], iidx[:, 2]], iphi[iidx[:, 0], 0, 0]) dist_rn = (rn[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() dist_phin = _circle_dist(phin[iidx[:, 1], iidx[:, 2]], iphi[iidx[:, 0], 0, 0]) # compute the value of the basis functions - vals = cond_r[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_r / dr) - vals *= cond_phi[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_phi / dphi) - valsn = cond_rn[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_rn / dr) - valsn *= cond_phin[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_phin / dphi) + #vals = cond_r[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_r / dr) + #vals *= cond_phi[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_phi / dphi) + #valsn = cond_rn[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_rn / dr) + #valsn *= cond_phin[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_phin / dphi) + + vals = torch.clamp((1 - dist_r / dr) * (1 - dist_phi / dphi), min=0.0) + valsn = torch.clamp((1 - dist_rn / dr) * (1 - dist_phin / dphi), min=0.0) vals += valsn return iidx, vals From 27d541474582866d1bedb9608d1d51bcdf5b8383 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 8 Sep 2025 00:58:25 -0700 Subject: [PATCH 15/31] fixing some selection criterium --- torch_harmonics/filter_basis.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/torch_harmonics/filter_basis.py b/torch_harmonics/filter_basis.py index 2882e89b..656cff19 100644 --- a/torch_harmonics/filter_basis.py +++ b/torch_harmonics/filter_basis.py @@ -130,10 +130,12 @@ def _compute_support_vals_isotropic(self, r: torch.Tensor, phi: torch.Tensor, r_ else: ir = (ikernel + 0.5) * dr + #print("ir", ir, ir.min()) + # find the indices where the rotated position falls into the support of the kernel - iidx = torch.argwhere(((r - ir.max()).abs() <= dr) & (r <= r_cutoff)) + #iidx = torch.argwhere((r.abs() <= dr) & (r <= r_cutoff)) #vals = 1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr - #iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool, device=r.device)) + iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool, device=r.device)) vals = 1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr vals = torch.clamp(vals, min=0.0) @@ -161,7 +163,6 @@ def _compute_support_vals_anisotropic(self, r: torch.Tensor, phi: torch.Tensor, #cond_r = ((r - ir.max()).abs() <= dr) & (r <= r_cutoff) cond_r = (r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool, device=r.device) - # find the indices where the rotated position falls into the support of the kernel if nr % 2 == 1: # find the support @@ -185,7 +186,7 @@ def _compute_support_vals_anisotropic(self, r: torch.Tensor, phi: torch.Tensor, phin = torch.where(phi + math.pi >= math.pi, phi - math.pi, phi + math.pi) cond_phi = _circle_dist(phi, iphi.max()).abs() <= dphi #cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff) - cond_rn = (rn <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool, device=rn.device) + cond_rn = (rn.abs() <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool, device=rn.device) cond_phin = _circle_dist(phin, iphi.min()) <= dphi # find indices where conditions are met iidx = torch.argwhere((cond_r & cond_phi) | (cond_rn & cond_phin)) @@ -200,7 +201,8 @@ def _compute_support_vals_anisotropic(self, r: torch.Tensor, phi: torch.Tensor, #vals *= cond_phi[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_phi / dphi) #valsn = cond_rn[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_rn / dr) #valsn *= cond_phin[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_phin / dphi) - + #vals += valsn + vals = torch.clamp((1 - dist_r / dr) * (1 - dist_phi / dphi), min=0.0) valsn = torch.clamp((1 - dist_rn / dr) * (1 - dist_phin / dphi), min=0.0) vals += valsn From 9d3f351227eb0bffc4b58fe34936a1427a8080d6 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 8 Sep 2025 01:00:28 -0700 Subject: [PATCH 16/31] removing some debug prints --- torch_harmonics/filter_basis.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_harmonics/filter_basis.py b/torch_harmonics/filter_basis.py index 656cff19..2a1e5b12 100644 --- a/torch_harmonics/filter_basis.py +++ b/torch_harmonics/filter_basis.py @@ -180,7 +180,6 @@ def _compute_support_vals_anisotropic(self, r: torch.Tensor, phi: torch.Tensor, # clamp values vals = torch.clamp(vals, min=0.0) else: - print("IN THERE!") # in the even case, the inner basis functions overlap into areas with a negative areas rn = -r phin = torch.where(phi + math.pi >= math.pi, phi - math.pi, phi + math.pi) From f45d3cea93026d7e31a493060becc340d112fe4f Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 8 Sep 2025 01:51:14 -0700 Subject: [PATCH 17/31] small fix --- torch_harmonics/filter_basis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_harmonics/filter_basis.py b/torch_harmonics/filter_basis.py index 2a1e5b12..a35bd842 100644 --- a/torch_harmonics/filter_basis.py +++ b/torch_harmonics/filter_basis.py @@ -186,7 +186,7 @@ def _compute_support_vals_anisotropic(self, r: torch.Tensor, phi: torch.Tensor, cond_phi = _circle_dist(phi, iphi.max()).abs() <= dphi #cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff) cond_rn = (rn.abs() <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool, device=rn.device) - cond_phin = _circle_dist(phin, iphi.min()) <= dphi + cond_phin = _circle_dist(phin, iphi.max()) <= dphi # find indices where conditions are met iidx = torch.argwhere((cond_r & cond_phi) | (cond_rn & cond_phin)) #iidx = torch.argwhere(cond_r | cond_rn) # & cond_phin) From fb1e589578f367267dda57b021ed1df4e4578d3a Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Tue, 9 Sep 2025 05:39:57 -0700 Subject: [PATCH 18/31] cleaning up attention --- setup.py | 2 +- tests/test_attention.py | 8 +- torch_harmonics/attention/_attention_utils.py | 38 +++- torch_harmonics/attention/attention.py | 1 + .../attention/csrc/attention_cpu.h | 28 +-- .../attention/csrc/attention_cpu_bwd.cpp | 34 +--- .../attention/csrc/attention_cpu_fwd.cpp | 18 +- .../attention/csrc/attention_cuda.cuh | 2 +- .../attention/csrc/attention_cuda_bwd.cu | 29 +-- .../attention/csrc/attention_cuda_fwd.cu | 18 +- .../attention/csrc/attention_cuda_utils.cu | 57 ------ .../attention/csrc/attention_cuda_utils.cuh | 171 ------------------ 12 files changed, 77 insertions(+), 329 deletions(-) diff --git a/setup.py b/setup.py index a5bd9e78..8773a8be 100644 --- a/setup.py +++ b/setup.py @@ -200,7 +200,7 @@ def get_ext_modules(): "torch_harmonics/attention/csrc/attention_cpu_bwd.cpp", ] - if False: #BUILD_CUDA: + if BUILD_CUDA: print(f"Compiling attention CUDA kernels for torch-harmonics.") attention_sources.extend([ "torch_harmonics/attention/csrc/attention_cuda_utils.cu", diff --git a/tests/test_attention.py b/tests/test_attention.py index 67364485..9fa0ea82 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -54,8 +54,8 @@ # CPU results normalized to 16 OpenMP threads, # GPU results normalized to V100 16 GB GPU # this is just to detect performance regressions, not for absolute performance -_perf_test_thresholds = {"cpu": {"fwd_ms": 1000, "bwd_ms": 8000}, - "cuda": {"fwd_ms": 50, "bwd_ms": 150}} +_perf_test_thresholds = {"cpu": {"fwd_ms": 800, "bwd_ms": 6000}, + "cuda": {"fwd_ms": 10, "bwd_ms": 30}} _run_perf_tests = (os.getenv("TORCH_HARMONICS_RUN_PERF_TESTS", "0") == "1") @@ -326,12 +326,12 @@ def test_optimized_pt2_compatibility(self, batch_size, channels, heads, in_shape @parameterized.expand( [ # self attention - [1, 256, 1, (91, 180), (91, 180), "equiangular", "equiangular", 1e-5, 1e-5], + [1, 256, 1, (91, 180), (91, 180), "equiangular", "equiangular"], ], skip_on_empty=True, ) @unittest.skipUnless(optimized_kernels_is_available() and _run_perf_tests, "skipping performance test because optimized kernels are not available or perf tests are disabled") - def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=False): + def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, verbose=False): if (self.device.type == "cuda") and (not cuda_kernels_is_available()): raise unittest.SkipTest("skipping test because CUDA kernels are not available") diff --git a/torch_harmonics/attention/_attention_utils.py b/torch_harmonics/attention/_attention_utils.py index a8c1b58a..3b0ade28 100644 --- a/torch_harmonics/attention/_attention_utils.py +++ b/torch_harmonics/attention/_attention_utils.py @@ -35,6 +35,7 @@ import torch import torch.nn.functional as F from attention_helpers import optimized_kernels_is_available +from torch_harmonics.utils import permute_to_0231, permute_to_0312 from . import attention_kernels # HELPER ROUTINE FOR BACKWARD setup_context @@ -87,6 +88,11 @@ def _neighborhood_s2_attention_optimized(k: torch.Tensor, v: torch.Tensor, q: to B, _, H, W = qw.shape qw = qw.reshape(B*nh, -1, H, W) + # permute to 0231 + kw = permute_to_0231(kw) + vw = permute_to_0231(vw) + qw = permute_to_0231(qw) + # convert to float32 inp_dtype = kw.dtype kw = kw.to(torch.float32).contiguous() @@ -97,12 +103,18 @@ def _neighborhood_s2_attention_optimized(k: torch.Tensor, v: torch.Tensor, q: to col_idx, row_off, nlon_in, nlat_out, nlon_out) - _, C, H, W = output.shape - output = output.reshape(B, -1, H, W) + #_, H, W, C = output.shape + #output = output.reshape(-1, H, W, C) # convert back precision output = output.to(dtype=inp_dtype) + # permute back to 0312 + output = permute_to_0312(output) + + # fold heads back into channel dimension + output = output.reshape(B, -1, H, W) + return output @torch.library.register_fake("attention_kernels::_neighborhood_s2_attention_optimized") @@ -147,21 +159,31 @@ def _neighborhood_s2_attention_bwd_optimized(ctx, grad_output): B, _, H, W = grad_output.shape grad_output = grad_output.reshape(B*nh, -1, H, W) - # save type and convert to float32 + # permute to 0231 + kw = permute_to_0231(kw.contiguous()) + vw = permute_to_0231(vw.contiguous()) + qw = permute_to_0231(qw.contiguous()) + grad_output = permute_to_0231(grad_output.contiguous()) + + # save type and convert to float32 kw_dtype = kw.dtype vw_dtype = vw.dtype qw_dtype = qw.dtype - - kw = kw.to(torch.float32).contiguous() - vw = vw.to(torch.float32).contiguous() - qw = qw.to(torch.float32).contiguous() - grad_output = grad_output.to(torch.float32).contiguous() + kw = kw.to(torch.float32) + vw = vw.to(torch.float32) + qw = qw.to(torch.float32) + grad_output = grad_output.to(torch.float32) dkw, dvw, dqw = attention_kernels.backward.default(kw, vw, qw, grad_output, quad_weights, col_idx, row_off, nlon_in, nlat_out, nlon_out) + # permute back to 0312 + dkw = permute_to_0312(dkw) + dvw = permute_to_0312(dvw) + dqw = permute_to_0312(dqw) + # weight grads _, C, H, W = dkw.shape dkw = dkw.reshape(B, -1, H, W) diff --git a/torch_harmonics/attention/attention.py b/torch_harmonics/attention/attention.py index 1d76286f..0110b63d 100644 --- a/torch_harmonics/attention/attention.py +++ b/torch_harmonics/attention/attention.py @@ -353,6 +353,7 @@ def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value self.nlon_out, ) + # compute the output out = nn.functional.conv2d(out, self.proj_weights, bias=self.proj_bias) return out diff --git a/torch_harmonics/attention/csrc/attention_cpu.h b/torch_harmonics/attention/csrc/attention_cpu.h index f96e7a7b..c61a987f 100644 --- a/torch_harmonics/attention/csrc/attention_cpu.h +++ b/torch_harmonics/attention/csrc/attention_cpu.h @@ -50,6 +50,8 @@ namespace attention_kernels { const int64_t nlon_in, const int64_t nlat_out, const int64_t nlon_out, const int64_t batch_size, const int64_t nchannels_in, const int64_t nchannels_out) { + // IMPORTANT: all input tensors are in channels last format! + // some parameters const int64_t block_wo = CACHE_BLOCK_SIZE; const int64_t nblock_wo = static_cast((nlon_out + block_wo - 1) / block_wo); @@ -95,7 +97,7 @@ namespace attention_kernels { float qdotk = 0.0; //#pragma omp simd reduction(+:qdotk) for (int64_t ci = 0; ci < nchannels_in; ci++) { - qdotk += static_cast(qy_arr[b][ci][ho][wo] * kx_arr[b][ci][hi][wip]); + qdotk += static_cast(qy_arr[b][ho][wo][ci] * kx_arr[b][hi][wip][ci]); } // update tmp max @@ -106,7 +108,7 @@ namespace attention_kernels { alpha_sum[wo-wo_start] = alpha + alpha_sum[wo-wo_start] * std::exp(qdotk_max[wo-wo_start] - qdotk_max_tmp); // update output - y_tmp[wo-wo_start] = y_tmp[wo-wo_start] * std::exp(qdotk_max[wo-wo_start] - qdotk_max_tmp) + alpha * static_cast(vx_arr[b][co][hi][wip]); + y_tmp[wo-wo_start] = y_tmp[wo-wo_start] * std::exp(qdotk_max[wo-wo_start] - qdotk_max_tmp) + alpha * static_cast(vx_arr[b][hi][wip][co]); // define new max qdotk_max[wo-wo_start] = qdotk_max_tmp; @@ -115,7 +117,7 @@ namespace attention_kernels { // update output for (int64_t wo = wo_start; wo < wo_end; wo++) { - y_arr[b][co][ho][wo] = static_cast(y_tmp[wo-wo_start] / alpha_sum[wo-wo_start]); + y_arr[b][ho][wo][co] = static_cast(y_tmp[wo-wo_start] / alpha_sum[wo-wo_start]); } } } @@ -138,6 +140,8 @@ namespace attention_kernels { const int64_t nlon_in, const int64_t nlat_out, const int64_t nlon_out, const int64_t batch_size, const int64_t nchannels_in, const int64_t nchannels_out) { + // IMPORTANT: all input tensors are in channels last format! + // compute dqy and dkx #pragma omp parallel for collapse(2) for (int64_t b = 0; b < batch_size; b++) { @@ -176,7 +180,7 @@ namespace attention_kernels { // compute correlation & softmax numerator qdotk_nz[idz-zstart] = 0.0; for (int64_t cit = 0; cit < nchannels_in; cit++) { - qdotk_nz[idz-zstart] += qy_arr[b][cit][ho][wo] * kx_arr[b][cit][hi][wip]; + qdotk_nz[idz-zstart] += qy_arr[b][ho][wo][cit] * kx_arr[b][hi][wip][cit]; } // tmp max and discount @@ -190,16 +194,16 @@ namespace attention_kernels { // dkx: input dot float gdotv = 0.0; for (int64_t cot = 0; cot < nchannels_out; cot++) { - gdotv += dy_arr[b][cot][ho][wo] * vx_arr[b][cot][hi][wip]; + gdotv += dy_arr[b][ho][wo][cot] * vx_arr[b][hi][wip][cot]; } float alpha_gdotv_tmp = alpha_nz[idz-zstart] * gdotv; alpha_gdotv = alpha_gdotv_tmp + alpha_gdotv * discount; // dqy: alpha_k - alpha_k = alpha_nz[idz-zstart] * kx_arr[b][ci][hi][wip] + alpha_k * discount; + alpha_k = alpha_nz[idz-zstart] * kx_arr[b][hi][wip][ci] + alpha_k * discount; // dqy: alpha_k_gdotv - alpha_k_gdotv = alpha_gdotv_tmp * kx_arr[b][ci][hi][wip] + alpha_k_gdotv * discount; + alpha_k_gdotv = alpha_gdotv_tmp * kx_arr[b][hi][wip][ci] + alpha_k_gdotv * discount; // define new max qdotk_max = qdotk_max_tmp; @@ -211,7 +215,7 @@ namespace attention_kernels { alpha_k_gdotv = alpha_k_gdotv / alpha_sum; // dqy: update - dqy_arr[b][ci][ho][wo] = (alpha_k_gdotv - alpha_gdotv * alpha_k); + dqy_arr[b][ho][wo][ci] = (alpha_k_gdotv - alpha_gdotv * alpha_k); for (int64_t idz = zstart; idz < zend; idz++) { int64_t nz_col_idx = col_idx_arr[idz]; @@ -228,11 +232,11 @@ namespace attention_kernels { // dkx: input dot float gdotv = 0.0; for (int64_t cot = 0; cot < nchannels_out; cot++) { - gdotv += dy_arr[b][cot][ho][wo] * vx_arr[b][cot][hi][wip]; + gdotv += dy_arr[b][ho][wo][cot] * vx_arr[b][hi][wip][cot]; } // dkx: update - dkx_arr[b][ci][hi][wip] += qy_arr[b][ci][ho][wo] * alpha_norm * (gdotv - alpha_gdotv); + dkx_arr[b][hi][wip][ci] += qy_arr[b][ho][wo][ci] * alpha_norm * (gdotv - alpha_gdotv); } } } @@ -270,7 +274,7 @@ namespace attention_kernels { // compute correlation & softmax numerator qdotk_nz[idz-zstart] = 0.0; for (int64_t ci = 0; ci < nchannels_in; ci++) { - qdotk_nz[idz-zstart] += qy_arr[b][ci][ho][wo] * kx_arr[b][ci][hi][wip]; + qdotk_nz[idz-zstart] += qy_arr[b][ho][wo][ci] * kx_arr[b][hi][wip][ci]; } // tmp max and discount @@ -296,7 +300,7 @@ namespace attention_kernels { // recompute alpha float alpha_norm = std::exp(qdotk_nz[idz-zstart] - qdotk_max) * quad_weights_arr[hi] / alpha_sum; - dvx_arr[b][co][hi][wip] += alpha_norm * dy_arr[b][co][ho][wo]; + dvx_arr[b][hi][wip][co] += alpha_norm * dy_arr[b][ho][wo][co]; } } } diff --git a/torch_harmonics/attention/csrc/attention_cpu_bwd.cpp b/torch_harmonics/attention/csrc/attention_cpu_bwd.cpp index f7ec9b6d..0297a79b 100644 --- a/torch_harmonics/attention/csrc/attention_cpu_bwd.cpp +++ b/torch_harmonics/attention/csrc/attention_cpu_bwd.cpp @@ -38,16 +38,16 @@ std::tuple s2_attention_bwd_cpu(tor torch::Tensor quad_weights, torch::Tensor col_idx, torch::Tensor row_off, int64_t nlon_in, int64_t nlat_out, int64_t nlon_out) { - // shapes: + // shapes: all channels LAST! // input - // kx: B, C, Hi, Wi - // vx: B, C, Hi, Wi - // qy: B, C, Ho, Wo + // kx: B, Hi, Wi, C + // vx: B, Hi, Wi, C + // qy: B, Ho, Wo, C // quad_weights: Hi // output - // dkx: B, C, Hi, Wi - // dvx: B, C, Hi, Wi - // dqy: B, C, Ho, Wo + // dkx: B, Hi, Wi, C + // dvx: B, Hi, Wi, C + // dqy: B, Ho, Wo, C // sanity checks CHECK_CPU_INPUT_TENSOR(kx); @@ -58,25 +58,14 @@ std::tuple s2_attention_bwd_cpu(tor CHECK_CPU_INPUT_TENSOR(col_idx); CHECK_CPU_INPUT_TENSOR(row_off); - // change to channels first: - bool kx_is_channels_last = kx.strides()[1] == 1; - bool vx_is_channels_last = vx.strides()[1] == 1; - bool qy_is_channels_last = qy.strides()[1] == 1; - bool dy_is_channels_last = dy.strides()[1] == 1; - - if (!kx_is_channels_last) { kx = kx.contiguous(at::MemoryFormat::ChannelsLast); } - if (!vx_is_channels_last) { vx = vx.contiguous(at::MemoryFormat::ChannelsLast); } - if (!qy_is_channels_last) { qy = qy.contiguous(at::MemoryFormat::ChannelsLast); } - if (!dy_is_channels_last) { dy = dy.contiguous(at::MemoryFormat::ChannelsLast); } - auto dkx = torch::zeros_like(kx); auto dvx = torch::zeros_like(vx); auto dqy = torch::zeros_like(qy); // some parameters const int64_t batch_size = kx.size(0); - const int64_t nchannels_out = vx.size(1); - const int64_t nchannels_in = qy.size(1); + const int64_t nchannels_out = vx.size(3); + const int64_t nchannels_in = qy.size(3); // extract accessors auto kx_arr = kx.packed_accessor64(); @@ -97,11 +86,6 @@ std::tuple s2_attention_bwd_cpu(tor nlon_in, nlat_out, nlon_out, batch_size, nchannels_in, nchannels_out); - // permute back - if (!qy_is_channels_last) { dqy = dqy.contiguous(at::MemoryFormat::Contiguous); } - if (!vx_is_channels_last) { dvx = dvx.contiguous(at::MemoryFormat::Contiguous); } - if (!kx_is_channels_last) { dkx = dkx.contiguous(at::MemoryFormat::Contiguous); } - return std::make_tuple(dkx, dvx, dqy); } diff --git a/torch_harmonics/attention/csrc/attention_cpu_fwd.cpp b/torch_harmonics/attention/csrc/attention_cpu_fwd.cpp index 3abbb7b3..9f4d96bc 100644 --- a/torch_harmonics/attention/csrc/attention_cpu_fwd.cpp +++ b/torch_harmonics/attention/csrc/attention_cpu_fwd.cpp @@ -45,22 +45,13 @@ namespace attention_kernels { CHECK_CPU_INPUT_TENSOR(col_idx); CHECK_CPU_INPUT_TENSOR(row_off); - // change to channels first: - bool kx_is_channels_last = kx.strides()[1] == 1; - bool vx_is_channels_last = vx.strides()[1] == 1; - bool qy_is_channels_last = qy.strides()[1] == 1; - - if (!kx_is_channels_last) { kx = kx.contiguous(at::MemoryFormat::ChannelsLast); } - if (!vx_is_channels_last) { vx = vx.contiguous(at::MemoryFormat::ChannelsLast); } - if (!qy_is_channels_last) { qy = qy.contiguous(at::MemoryFormat::ChannelsLast); } - // some parameters const int64_t batch_size = kx.size(0); - const int64_t nchannels_out = vx.size(1); - const int64_t nchannels_in = qy.size(1); + const int64_t nchannels_out = vx.size(3); + const int64_t nchannels_in = qy.size(3); // prepare result tensor - auto y = torch::zeros({batch_size, nchannels_out, nlat_out, nlon_out}, qy.options()); + auto y = torch::zeros({batch_size, nlat_out, nlon_out, nchannels_out}, qy.options()); // extract accessors auto roff_arr = row_off.packed_accessor64(); @@ -74,9 +65,6 @@ namespace attention_kernels { s2_attn_fwd_kernel(kx_arr, vx_arr, qy_arr, quad_weights_arr, col_idx_arr, roff_arr, y_arr, nlon_in, nlat_out, nlon_out, batch_size, nchannels_in, nchannels_out); - // permute back - if (!qy_is_channels_last) { y = y.contiguous(at::MemoryFormat::Contiguous); } - return y; } diff --git a/torch_harmonics/attention/csrc/attention_cuda.cuh b/torch_harmonics/attention/csrc/attention_cuda.cuh index 0042bfce..533b5938 100644 --- a/torch_harmonics/attention/csrc/attention_cuda.cuh +++ b/torch_harmonics/attention/csrc/attention_cuda.cuh @@ -35,7 +35,7 @@ #include #define CHECK_CUDA_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCUDA) -#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_INTERNAL_ASSERT(x.is_contiguous() || x.is_contiguous(at::MemoryFormat::ChannelsLast)) +#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_INTERNAL_ASSERT(x.is_contiguous(), #x " must be contiguous") #define CHECK_CUDA_INPUT_TENSOR(x) \ CHECK_CUDA_TENSOR(x); \ CHECK_CONTIGUOUS_TENSOR(x) diff --git a/torch_harmonics/attention/csrc/attention_cuda_bwd.cu b/torch_harmonics/attention/csrc/attention_cuda_bwd.cu index c34ccecd..21e0b28e 100644 --- a/torch_harmonics/attention/csrc/attention_cuda_bwd.cu +++ b/torch_harmonics/attention/csrc/attention_cuda_bwd.cu @@ -1003,16 +1003,17 @@ std::tuple s2_attention_bwd_dkvq_cuda(at::Te CHECK_CUDA_TENSOR(psi_col_idx); CHECK_CUDA_TENSOR(psi_row_off); - //const size_t uo_num_channels = kx.size(1); - size_t nchans_in = qy.size(1); // or kx.size(1) - size_t nchans_out = vx.size(1); + // IMPORTANT: all input tensors are in channels last format! + + size_t nchans_in = qy.size(3); // or kx.size(3) + size_t nchans_out = vx.size(3); const int batch_size = kx.size(0); // extract dtype - auto kx_type = kx.dtype(); // nchans_in + auto kx_type = kx.dtype(); auto qy_type = qy.dtype(); - auto vx_type = vx.dtype(); // ncahn_out + auto vx_type = vx.dtype(); auto dy_type = dy.dtype(); torch::Tensor kxP = kx.to(torch::kFloat32); @@ -1020,19 +1021,7 @@ std::tuple s2_attention_bwd_dkvq_cuda(at::Te torch::Tensor qyP = qy.to(torch::kFloat32); torch::Tensor dyP = dy.to(torch::kFloat32); - // exract memory format: this is much safer than checking is_contiguous(at::MemoryFormat::ChannelsLast) - // the former fails for num_channels == 1 - bool kx_is_channels_last = kxP.strides()[1] == 1; - bool vx_is_channels_last = vxP.strides()[1] == 1; - bool qy_is_channels_last = qyP.strides()[1] == 1; - bool dy_is_channels_last = dyP.strides()[1] == 1; - - // transpose if required - if (!kx_is_channels_last) { kxP = permute_4D_to0231(kxP); } - if (!vx_is_channels_last) { vxP = permute_4D_to0231(vxP); } - if (!qy_is_channels_last) { qyP = permute_4D_to0231(qyP); } - if (!dy_is_channels_last) { dyP = permute_4D_to0231(dyP); } - + // create output tensors torch::Tensor dkxP = torch::zeros_like(kxP); torch::Tensor dvxP = torch::zeros_like(vxP); torch::Tensor dqyP = torch::zeros_like(qyP); @@ -1053,10 +1042,6 @@ std::tuple s2_attention_bwd_dkvq_cuda(at::Te torch::Tensor dvx = dvxP; torch::Tensor dqy = dqyP; - if (!kx_is_channels_last) { dkx = permute_4D_to0312(dkx); } - if (!vx_is_channels_last) { dvx = permute_4D_to0312(dvx); } - if (!qy_is_channels_last) { dqy = permute_4D_to0312(dqy); } - // convert precision back to starting dkx = dkx.to(kx_type); dvx = dvx.to(vx_type); diff --git a/torch_harmonics/attention/csrc/attention_cuda_fwd.cu b/torch_harmonics/attention/csrc/attention_cuda_fwd.cu index 6de7321d..a6c9d65a 100644 --- a/torch_harmonics/attention/csrc/attention_cuda_fwd.cu +++ b/torch_harmonics/attention/csrc/attention_cuda_fwd.cu @@ -530,8 +530,10 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, CHECK_CUDA_TENSOR(psi_col_idx); CHECK_CUDA_TENSOR(psi_row_off); - size_t nchans_in = qy.size(1); // or kx.size(1) - size_t nchans_out = vx.size(1); + // IMPORTANT: all input tensors are in channels last format! + + size_t nchans_in = qy.size(3); // or kx.size(3) + size_t nchans_out = vx.size(3); const int batch_size = kx.size(0); @@ -542,16 +544,7 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, torch::Tensor vxP = vx.to(torch::kFloat32); torch::Tensor qyP = qy.to(torch::kFloat32); - // these are much safer than checking is_contiguous(at::MemoryFormat::ChannelsLast) - // the former fails for num_channels == 1 - bool kx_is_channels_last = kxP.strides()[1] == 1; - bool vx_is_channels_last = vxP.strides()[1] == 1; - bool qy_is_channels_last = qyP.strides()[1] == 1; - - if (!kx_is_channels_last) { kxP = permute_4D_to0231(kxP); } - if (!vx_is_channels_last) { vxP = permute_4D_to0231(vxP); } - if (!qy_is_channels_last) { qyP = permute_4D_to0231(qyP); } - + // output tensor torch::Tensor yP = torch::empty_like(vxP); s2_attn_fwd_dispatch(batch_size, @@ -567,7 +560,6 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, yP); torch::Tensor y = yP; - if (!qy_is_channels_last) { y = permute_4D_to0312(y); } // convert precision back to starting y = y.to(qy_type); diff --git a/torch_harmonics/attention/csrc/attention_cuda_utils.cu b/torch_harmonics/attention/csrc/attention_cuda_utils.cu index 09b42cfc..f758fcae 100644 --- a/torch_harmonics/attention/csrc/attention_cuda_utils.cu +++ b/torch_harmonics/attention/csrc/attention_cuda_utils.cu @@ -111,63 +111,6 @@ at::Tensor sortRows(int nlat_out, at::Tensor row_off, cudaStream_t stream) { } // END - CSR rows sorting kernels and functions - -// BEGIN - 4D tensor permutation kernels and functions -__global__ void empty_k() {} - -static int getPtxver() { - cudaFuncAttributes attrs; - CHECK_CUDA(cudaFuncGetAttributes(&attrs, empty_k)); - return attrs.ptxVersion*10; -} - -at::Tensor permute_4D_to0231(at::Tensor src) { - - auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device()); - torch::Tensor dst = torch::empty({src.size(0), src.size(2), src.size(3), src.size(1)}, options); - - const int ptxv = getPtxver(); - - // to be further specialized for additional archs, if necessary - if (ptxv < 100) { - AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_generic", ([&] { - launch_permute_to0231(src, dst); - })); - CHECK_ERROR("permute_to0231_k_tile_generic"); - } else { - AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_sm100", ([&] { - launch_permute_to0231(src, dst); - })); - CHECK_ERROR("permute_to0231_k_tile_sm100"); - } - - return dst; -} - -at::Tensor permute_4D_to0312(at::Tensor src) { - - auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device()); - torch::Tensor dst = torch::empty({src.size(0), src.size(3), src.size(1), src.size(2)}, options); - - const int ptxv = getPtxver(); - - // to be further specialized for additional archs, if necessary - if (ptxv < 100) { - AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_generic", ([&] { - launch_permute_to0312(src, dst); - })); - CHECK_ERROR("permute_to0312_k_tile_generic"); - } else { - AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_sm100", ([&] { - launch_permute_to0312(src, dst); - })); - CHECK_ERROR("permute_to0312_k_tile_sm100"); - } - - return dst; -} -// END - tensor permutation kernels and functions - // BEGIN - general host-side functions unsigned int next_pow2(unsigned int x) { diff --git a/torch_harmonics/attention/csrc/attention_cuda_utils.cuh b/torch_harmonics/attention/csrc/attention_cuda_utils.cuh index fbda09d5..81c75871 100644 --- a/torch_harmonics/attention/csrc/attention_cuda_utils.cuh +++ b/torch_harmonics/attention/csrc/attention_cuda_utils.cuh @@ -205,175 +205,4 @@ __device__ VAL_T __block_sum(VAL_T val) { return val; } -// transpose utils -template -__global__ -__launch_bounds__(BDIM_X*BDIM_Y) -void permute_to0231_k(const int nchn, - const int nlat, - const int nlon, - const at::PackedTensorAccessor32 src, - at::PackedTensorAccessor32 dst) { - - static_assert(!(BDIM_X & (BDIM_X-1))); - static_assert(!(BDIM_Y & (BDIM_Y-1))); - static_assert(BDIM_X >= BDIM_Y); - - __shared__ VAL_T sh[BDIM_X][BDIM_X+1]; - - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - - const int coff = blockIdx.x*BDIM_X; // channel offset - const int woff = blockIdx.y*BDIM_X; // width offset - const int batch = blockIdx.z / nlat; // batch (same for all block) - const int h = blockIdx.z - (batch * nlat); // height (same for all block) - - const int nchn_full = (nchn-coff) >= BDIM_X; - const int nlon_full = (nlon-woff) >= BDIM_X; - - if (nchn_full && nlon_full) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - sh[j+tidy][tidx] = src[batch][coff + j+tidy][h][woff+tidx]; - } - __syncthreads(); - - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy]; - } - } else { - if (woff+tidx < nlon) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - sh[j+tidy][tidx] = (coff + j+tidy < nchn) ? src[batch][coff + j+tidy][h][woff+tidx] : VAL_T(0); - } - } - __syncthreads(); - - if (coff+tidx < nchn) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - if (woff + j+tidy < nlon) { - dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy]; - } - } - } - } - return; -} - -template -void launch_permute_to0231(at::Tensor src, at::Tensor dst){ - dim3 block; - dim3 grid; - - block.x = WARP_SIZE; - block.y = WARPS_X_TILE; - grid.x = DIV_UP(src.size(1), block.x); - grid.y = DIV_UP(src.size(3), block.x); - grid.z = src.size(2)*src.size(0); - - assert(grid.y < 65536); - assert(grid.z < 65536); - - // get stream - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - permute_to0231_k - <<>>(src.size(1), - src.size(2), - src.size(3), - src.packed_accessor32(), - dst.packed_accessor32()); -} - -template -__global__ -__launch_bounds__(BDIM_X*BDIM_Y) -void permute_to0312_k(const int nchn, - const int nlat, - const int nlon, - const at::PackedTensorAccessor32 src, - at::PackedTensorAccessor32 dst) { - - static_assert(!(BDIM_X & (BDIM_X-1))); - static_assert(!(BDIM_Y & (BDIM_Y-1))); - static_assert(BDIM_X >= BDIM_Y); - - __shared__ VAL_T sh[BDIM_X][BDIM_X+1]; - - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - - const int woff = blockIdx.x*BDIM_X; // width offset - const int coff = blockIdx.y*BDIM_X; // channel offset - const int batch = blockIdx.z / nlat; // batch (same for all block) - const int h = blockIdx.z - (batch * nlat); // height (same for all block) - - const int nchn_full = (nchn-coff) >= BDIM_X; - const int nlon_full = (nlon-woff) >= BDIM_X; - - if (nchn_full && nlon_full) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - sh[j+tidy][tidx] = src[batch][h][woff + j+tidy][coff+tidx]; - } - __syncthreads(); - - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy]; - } - } else { - if (coff+tidx < nchn) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - sh[j+tidy][tidx] = (woff + j+tidy < nlon) ? src[batch][h][woff + j+tidy][coff+tidx] : VAL_T(0); - } - } - __syncthreads(); - - if (woff+tidx < nlon) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - if (coff + j+tidy < nchn) { - dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];; - } - } - } - } - return; -} - -template -void launch_permute_to0312(at::Tensor src, at::Tensor dst){ - dim3 block; - dim3 grid; - - block.x = WARP_SIZE; - block.y = WARPS_X_TILE; - grid.x = DIV_UP(src.size(2), block.x); - grid.y = DIV_UP(src.size(3), block.x); - grid.z = src.size(1)*src.size(0); - - assert(grid.y < 65536); - assert(grid.z < 65536); - - // get stream - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - permute_to0312_k - <<>>(src.size(3), - src.size(1), - src.size(2), - src.packed_accessor32(), - dst.packed_accessor32()); -} - } \ No newline at end of file From 4a1d7854b9305da1121874ce000a98f8292ca26c Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Tue, 9 Sep 2025 05:53:00 -0700 Subject: [PATCH 19/31] fixing raw meta kernels --- torch_harmonics/attention/_attention_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch_harmonics/attention/_attention_utils.py b/torch_harmonics/attention/_attention_utils.py index 3b0ade28..633b6607 100644 --- a/torch_harmonics/attention/_attention_utils.py +++ b/torch_harmonics/attention/_attention_utils.py @@ -55,7 +55,8 @@ def _setup_context_attention_backward(ctx, inputs, output): def _(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - out_shape = (kw.shape[0], vw.shape[1], nlat_out, nlon_out) + # the raw kernel uses channels last format + out_shape = (kw.shape[0], nlat_out, nlon_out, vw.shape[3]) return torch.empty(out_shape, dtype=kw.dtype, device=kw.device) # raw backward fake @@ -63,6 +64,7 @@ def _(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, def _(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, grad_output: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, nlon_in: int, nlat_out: int, nlon_out: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # the raw kernel uses channels last format dk = torch.empty_like(kw) dv = torch.empty_like(vw) dq = torch.empty_like(qw) @@ -123,6 +125,7 @@ def _(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + # the wrapped kernel uses channels first format out_shape = (k.shape[0], wv.shape[0], nlat_out, nlon_out) return torch.empty(out_shape, dtype=k.dtype, device=k.device) From 1a7f96d20b147e29c299a62be20a5a59018ecc39 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Tue, 9 Sep 2025 06:30:46 -0700 Subject: [PATCH 20/31] better comments --- setup.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 8773a8be..42832a1b 100644 --- a/setup.py +++ b/setup.py @@ -101,7 +101,7 @@ def get_ext_modules(): print(f"Compiling helper routines for torch-harmonics.") - # UTILITIES + # Utility helpers ext_modules.append( CppExtension( "utility_helpers", @@ -112,7 +112,7 @@ def get_ext_modules(): ) ) - # DISCO + # DISCO helpers ext_modules.append( CppExtension( "disco_helpers", @@ -123,6 +123,7 @@ def get_ext_modules(): ) ) + # Attention helpers ext_modules.append( CppExtension( "attention_helpers", From 1b715b3ce67fe5eae0f5937bb994f5b0102780c4 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 6 Oct 2025 02:47:53 -0700 Subject: [PATCH 21/31] snapshot --- torch_harmonics/disco/_disco_utils.py | 3 +- torch_harmonics/disco/convolution.py | 136 +++++++----------- torch_harmonics/disco/csrc/disco_cuda_bwd.cu | 24 +++- torch_harmonics/disco/csrc/disco_cuda_fwd.cu | 10 +- .../distributed/distributed_convolution.py | 107 ++++++++++---- 5 files changed, 156 insertions(+), 124 deletions(-) diff --git a/torch_harmonics/disco/_disco_utils.py b/torch_harmonics/disco/_disco_utils.py index b591a23b..5bc0fe80 100644 --- a/torch_harmonics/disco/_disco_utils.py +++ b/torch_harmonics/disco/_disco_utils.py @@ -29,8 +29,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -from typing import Optional, Tuple -import math +from typing import Optional import torch from disco_helpers import optimized_kernels_is_available diff --git a/torch_harmonics/disco/convolution.py b/torch_harmonics/disco/convolution.py index 73f585f7..78a8f08d 100644 --- a/torch_harmonics/disco/convolution.py +++ b/torch_harmonics/disco/convolution.py @@ -30,21 +30,15 @@ # import abc -from typing import List, Tuple, Union, Optional -from warnings import warn +from typing import Tuple, Union, Optional import math -import xmlrpc import torch import torch.nn as nn -import nvtx - -from functools import partial - from torch_harmonics.cache import lru_cache -from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes +from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longitudes from torch_harmonics.utils import permute_to_0231, permute_to_0312 from ._disco_utils import _get_psi, _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch from ._disco_utils import _disco_s2_contraction_optimized, _disco_s2_transpose_contraction_optimized @@ -166,57 +160,6 @@ def _normalize_convolution_tensor_s2( return psi_vals -def _pad_convolution_tensor_s2(psi_idx, psi_vals, kernel_size) -> Tuple[torch.Tensor, torch.Tensor]: - """Pads convolution tensor values with zeros to allow kernel vectorization. - - This function modifies the tensor to ensure that the number of elements per row is consistent for each kernel. - This is essential for vectorizing the kernel operations. - - Parameters - ----------- - psi_idx: torch.Tensor - Index tensor for the sparse convolution tensor. - psi_vals: torch.Tensor - Value tensor for the sparse convolution tensor. - kernel_size: int - Number of kernel basis functions. - - Returns - ------- - Tuple[torch.Tensor, torch.Tensor] - Padded index and value tensors. - """ - - ker_idx = psi_idx[0, ...] - row_idx = psi_idx[1, ...] - col_idx = psi_idx[2, ...] - - # check number of elements per row for each kernel and collect unique indices in the same go: - indices = {} - nnz = 0 - for ik in range(kernel_size): - iidx = torch.argwhere(ker_idx == ik) - nnz_tmp = iidx.shape[0] - nnz = max(nnz, nnz_tmp) - #row_idx_ker = row_idx[] - #col_idx_ker = col_idx[torch.where(ker_idx == ik)[0]] - - print(f"max number of elements per row: {nnz}") - - # create padded kernels - ker_idx_pad = torch.zeros(nnz*kernel_size, dtype=psi_idx.dtype, device=psi_idx.device) - row_idx_pad = torch.zeros(nnz*kernel_size, dtype=psi_idx.dtype, device=psi_idx.device) - col_idx_pad = torch.zeros(nnz*kernel_size, dtype=psi_idx.dtype, device=psi_idx.device) - vals_pad = torch.zeros(nnz*kernel_size, dtype=psi_vals.dtype, device=psi_vals.device) - - #off = 0 - for ik in range(kernel_size): - iidx = torch.argwhere(ker_idx == ik) - print(f"kernel {ik}:", row_idx[iidx], col_idx[iidx]) - - return psi_idx_pad, psi_vals_pad - - @lru_cache(typed=True, copy=True) def _precompute_convolution_tensor_s2( in_shape: Tuple[int], @@ -530,10 +473,6 @@ def __init__( merge_quadrature=True, ) - #if self.optimized_kernel: - # # pad the convolution tensor to allow kernel vectorization - # idx, vals = _pad_convolution_tensor_s2(idx, vals, self.kernel_size) - # extract values and indices ker_idx = idx[0, ...].contiguous() row_idx = idx[1, ...].contiguous() @@ -562,34 +501,44 @@ def extra_repr(self): def psi_idx(self): return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous() - @nvtx.annotate("forward", color="purple") def forward(self, x: torch.Tensor) -> torch.Tensor: if self.optimized_kernel: - with nvtx.annotate("_disco_s2_contraction_optimized", color="red"): - xp = permute_to_0231(x) + # permute input + xp = permute_to_0231(x) - xpc = _disco_s2_contraction_optimized( - xp, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, - self.kernel_size, self.nlat_out, self.nlon_out - ).reshape(x.shape[0], self.nlat_out, self.nlon_out, self.groups, self.groupsize_in, self.kernel_size) + # disco contaction + xpc = _disco_s2_contraction_optimized( + xp, + self.psi_roff_idx, + self.psi_ker_idx, + self.psi_row_idx, + self.psi_col_idx, + self.psi_vals, + self.kernel_size, + self.nlat_out, + self.nlon_out + ) - outp = torch.einsum("bxygck,gock->bxygo", xpc, self.weight).reshape(xpc.shape[0], self.nlat_out, self.nlon_out, -1).contiguous() + # weight multiplication + B, H, W, _, K = xpc.shape + xpc = xpc.reshape(B, H, W, self.groups, self.groupsize_in, K) + outp = torch.einsum("bxygck,gock->bxygo", xpc, self.weight) + outp = outp.reshape(B, H, W, -1).contiguous() - out = permute_to_0312(outp) + # permute output + out = permute_to_0312(outp) else: + # disco contaction x = _disco_s2_contraction_torch(x, self.psi.to(x.device), self.nlon_out) # extract shape B, _, K, H, W = x.shape - with nvtx.annotate("reshape", color="blue"): - x = x.reshape(B, self.groups, self.groupsize_in, K, H, W) + x = x.reshape(B, self.groups, self.groupsize_in, K, H, W) - # do weight multiplication - with nvtx.annotate("einsum", color="blue"): - out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight).contiguous() - - out = out.reshape(B, -1, H, W) + # weight multiplication + out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight) + out = out.reshape(B, -1, H, W).contiguous() if self.bias is not None: out = out + self.bias.reshape(1, -1, 1, 1) @@ -715,22 +664,37 @@ def psi_idx(self): def forward(self, x: torch.Tensor) -> torch.Tensor: # extract shape - B, C, H, W = x.shape + B, _, H, W = x.shape if self.optimized_kernel: + # permute input xp = permute_to_0231(x) + + # weight multiplication xp = xp.reshape(B, H, W, self.groups, self.groupsize_in) - xpc = torch.einsum("bxygc,gock->bxygok", xp, self.weight).reshape(B, H, W, -1, self.kernel_size).contiguous() + xpc = torch.einsum("bxygc,gock->bxygok", xp, self.weight) + xpc = xpc.reshape(B, H, W, -1, self.kernel_size).contiguous() + + # disco contraction outp = _disco_s2_transpose_contraction_optimized( - xpc, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out + xpc, + self.psi_roff_idx, + self.psi_ker_idx, + self.psi_row_idx, + self.psi_col_idx, + self.psi_vals, + self.kernel_size, + self.nlat_out, + self.nlon_out ) + + # permute output out = permute_to_0312(outp) else: + # weight multiplication x = x.reshape(B, self.groups, self.groupsize_in, H, W) - - # do weight multiplication - xc = torch.einsum("bgcxy,gock->bgokxy", x, self.weight).contiguous() - xc = xc.reshape(B, self.groups* self.groupsize_out, -1, H, W) + xc = torch.einsum("bgcxy,gock->bgokxy", x, self.weight) + xc = xc.reshape(B, self.groups* self.groupsize_out, -1, H, W).contiguous() # disco contraction out = _disco_s2_transpose_contraction_torch(xc, self.psi_st.to(x.device), self.nlon_out) diff --git a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu index 4124d7d0..c6dacde8 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu @@ -811,7 +811,7 @@ void launch_gen_disco_bwd(int64_t batch_size, size_t shsize = (sizeof(FLOATV_T)*(nchans*K) + sizeof(float)*nchans)*block.y; const int pscale = nlon_out / nlon_in; -#if 1 +#if 0 printf("Launching s2_disco_bwd_generic_vec_k<%d, float%s><<<..., ..., %zu, ...>>> with:\n" "\tnchan_out: %ld\n" "\tK: %ld\n\n", @@ -860,7 +860,7 @@ void launch_spc_disco_bwd(int nloc, // "BDIM_X*nloc" >= nchans //size_t shsize = sizeof(float)*DIV_UP(nchans, BDIM_X)*BDIM_X*block.y; const int pscale = nlon_out / nlon_in; -#if 1 +#if 0 printf("Launching s2_disco_bwd_special_vec_k<%d, %d, %d, float%s><<<(%d, %d), (%d, %d), ..., %zu, ...>>> with:\n" "\tnchans: %ld\n" "\tK: %ld\n\n", @@ -1111,12 +1111,22 @@ static void s2_disco_bwd_dispatch(int64_t batch_size, int64_t nlat_in = ograd.size(1); int64_t nlon_in = ograd.size(2); int64_t Co = ograd.size(3); - int64_t nchan = Co * K; + int64_t Kograd = ograd.size(4); + if (K != Kograd) { + fprintf(stderr, + "%s:%d: error, K (%ld) must match size of dimension 4 of ograd (%ld)!\n", + __func__, __LINE__, K, Kograd); + exit(EXIT_FAILURE); + } + + int64_t nchan = Co * Kograd; int64_t nrows = roff_idx.size(0) - 1; +#if 0 + printf("%s:%d: batch_size: %ld, nchan: %ld, nlat_in: %ld, nlon_in: %ld, K: %ld\n", + __func__, __LINE__, batch_size, Co, nlat_in, nlon_in, Kograd); - printf("%s:%d: batch_size: %ld, nlat_in: %ld, nlon_in: %ld, nchan: %ld, K: %ld\n", - __func__, __LINE__, batch_size, nlat_in, nlon_in, nchan, K); - + printf("K: %ld, Cin: %ld\n", K, nchan/K); +#endif int64_t nlat_out = Ho; int64_t nlon_out = Wo; @@ -1128,7 +1138,7 @@ static void s2_disco_bwd_dispatch(int64_t batch_size, // extract dtype auto x_type = ograd.dtype(); - torch::Tensor xP = ograd.reshape({batch_size, nlat_in, nlon_in, nchan}).to(torch::kFloat32).contiguous(); + torch::Tensor xP = ograd.reshape({batch_size, nlat_in, nlon_in, nchan}).to(torch::kFloat32); torch::Tensor igrad = torch::zeros(out_dims, xP.options()); diff --git a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu index eaccb80a..75a8bddb 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu @@ -823,7 +823,7 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, !is_aligned(_val_pck) || (K % VEC_SIZE) != 0) { - printf("Is aligned: %s:%d: VEC_SIZE: %d, nchan_in: %d, K: %d, _xp: %p, _yp: %p\n", __func__, __LINE__, VEC_SIZE, nchan_in, K, _xp, _yp); + //printf("Is aligned: %s:%d: VEC_SIZE: %d, nchan_in: %d, K: %d, _xp: %p, _yp: %p\n", __func__, __LINE__, VEC_SIZE, nchan_in, K, _xp, _yp); const int nloc = DIV_UP(nchan_in*K, bdimx); @@ -850,7 +850,7 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, } else { - printf("Is not aligned: %s:%d: VEC_SIZE: %d, nchan_in: %d, K: %d, _xp: %p, _yp: %p\n", __func__, __LINE__, VEC_SIZE, nchan_in, K, _xp, _yp); + //printf("Is not aligned: %s:%d: VEC_SIZE: %d, nchan_in: %d, K: %d, _xp: %p, _yp: %p\n", __func__, __LINE__, VEC_SIZE, nchan_in, K, _xp, _yp); //float4 *_xp4 = reinterpret_cast(_xp); float4 *_yp4 = reinterpret_cast(_yp); @@ -963,8 +963,8 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, int64_t nlat_out = Ho; int64_t nlon_out = Wo; - printf("%s:%d: batch_size: %ld, nchan: %ld, nlat_in: %ld, nlon_in: %ld, nlat_out: %ld, nlon_out: %ld, nrows: %ld, nnz_tot: %ld, K: %ld\n", - __func__, __LINE__, batch_size, nchan, nlat_in, nlon_in, nlat_out, nlon_out, nrows, col_idx.size(0), K); + //printf("%s:%d: batch_size: %ld, nchan: %ld, nlat_in: %ld, nlon_in: %ld, nlat_out: %ld, nlon_out: %ld, nrows: %ld, nnz_tot: %ld, K: %ld\n", + // __func__, __LINE__, batch_size, nchan, nlat_in, nlon_in, nlat_out, nlon_out, nrows, col_idx.size(0), K); // get stream auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -1056,7 +1056,7 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, torch::Tensor out = y; #endif // closes ORIGINAL if -#if 1 +#if 0 if (std::getenv("S2_DISCO_DUMP_Y")) { printf("waiting for kernel to finish..."); CHECK_CUDA(cudaStreamSynchronize(stream)); diff --git a/torch_harmonics/distributed/distributed_convolution.py b/torch_harmonics/distributed/distributed_convolution.py index 13d7cbb4..8504b446 100644 --- a/torch_harmonics/distributed/distributed_convolution.py +++ b/torch_harmonics/distributed/distributed_convolution.py @@ -29,7 +29,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -from typing import List, Tuple, Union, Optional +from typing import Tuple, Union, Optional from itertools import accumulate import torch @@ -37,10 +37,9 @@ from functools import partial -from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes from torch_harmonics.disco._disco_utils import _get_psi, _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch from torch_harmonics.disco._disco_utils import _disco_s2_contraction_optimized, _disco_s2_transpose_contraction_optimized -from torch_harmonics.filter_basis import get_filter_basis +from torch_harmonics.utils import permute_to_0231, permute_to_0312 from disco_helpers import optimized_kernels_is_available, preprocess_psi from torch_harmonics.disco.convolution import ( _precompute_convolution_tensor_s2, @@ -49,10 +48,10 @@ # distirbuted stuff from torch_harmonics.distributed import polar_group_size, azimuth_group_size -from torch_harmonics.distributed import distributed_transpose_azimuth, distributed_transpose_polar +from torch_harmonics.distributed import distributed_transpose_azimuth from torch_harmonics.distributed import reduce_from_polar_region, scatter_to_polar_region, gather_from_polar_region, copy_to_polar_region from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank -from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim +from torch_harmonics.distributed import compute_split_shapes def _split_distributed_convolution_tensor_s2( @@ -256,28 +255,58 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes) if self.optimized_kernel: + polar_dim = -3 + azimuth_dim = -2 + chan_dim = -1 + + # permute input + xp = permute_to_0231(x) + + # disco contraction x = _disco_s2_contraction_optimized( - x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out_local, self.nlon_out + xp, + self.psi_roff_idx, + self.psi_ker_idx, + self.psi_row_idx, + self.psi_col_idx, + self.psi_vals, + self.kernel_size, + self.nlat_out_local, + self.nlon_out ) else: + polar_dim = -2 + azimuth_dim = -1 + chan_dim = -3 + + # disco contraction x = _disco_s2_contraction_torch(x, self.psi.to(x.device), self.nlon_out) # perform reduce scatter in polar region x = reduce_from_polar_region(x) - x = scatter_to_polar_region(x, -2) + x = scatter_to_polar_region(x, polar_dim) # now we can transpose back the result, so that lon is split and channels are local if self.comm_size_azimuth > 1: chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth) - x = distributed_transpose_azimuth.apply(x, (-1, 1), chan_shapes) + x = distributed_transpose_azimuth.apply(x, (azimuth_dim, chan_dim), chan_shapes) # extract shape - B, C, K, H, W = x.shape - x = x.reshape(B, self.groups, self.groupsize, K, H, W) - - # do weight multiplication - out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous() - out = out.reshape(out.shape[0], -1, H, W) + if self.optimized_kernel: + # weight multiplication + B, H, W, _, K = x.shape + x = x.reshape(B, H, W, self.groups, self.groupsize_in, K) + outp = torch.einsum("bxygck,gock->bxygo", x, self.weight) + outp = outp.reshape(B, H, W, -1).contiguous() + + # permute output + out = permute_to_0312(outp) + else: + # weight multiplication + B, _, K, H, W = x.shape + x = x.reshape(B, self.groups, self.groupsize_in, K, H, W) + out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight) + out = out.reshape(B, -1, H, W).contiguous() if self.bias is not None: out = out + self.bias.reshape(1, -1, 1, 1) @@ -426,27 +455,57 @@ def psi_idx(self): def forward(self, x: torch.Tensor) -> torch.Tensor: # extract shape - B, C, H, W = x.shape - x = x.reshape(B, self.groups, self.groupsize, H, W) + B, _, H, W = x.shape - # do weight multiplication - x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous() - x = x.reshape(B, -1, x.shape[-3], H, W) - num_chans = x.shape[1] + if self.optimized_kernel: + # permute input + xp = permute_to_0231(x) + + # weight multiplication + xp = xp.reshape(B, H, W, self.groups, self.groupsize_in) + x = torch.einsum("bxygc,gock->bxygok", xp, self.weight) + x = x.reshape(B, H, W, -1, self.kernel_size).contiguous() + polar_dim = -3 + azimuth_dim = -2 + chan_dim = -1 + else: + # weight multiplication + x = x.reshape(B, self.groups, self.groupsize_in, H, W) + x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight) + x = x.reshape(B, -1, x.shape[-3], H, W).contiguous() + polar_dim = -2 + azimuth_dim = -1 + chan_dim = 1 + + # save number of channels + num_chans = x.shape[chan_dim] # transpose such that lon is local, channels are split if self.comm_size_azimuth > 1: - x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes) + x = distributed_transpose_azimuth.apply(x, (chan_dim, azimuth_dim), self.lon_in_shapes) # gather input tensor and set up backward reduction hooks - x = gather_from_polar_region(x, -2, self.lat_in_shapes) + x = gather_from_polar_region(x, polar_dim, self.lat_in_shapes) x = copy_to_polar_region(x) if self.optimized_kernel: - out = _disco_s2_transpose_contraction_optimized( - x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out_local, self.nlon_out + # disco contraction + outp = _disco_s2_transpose_contraction_optimized( + x, + self.psi_roff_idx, + self.psi_ker_idx, + self.psi_row_idx, + self.psi_col_idx, + self.psi_vals, + self.kernel_size, + self.nlat_out_local, + self.nlon_out ) + + # permute output + out = permute_to_0312(outp) else: + # disco contraction out = _disco_s2_transpose_contraction_torch(x, self.psi_st.to(x.device), self.nlon_out) # now we can transpose back the result, so that lon is split and channels are local From 5dcda4d2f638f25dda21e0c0c4297bd12bf3c5dc Mon Sep 17 00:00:00 2001 From: Mauro Bisson Date: Wed, 8 Oct 2025 15:07:36 -0700 Subject: [PATCH 22/31] Fixed absolutie difference errors w.r.t. CPU and torch DISCO implementations of the BWD kernel when the BWD pass is upsampling. --- torch_harmonics/disco/csrc/disco_cuda_bwd.cu | 428 ++++----------- torch_harmonics/disco/csrc/disco_cuda_fwd.cu | 525 ++++--------------- torch_harmonics/utils/csrc/csr_cuda.cuh | 15 +- 3 files changed, 225 insertions(+), 743 deletions(-) diff --git a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu index c6dacde8..288e5a71 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu @@ -210,14 +210,14 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t template static __global__ void pack_vals_k(const int64_t K, - const int64_t nlat_out, + const int64_t nrows, const int64_t *__restrict__ row_off, const VAL_T *__restrict__ val_dat, VAL_T *__restrict__ val_pck) { const int tidx = threadIdx.x; const int wid = blockIdx.x*blockDim.y + threadIdx.y; - if (wid >= nlat_out) { + if (wid >= nrows) { return; } @@ -231,7 +231,7 @@ static __global__ void pack_vals_k(const int64_t K, for(int off = tidx; off < rlen; off += blockDim.x) { for(int ker = 0; ker < K; ker++) { - val_pck[off*K + ker] = val_dat[ row_off[ker*nlat_out + wid] + off]; + val_pck[off*K + ker] = val_dat[ row_off[ker*nrows + wid] + off]; } } @@ -257,9 +257,9 @@ static __device__ void processCSR_Kpow2_shm_d(const int wi, const int tidxDivK = tidx >> log2_K; const int tidxModK = tidx & (K-1); - + vals += tidxModK; - + const int BDIMX_div_K = BDIM_X >> log2_K; for(int chan = tidx; chan < nchans; chan += WARP_SIZE) { @@ -284,7 +284,7 @@ static __device__ void processCSR_Kpow2_shm_d(const int wi, const FLOATV_T myval = vals[0]; for(int i = 0; i < nchans*K; i += WARP_SIZE) { - + float sum = (i+tidx < nchans*K) ? __vred(__vmul(myval, shx[i+tidx])) : 0; for(int j = 1; j < K; j *= 2) { @@ -377,8 +377,10 @@ void s2_disco_bwd_generic_vec_k(int nchans, // no. of input float (not FLOATV int pscale, int K, // no. of output FLOATV_T elem along K dim (kernel size) const FLOATV_T *__restrict__ x, - const int32_t *__restrict__ row_idx, - const int64_t *__restrict__ row_off, + const int64_t csr_nrow, + const int32_t *__restrict__ row_sort, + const int64_t *__restrict__ row_off, + const int64_t *__restrict__ row_idx, const int64_t *__restrict__ col_idx, const FLOATV_T *__restrict__ val_pck, float *__restrict__ y) { @@ -390,34 +392,33 @@ void s2_disco_bwd_generic_vec_k(int nchans, // no. of input float (not FLOATV const int batch = blockIdx.y; const int ctaid = blockIdx.x*blockDim.y + threadIdx.y; - if (ctaid >= nlat_in*nlon_in) { + if (ctaid >= csr_nrow*nlon_in) { return; } -#if 1 - const int h = ctaid / nlon_in; + const int h = ctaid / nlon_in; const int wi = ctaid - (h*nlon_in); - const int hi = row_idx[h]; -#else - // for now don't use row_idx - const int hi = ctaid / nlon_in; - const int wi = ctaid - (hi*nlon_in); -#endif - + + // set csr_row to "h" to bypass the row sorting + const int csr_row = row_sort[h]; // h + + const int64_t rbeg = row_off[csr_row ]; + const int64_t rend = row_off[csr_row+1]; + + const int hi = row_idx[rbeg]; // reads only the first "nrow" rows of row_idx and only the first element of each row + x += int64_t(batch)*nlat_in*nlon_in*nchans*K + int64_t(hi)*nlon_in*nchans*K + int64_t(wi)*nchans*K; y += int64_t(batch)*nlat_out*nlon_out*nchans; extern __shared__ __align__(sizeof(float4)) float shext[]; - FLOATV_T *shx = reinterpret_cast(shext + (VEC_SIZE*nchans*K + nchans)*threadIdx.y); - float *shy = reinterpret_cast(shx + nchans*K); + + FLOATV_T *shx = reinterpret_cast(shext) + nchans*K*threadIdx.y; + float *shy = reinterpret_cast(shext) + nchans*K*VEC_SIZE*blockDim.y + nchans*threadIdx.y; for(int chan = tidx; chan < nchans*K; chan += WARP_SIZE) { shx[chan] = x[chan]; } - const int64_t rbeg = row_off[hi]; - const int64_t rend = row_off[hi+1]; - col_idx += rbeg; val_pck += rbeg*K; // val_pck CSR contains K values per element @@ -453,9 +454,9 @@ static __device__ void processCSR_Kpow2_reg_d(const int wi, unsigned int subwarp_mask = FULL_MASK; if constexpr(BDIM_X <= WARP_SIZE) { - const int tidy = threadIdx.y; constexpr unsigned int MASK = (1ull << BDIM_X)-1; - subwarp_mask = MASK << (tidy*BDIM_X); + unsigned int subwarp_id = threadIdx.y % (WARP_SIZE/BDIM_X); + subwarp_mask = MASK << (subwarp_id*BDIM_X); } // K is a power of two <= BDIM_X @@ -463,7 +464,7 @@ static __device__ void processCSR_Kpow2_reg_d(const int wi, const int tidxDivK = tidx >> log2_K; const int tidxModK = tidx & (K-1); - + cols += tidx; vals += tidxModK; @@ -492,7 +493,7 @@ static __device__ void processCSR_Kpow2_reg_d(const int wi, const FLOATV_T myval = vals[0]; float locy[NLOC]; - + #pragma unroll for(int i = 0; i < NLOC; i++) { locy[i] = __vred(__vmul(myval, locx[i])); @@ -506,10 +507,10 @@ static __device__ void processCSR_Kpow2_reg_d(const int wi, #pragma unroll for(int i = 0; i < NLOC; i++) { - locy[i] += __shfl_xor_sync(subwarp_mask, locy[i], j); + locy[i] += __shfl_xor_sync(subwarp_mask, locy[i], j, BDIM_X); } } - + if (!tidxModK) { // NLOC*BDIM_X >= nchans*K // NLOC_M1*BDIM_X < nchans*K => NLOC_M1*BDIM_X/K < nchans @@ -528,110 +529,6 @@ static __device__ void processCSR_Kpow2_reg_d(const int wi, return; } -#if 0 -template -static __device__ void processCSR_Kpow2_reg_d2(const int wi, - const int rlen, - const int nchans, // no. of input FLOATV_T elements along channel dim - const int nlon_out, - const int pscale, - const int K, - const FLOATV_T (&locx)[NLOC], - const int64_t *__restrict__ cols, - const FLOATV_T *__restrict__ vals, - float *(&shYOff)[BDIM_X+SHPAD], - float *__restrict__ shy, // NO LONGER USED - float *__restrict__ y) { - constexpr int NLOC_M1 = NLOC-1; - - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - - const int log2_K = __ffs(K)-1; - - const int tidxDivK = tidx >> log2_K; - const int tidxModK = tidx & (K-1); - - vals += tidxModK; - - const int BDIMX_div_K = BDIM_X >> log2_K; -#if 1 - for(int chan = tidx; chan < nchans; chan += BDIM_X) { - shy[chan] = 0; - } - __sync(); -#endif - - cols += tidx; - - for(int off = 0; off < rlen; off++) { - if ((off % BDIM_X) == 0) { - __sync(); - - const int64_t col = (off+tidx < rlen) ? cols[0] : 0; - - const int ho = col / nlon_out; - const int wo = col - (ho*nlon_out); - - int wop = wo + pscale*wi; - wop -= (wop / nlon_out)*nlon_out; - - shYOff[tidx] = y + int64_t(ho)*nlon_out*nchans + int64_t(wop)*nchans; - cols += BDIM_X; - - __sync(); - } - - float *_y = shYOff[off % BDIM_X];// + tidxDivK; - float *_shy = shy + tidxDivK; - - const FLOATV_T myval = vals[0]; - - float locy[NLOC]; - - #pragma unroll - for(int i = 0; i < NLOC; i++) { - locy[i] = __vred(__vmul(myval, locx[i])); - } - - #pragma unroll - for(int i = 0; i < NLOC; i++) { - - // K is a power of two <= 32 - for(int j = 1; j < K; j *= 2) { - locy[i] += __shfl_xor_sync(FULL_MASK, locy[i], j); - } - } - - if (!tidxModK) { - // NLOC*BDIM_X >= nchans*K - // NLOC_M1*BDIM_X < nchans*K => NLOC_M1*BDIM_X/K < nchans - #pragma unroll - for(int i = 0; i < NLOC; i++) { - _shy[i*BDIMX_div_K] += locy[i]; - } - // if (NLOC_M1*BDIM_X+tidx < nchans*K) { - // _shy[NLOC_M1*BDIMX_div_K] += locy[NLOC_M1]; - // } - } - __sync(); - - for(int chan = tidx; chan < nchans; chan += BDIM_X) { - atomicAdd(_y+chan, shy[chan]); - shy[chan] = 0; - } - __sync(); - - vals += K; - } - - return; -} -#endif - template= nlat_in*nlon_in) { + if (ctaid >= csr_nrow*nlon_in) { return; } -#if 1 - const int h = ctaid / nlon_out; - const int wi = ctaid - (h*nlon_out); - const int hi = row_idx[h]; -#else - // for now don't use row_idx - const int hi = ctaid / nlon_out; - const int wi = ctaid - (hi*nlon_out); -#endif + const int h = ctaid / nlon_in; + const int wi = ctaid - (h*nlon_in); + + // set csr_row to "h" to bypass the row sorting + const int csr_row = row_sort[h]; // h + + const int64_t rbeg = row_off[csr_row ]; + const int64_t rend = row_off[csr_row+1]; + + const int hi = row_idx[rbeg]; // reads only the first "nrow" rows of row_idx and only the first element of each row x += int64_t(batch)*nlat_in*nlon_in*nchans*K + int64_t(hi)*nlon_in*nchans*K + int64_t(wi)*nchans*K + tidx; y += int64_t(batch)*nlat_out*nlon_out*nchans; FLOATV_T locx[NLOC]; - + #pragma unroll for(int i = 0; i < NLOC_M1; i++) { locx[i] = x[i*BDIM_X]; @@ -768,16 +668,13 @@ void s2_disco_bwd_special_vec_k(int nchans, // no. of input float (not FLOATV // only used if K is not a multiple of 2 extern __shared__ __align__(sizeof(float4)) float shext[]; - float *shy = shext + nchans*threadIdx.y; - - const int64_t rbeg = row_off[hi]; - const int64_t rend = row_off[hi+1]; + float *shy = shext + DIV_UP(nchans, BDIM_X)*BDIM_X*threadIdx.y; col_idx += rbeg; val_pck += rbeg*K; // val_pck CSR contains K values per element const int rlen = rend-rbeg; - + constexpr int PAD = (BDIM_X < WARP_SIZE) ? 1 : 0; __shared__ float *shYOffAll[BDIM_Y][BDIM_X+PAD]; @@ -798,29 +695,26 @@ void launch_gen_disco_bwd(int64_t batch_size, int64_t nlon_out, int64_t K, FLOATV_T *__restrict__ _xp, - int32_t *_row_idx, + int64_t nrow, + int32_t *_row_sort, int64_t *_row_off, + int64_t *_row_idx, int64_t *_col_idx, FLOATV_T *_val_pck, float *__restrict__ _yp, cudaStream_t stream) { dim3 block(WARP_SIZE, THREADS/WARP_SIZE); - dim3 grid(DIV_UP(nlat_in*nlon_in, block.y), batch_size); + dim3 grid(DIV_UP(nrow*nlon_in, block.y), batch_size); size_t shsize = (sizeof(FLOATV_T)*(nchans*K) + sizeof(float)*nchans)*block.y; const int pscale = nlon_out / nlon_in; -#if 0 - printf("Launching s2_disco_bwd_generic_vec_k<%d, float%s><<<..., ..., %zu, ...>>> with:\n" - "\tnchan_out: %ld\n" - "\tK: %ld\n\n", - THREADS, sizeof(FLOATV_T)==16?"4":"", shsize, nchans, K); -#endif + // will use only the first 1/K-th of the CSR, i.e. only the first nlat_out rows s2_disco_bwd_generic_vec_k <<>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out, pscale, K, - _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp); + _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp); CHECK_ERROR("s2_disco_bwd_generic_vec_k"); return; @@ -839,8 +733,10 @@ void launch_spc_disco_bwd(int nloc, // "BDIM_X*nloc" >= nchans int64_t nlon_out, int64_t K, FLOATV_T *__restrict__ _xp, - int32_t *_row_idx, + int64_t nrow, + int32_t *_row_sort, int64_t *_row_off, + int64_t *_row_idx, int64_t *_col_idx, FLOATV_T *_val_pck, float *__restrict__ _yp, @@ -848,27 +744,19 @@ void launch_spc_disco_bwd(int nloc, // "BDIM_X*nloc" >= nchans if (CUR_LOC_SIZE == nloc) { - // block size set to 64 threads constexpr int BDIM_Y = (BDIM_X <= WARP_SIZE) ? THREADS / BDIM_X : 1; - // groups in gridDim.y dim3 block(BDIM_X, BDIM_Y); - dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size); + dim3 grid(DIV_UP(nrow*nlon_in, block.y), batch_size); // could be (BDIM_X/K) instead of BDIM_X but let's keep it simple size_t shsize = (K & (K-1)) ? sizeof(float)*DIV_UP(nchans, BDIM_X)*BDIM_X*block.y : 0; - //size_t shsize = sizeof(float)*DIV_UP(nchans, BDIM_X)*BDIM_X*block.y; const int pscale = nlon_out / nlon_in; -#if 0 - printf("Launching s2_disco_bwd_special_vec_k<%d, %d, %d, float%s><<<(%d, %d), (%d, %d), ..., %zu, ...>>> with:\n" - "\tnchans: %ld\n" - "\tK: %ld\n\n", - BDIM_X, BDIM_Y, CUR_LOC_SIZE, sizeof(FLOATV_T)==16?"4":"", grid.x, grid.y, block.x, block.y, shsize, nchans, K); -#endif + s2_disco_bwd_special_vec_k <<>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out, pscale, K, - _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp); + _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp); CHECK_ERROR("s2_disco_bwd_special_vec_k"); @@ -878,7 +766,7 @@ void launch_spc_disco_bwd(int nloc, // "BDIM_X*nloc" >= nchans launch_spc_disco_bwd(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, - K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); + K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); } return; } @@ -891,9 +779,10 @@ static void s2_disco_bwd_dispatch(int64_t batch_size, int64_t nlon_out, int64_t K, at::Tensor xP, - at::Tensor row_off, // CSR row offsets - at::Tensor col_idx, // CSR column indices - at::Tensor val_dat, // CSR value data + at::Tensor row_off, // CSR non-empty row offsets + at::Tensor row_idx, // CSR non-empty row indices + at::Tensor col_idx, // CSR non-empty col indices + at::Tensor val_dat, // CSR non-empty value data at::Tensor yP) { static_assert(0 == (MAX_LOCAL_ARR_LEN & (MAX_LOCAL_ARR_LEN-1))); @@ -907,7 +796,7 @@ static void s2_disco_bwd_dispatch(int64_t batch_size, K > WARP_SIZE) { fprintf(stderr, - ":%s:%d: invalid value of one or more input parameters!\n", + ":%s:%d: invalid value of one or more input parameters!\n", __FILE__, __LINE__); exit(EXIT_FAILURE); } @@ -915,11 +804,6 @@ static void s2_disco_bwd_dispatch(int64_t batch_size, // get stream auto stream = at::cuda::getCurrentCUDAStream().stream(); - // sort row indices (ho-s) in descending order - // based on (row_off[ho+1]-row_off[ho]) - at::Tensor row_idx = sortRows(nlat_in, row_off, stream); - - // replace the K sequential CRSs in "val_dat": // // val_dat[ 0: nnz/K) for ker = 0 @@ -928,19 +812,27 @@ static void s2_disco_bwd_dispatch(int64_t batch_size, // val_dat[nnz/K:2*nnz/K) for ker = K-1 // // with a packed CSR: - // + // // val_dat[nnz/K][K], i.e. with a CSR where elements of the original K CSRs are packed in consecutive elements assert(0 == (val_idx.size(0) % K)); + int64_t nrow_csr = row_off.size(0)-1; + assert(0 == (nrow_csr % K)); + + int64_t nrow = nrow_csr / K; + + // sort row indices (ho-s) in descending order + // based on (row_off[ho+1]-row_off[ho]) + at::Tensor row_sort = sortRows(nrow, row_off, stream); + // move into "disco_cuda_utils.cu" IF val_dat format won't be changed upstream in the call chain int64_t val_dims[] = {val_dat.size(0)}; auto options = torch::TensorOptions().device(val_dat.device()).dtype(val_dat.dtype()); torch::Tensor val_pck = torch::zeros(val_dims, options); { dim3 block(WARP_SIZE, THREADS/WARP_SIZE); - dim3 grid(DIV_UP(nlat_out, block.y)); - - pack_vals_k<<>>(K, nlat_out, + dim3 grid(DIV_UP(nlat_in, block.y)); + pack_vals_k<<>>(K, nrow, row_off.data_ptr(), val_dat.data_ptr(), val_pck.data_ptr()); @@ -956,8 +848,9 @@ static void s2_disco_bwd_dispatch(int64_t batch_size, float *_xp = reinterpret_cast(xP.data_ptr()); float *_yp = reinterpret_cast(yP.data_ptr()); - int32_t *_row_idx = reinterpret_cast(row_idx.data_ptr()); + int32_t *_row_sort = reinterpret_cast(row_sort.data_ptr()); int64_t *_row_off = reinterpret_cast(row_off.data_ptr()); + int64_t *_row_idx = reinterpret_cast(row_idx.data_ptr()); int64_t *_col_idx = reinterpret_cast(col_idx.data_ptr()); float *_val_pck = reinterpret_cast(val_pck.data_ptr()); @@ -968,37 +861,32 @@ static void s2_disco_bwd_dispatch(int64_t batch_size, !is_aligned(_val_pck) || (K % VEC_SIZE) != 0) { - //printf("%s:%d: VEC_SIZE: %d, nchans: %d, K: %d, _xp: %p, _yp: %p\n", __func__, __LINE__, VEC_SIZE, nchans, K, _xp, _yp); - const int nloc = DIV_UP(nchans*K, bdimx); - + // to avoid the compilation of unused template instances; // we use a block size BDIM_X that is the smallest power of 2 // such that BDIM_X*MAX_LOCAL_ARR_LEN >= nchans*K, so // BDIM_X > 32 are used only for: // // (BDIM_X-1)*MAX_LOCAL_ARR_LEN < nchan <= BDIM_X*MAX_LOCAL_ARR_LEN - constexpr int MIN_LOC_ARR_LEN = MAX_LOCAL_ARR_LEN/2+1; + constexpr int MIN_LOCAL_ARR_LEN = MAX_LOCAL_ARR_LEN/2+1; // use 2D blocks only if 32 threads are enough switch(bdimx) { - case 8: launch_spc_disco_bwd< 8, 1, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; - case 16: launch_spc_disco_bwd< 16, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; - case 32: launch_spc_disco_bwd< 32, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; - case 64: launch_spc_disco_bwd< 64, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; - case 128: launch_spc_disco_bwd< 128, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; - case 256: launch_spc_disco_bwd< 256, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; - case 512: launch_spc_disco_bwd< 512, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; - case 1024: launch_spc_disco_bwd<1024, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; - default: launch_gen_disco_bwd ( batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; + case 8: launch_spc_disco_bwd< 8, 1, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 16: launch_spc_disco_bwd< 16, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 32: launch_spc_disco_bwd< 32, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 64: launch_spc_disco_bwd< 64, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 128: launch_spc_disco_bwd< 128, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 256: launch_spc_disco_bwd< 256, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 512: launch_spc_disco_bwd< 512, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 1024: launch_spc_disco_bwd<1024, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + default: launch_gen_disco_bwd ( batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; } } else { - //printf("%s:%d: VEC_SIZE: %d, nchans: %d, K: %d, _xp: %p, _yp: %p\n", __func__, __LINE__, VEC_SIZE, nchans, K, _xp, _yp); - float4 *_xp4 = reinterpret_cast(_xp); - //float4 *_yp4 = reinterpret_cast(_yp); float4 *_val_pck4 = reinterpret_cast(_val_pck); @@ -1006,19 +894,19 @@ static void s2_disco_bwd_dispatch(int64_t batch_size, const int nloc = DIV_UP(nchans*K, bdimx); constexpr int MAX_LOCAL_VEC_LEN = MAX_LOCAL_ARR_LEN / VEC_SIZE; - constexpr int MIN_LOC_VEC_LEN = MAX_LOCAL_VEC_LEN/2+1; + constexpr int MIN_LOCAL_VEC_LEN = MAX_LOCAL_VEC_LEN/2+1; // use 2D blocks only if 32 threads are enough switch(bdimx) { - case 8: launch_spc_disco_bwd< 8, 1, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, _row_idx, _row_off, _col_idx, _val_pck4, _yp, stream); break; - case 16: launch_spc_disco_bwd< 16, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, _row_idx, _row_off, _col_idx, _val_pck4, _yp, stream); break; - case 32: launch_spc_disco_bwd< 32, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, _row_idx, _row_off, _col_idx, _val_pck4, _yp, stream); break; - case 64: launch_spc_disco_bwd< 64, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, _row_idx, _row_off, _col_idx, _val_pck4, _yp, stream); break; - case 128: launch_spc_disco_bwd< 128, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, _row_idx, _row_off, _col_idx, _val_pck4, _yp, stream); break; - case 256: launch_spc_disco_bwd< 256, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, _row_idx, _row_off, _col_idx, _val_pck4, _yp, stream); break; - case 512: launch_spc_disco_bwd< 512, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, _row_idx, _row_off, _col_idx, _val_pck4, _yp, stream); break; - case 1024: launch_spc_disco_bwd<1024, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, _row_idx, _row_off, _col_idx, _val_pck4, _yp, stream); break; - default: launch_gen_disco_bwd ( batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, _row_idx, _row_off, _col_idx, _val_pck4, _yp, stream); break; + case 8: launch_spc_disco_bwd< 8, 1, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp, stream); break; + case 16: launch_spc_disco_bwd< 16, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp, stream); break; + case 32: launch_spc_disco_bwd< 32, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp, stream); break; + case 64: launch_spc_disco_bwd< 64, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp, stream); break; + case 128: launch_spc_disco_bwd< 128, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp, stream); break; + case 256: launch_spc_disco_bwd< 256, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp, stream); break; + case 512: launch_spc_disco_bwd< 512, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp, stream); break; + case 1024: launch_spc_disco_bwd<1024, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp, stream); break; + default: launch_gen_disco_bwd ( batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp, stream); break; } } return; @@ -1036,76 +924,7 @@ static void s2_disco_bwd_dispatch(int64_t batch_size, CHECK_CUDA_INPUT_TENSOR(row_idx); CHECK_CUDA_INPUT_TENSOR(col_idx); CHECK_CUDA_INPUT_TENSOR(val); -#if 0 - // extract some shapes - int64_t B = ograd.size(0); - int64_t Hi = ograd.size(1); - int64_t Wi = ograd.size(2); - int64_t C = ograd.size(3); - int64_t BC = B * C; - int64_t nrows = roff_idx.size(0) - 1; - - // allocate output - int64_t out_dims[] = {B, Ho, Wo, C}; - - // get stream - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - // extract dtype - auto x_type = ograd.dtype(); - torch::Tensor xP = ograd.to(torch::kFloat32); - - torch::Tensor igrad = torch::zeros(out_dims, xP.options()); - // assert - static_assert(0 == (ELXTH_MAX % 2)); - - if (Wo <= 64 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(xP.scalar_type(), "disco_backward_cuda", ([&] { - launch_kernel<64, 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - xP.data_ptr(), igrad.data_ptr(), stream); - })); - } else if (Wo <= 128 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(xP.scalar_type(), "disco_backward_cuda", ([&] { - launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - xP.data_ptr(), igrad.data_ptr(), stream); - })); - } else if (Wo <= 256 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(xP.scalar_type(), "disco_backward_cuda", ([&] { - launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - xP.data_ptr(), igrad.data_ptr(), stream); - })); - } else if (Wo <= 512 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(xP.scalar_type(), "disco_backward_cuda", ([&] { - launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - xP.data_ptr(), igrad.data_ptr(), stream); - })); - } else if (Wo <= 1024 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(xP.scalar_type(), "disco_backward_cuda", ([&] { - launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - xP.data_ptr(), igrad.data_ptr(), stream); - })); - } else { - fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo, - 1024 * ELXTH_MAX); - exit(EXIT_FAILURE); - } -#else // extract some shapes int64_t batch_size = ograd.size(0); int64_t nlat_in = ograd.size(1); @@ -1121,12 +940,6 @@ static void s2_disco_bwd_dispatch(int64_t batch_size, int64_t nchan = Co * Kograd; int64_t nrows = roff_idx.size(0) - 1; -#if 0 - printf("%s:%d: batch_size: %ld, nchan: %ld, nlat_in: %ld, nlon_in: %ld, K: %ld\n", - __func__, __LINE__, batch_size, Co, nlat_in, nlon_in, Kograd); - - printf("K: %ld, Cin: %ld\n", K, nchan/K); -#endif int64_t nlat_out = Ho; int64_t nlon_out = Wo; @@ -1136,27 +949,12 @@ static void s2_disco_bwd_dispatch(int64_t batch_size, // get stream auto stream = at::cuda::getCurrentCUDAStream().stream(); - // extract dtype + // extract dtype, convert to fp32 and make contiguous auto x_type = ograd.dtype(); - torch::Tensor xP = ograd.reshape({batch_size, nlat_in, nlon_in, nchan}).to(torch::kFloat32); + torch::Tensor xP = ograd.reshape({batch_size, nlat_in, nlon_in, nchan}).to(torch::kFloat32).contiguous(); torch::Tensor igrad = torch::zeros(out_dims, xP.options()); - -#if 0 - printf("%s:%d: tensors info:\n", __func__, __LINE__); - printf("\tbatch_size: %ld\n", batch_size); - printf("\t nlat_in: %ld\n", nlat_in); - printf("\t nlon_in: %ld\n", nlon_in); - printf("\t C: %ld\n", nchan); - printf("\t K: %ld\n\n", K); - printf("\troff_idx.size(0)-1 == nlat_in*K: %d\n", roff_idx.size(0)-1 == nlat_in*K); - //printf("\tinp channle-last: %d\n", x_is_channels_last); - printf("\treshaped inp to: {%ld, %ld, %ld, %ld}\n", xP.size(0), xP.size(1), xP.size(2), xP.size(3)); - fflush(stdout); - - //exit(1); -#endif - + // call channel-last kernel implementation s2_disco_bwd_dispatch(batch_size, Co, //nchan, @@ -1167,18 +965,11 @@ static void s2_disco_bwd_dispatch(int64_t batch_size, K, xP, roff_idx, + row_idx, col_idx, val, igrad); -/* - // switch back to original layout; - torch::Tensor out = yP; - if (!x_is_channels_last) { - out = permute_4D_to0312(yP); - // make y {batch_size, nchan, nlat_out, nlon_out} - } -*/ -#endif + // convert back to original dtype igrad = igrad.to(x_type); @@ -1189,5 +980,4 @@ static void s2_disco_bwd_dispatch(int64_t batch_size, { m.impl("backward", &disco_cuda_bwd); } - } diff --git a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu index 75a8bddb..7cdd2683 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu @@ -41,10 +41,6 @@ namespace disco_kernels { using namespace utility_kernels; -void dump_tensor(const char *fname, at::Tensor t); -void dump_csr(const char *fname, at::Tensor roff, at::Tensor cols); -void dump_csr_linear(const char *fname, at::Tensor roff, at::Tensor kers, at::Tensor rows, at::Tensor cols, at::Tensor vals); - template __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int Ho, const int Wo, const int pscale, const int64_t *__restrict__ roff, const int64_t *__restrict__ kers, @@ -196,17 +192,16 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t return; } - template -__global__ void pack_vals_k(const int64_t K, - const int64_t nlat_out, - const int64_t *__restrict__ row_off, - const VAL_T *__restrict__ val_dat, - VAL_T *__restrict__ val_pck) { +static __global__ void pack_vals_k(const int64_t K, + const int64_t nrows, + const int64_t *__restrict__ row_off, + const VAL_T *__restrict__ val_dat, + VAL_T *__restrict__ val_pck) { const int tidx = threadIdx.x; const int wid = blockIdx.x*blockDim.y + threadIdx.y; - if (wid >= nlat_out) { + if (wid >= nrows) { return; } @@ -220,16 +215,15 @@ __global__ void pack_vals_k(const int64_t K, for(int off = tidx; off < rlen; off += blockDim.x) { for(int ker = 0; ker < K; ker++) { - val_pck[off*K + ker] = val_dat[ row_off[ker*nlat_out + wid] + off]; + val_pck[off*K + ker] = val_dat[ row_off[ker*nrows + wid] + off]; } } return; } - + // BEGIN VERSION WITH CHANNEL-LAST WITH 2D BLOCKS, 2ND DIM IDENTIFYING CHANNLES, NO EINSUM -#if 1 template __device__ void processCSR_Kpow2_shm_d(const int wo, @@ -270,9 +264,9 @@ __device__ void processCSR_Kpow2_shm_d(const int wo, // if BDIM_X is a multiple of K then "i*(j*BDIM_X) % K = const", // so thread "i" only needs to read vals[off*K + (i % K)] to update the // whole channel array - + const FLOATV_T myval = vals[0]; //vals[off*K + tidxModK]; - + for(int chan = tidx; chan < nchan_in*K; chan += BDIM_X) { // no. of vectors in nchan_in*K dim on intermediate out shy[chan] = __vadd(shy[chan], @@ -321,12 +315,12 @@ __device__ void processCSR_Kanyv_shm_d(const int wo, const int iDivK = chan / K; const int iModK = chan - (iDivK*K); - + shy[chan] = __vadd(shy[chan], __vmul(vals[iModK], __vset(_x[iDivK]))); } - + vals += K; } return; @@ -344,38 +338,39 @@ void s2_disco_fwd_generic_vec_k(int nchan_in, // no. of input float (not FLOA int pscale, int K, // no. of output FLOATV_T elem along K dim (kernel size) const float *__restrict__ x, - const int32_t *__restrict__ row_idx, + const int64_t csr_nrow, + const int32_t *__restrict__ row_sort, const int64_t *__restrict__ row_off, + const int64_t *__restrict__ row_idx, const int64_t *__restrict__ col_idx, const FLOATV_T *__restrict__ val_pck, FLOATV_T *__restrict__ y) { - constexpr int VEC_SIZE = sizeof(FLOATV_T) / sizeof(float); - const int tidx = threadIdx.x; const int batch = blockIdx.y; const int ctaid = blockIdx.x*blockDim.y + threadIdx.y; - if (ctaid >= nlat_out*nlon_out) { + if (ctaid >= csr_nrow*nlon_out) { return; } -#if 1 - const int h = ctaid / nlon_out; + const int h = ctaid / nlon_out; const int wo = ctaid - (h*nlon_out); - const int ho = row_idx[h]; -#else - // for now don't use row_idx - const int ho = wid / nlon_out; - const int wo = wid - (ho*nlon_out); -#endif - + + // set csr_row to "h" to bypass the row sorting + const int csr_row = row_sort[h]; // h + + const int64_t rbeg = row_off[csr_row ]; + const int64_t rend = row_off[csr_row+1]; + + const int ho = row_idx[rbeg]; // reads only the first "nrow" rows of row_idx and only the first element of each row + const int nchan_out = nchan_in*K; extern __shared__ __align__(sizeof(float4)) float shext[]; FLOATV_T *shy = reinterpret_cast(shext) + threadIdx.y*nchan_out; - + for(int chan = tidx; chan < nchan_out; chan += WARP_SIZE) { shy[chan] = __vset(0.f); } @@ -383,9 +378,6 @@ void s2_disco_fwd_generic_vec_k(int nchan_in, // no. of input float (not FLOA x += int64_t(batch)*nlat_in*nlon_in*nchan_in; y += int64_t(batch)*nlat_out*nlon_out*nchan_out + int64_t(ho)*nlon_out*nchan_out + int64_t(wo)*nchan_out; - const int64_t rbeg = row_off[ho]; - const int64_t rend = row_off[ho+1]; - col_idx += rbeg; val_pck += rbeg*K; // val_pck CSR contains K values per element @@ -421,13 +413,12 @@ __device__ void processCSR_Kpow2_reg_d(const int wo, constexpr int NLOC_M1 = NLOC-1; const int tidx = threadIdx.x; - const int tidy = threadIdx.y; const int log2_K = __ffs(K)-1; const int tidxDivK = tidx >> log2_K; const int tidxModK = tidx & (K-1); - + cols += tidx; vals += tidxModK; @@ -460,7 +451,7 @@ __device__ void processCSR_Kpow2_reg_d(const int wo, // if BDIM_X is a multiple of K then "i*(j*BDIM_X) % K = const", // so thread "i" only needs to read vals[off*K + (i % K)] to update the // whole channel array - + #pragma unroll for(int i = 0; i < NLOC_M1; i++) { locy[i] = __vadd(locy[i], @@ -497,7 +488,6 @@ __device__ void processCSR_Kanyv_reg_d(const int wo, constexpr int NLOC_M1 = NLOC-1; const int tidx = threadIdx.x; - const int tidy = threadIdx.y; cols += tidx; @@ -527,7 +517,7 @@ __device__ void processCSR_Kanyv_reg_d(const int wo, // if BDIM_X is not a multiple of K then "i*(j*BDIM_X) % K = f(i,j)", // so the mod need to be recomputed at each iteration of update the update loop - + #pragma unroll for(int i = 0; i < NLOC_M1; i++) { @@ -541,7 +531,7 @@ __device__ void processCSR_Kanyv_reg_d(const int wo, locy[i] = __vadd(locy[i], __vmul(vval, xval)); } if (NLOC_M1*BDIM_X+tidx < nchan_in*K) { - + const int chan = NLOC_M1*BDIM_X+tidx; const int iDivK = chan / K; const int iModK = chan - (iDivK*K); @@ -551,7 +541,7 @@ __device__ void processCSR_Kanyv_reg_d(const int wo, locy[NLOC_M1] = __vadd(locy[NLOC_M1], __vmul(vval, xval)); } - + vals += K; } return; @@ -571,8 +561,10 @@ void s2_disco_fwd_special_vec_k(const int nchan_in, // no. of input float (no const int pscale, const int K, // no. of output FLOATV_T elem along K dim (kernel size) const float *__restrict__ x, - const int32_t *__restrict__ row_idx, + const int64_t csr_nrow, + const int32_t *__restrict__ row_sort, const int64_t *__restrict__ row_off, + const int64_t *__restrict__ row_idx, const int64_t *__restrict__ col_idx, const FLOATV_T *__restrict__ val_pck, FLOATV_T *__restrict__ y) { @@ -584,7 +576,7 @@ void s2_disco_fwd_special_vec_k(const int nchan_in, // no. of input float (no constexpr int NLOC_M1 = NLOC-1; - constexpr int VEC_SIZE = sizeof(FLOATV_T) / sizeof(float); + //constexpr int VEC_SIZE = sizeof(FLOATV_T) / sizeof(float); const int tidx = threadIdx.x; const int tidy = threadIdx.y; @@ -592,19 +584,20 @@ void s2_disco_fwd_special_vec_k(const int nchan_in, // no. of input float (no const int batch = blockIdx.y; const int ctaid = blockIdx.x*blockDim.y + threadIdx.y; - if (ctaid >= nlat_out*nlon_out) { + if (ctaid >= csr_nrow*nlon_out) { return; } -#if 1 - const int h = ctaid / nlon_out; + const int h = ctaid / nlon_out; const int wo = ctaid - (h*nlon_out); - const int ho = row_idx[h]; -#else - // for now don't use row_idx - const int ho = ctaid / nlon_out; - const int wo = ctaid - (ho*nlon_out); -#endif + + // set csr_row to "h" to bypass the row sorting + const int csr_row = row_sort[h]; // h + + const int64_t rbeg = row_off[csr_row ]; + const int64_t rend = row_off[csr_row+1]; + + const int ho = row_idx[rbeg]; // reads only the first "nrow" rows of row_idx and only the first element of each row const int nchan_out = nchan_in*K; @@ -618,9 +611,6 @@ void s2_disco_fwd_special_vec_k(const int nchan_in, // no. of input float (no locy[i] = __vset(0.f); } - const int64_t rbeg = row_off[ho]; - const int64_t rend = row_off[ho+1]; - col_idx += rbeg; val_pck += rbeg*K; // val_pck CSR contains K values per element @@ -655,30 +645,26 @@ void launch_gen_disco_fwd(int64_t batch_size, int64_t nlon_out, int64_t K, float *__restrict__ _xp, - int32_t *_row_idx, + int64_t nrow, + int32_t *_row_sort, int64_t *_row_off, + int64_t *_row_idx, int64_t *_col_idx, FLOATV_T *_val_pck, FLOATV_T *__restrict__ _yp, cudaStream_t stream) { dim3 block(WARP_SIZE, THREADS/WARP_SIZE); - dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size); + dim3 grid(DIV_UP(nrow*nlon_out, block.y), batch_size); size_t shsize = sizeof(FLOATV_T)*(nchan_in*K)*block.y; const int pscale = nlon_in / nlon_out; -#if 0 - printf("Launching s2_disco_fwd_generic_vec_k<%d, float%s><<<..., ..., %zu, ...>>> with:\n" - "\tngroup: %ld\n" - "\tnchan_in: %ld\n" - "\tK: %ld\n\n", - THREADS, sizeof(FLOATV_T)==16?"4":"", shsize, ngroup, nchan_in, K); -#endif + // will use only the first 1/K-th of the CSR, i.e. only the first nlat_out rows s2_disco_fwd_generic_vec_k <<>>(nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, pscale, K, - _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp); + _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp); CHECK_ERROR("s2_disco_fwd_generic_vec_k"); return; @@ -697,8 +683,10 @@ void launch_spc_disco_fwd(int nloc, // "BDIM_X*nloc" >= nchans int64_t nlon_out, int64_t K, float *__restrict__ _xp, - int32_t *_row_idx, + int64_t nrow, + int32_t *_row_sort, int64_t *_row_off, + int64_t *_row_idx, int64_t *_col_idx, FLOATV_T *_val_pck, FLOATV_T *__restrict__ _yp, @@ -706,26 +694,18 @@ void launch_spc_disco_fwd(int nloc, // "BDIM_X*nloc" >= nchans if (CUR_LOC_SIZE == nloc) { - // block size set to 64 threads constexpr int BDIM_Y = (BDIM_X <= WARP_SIZE) ? THREADS / BDIM_X : 1; - // groups in gridDim.y dim3 block(BDIM_X, BDIM_Y); - dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size); - + dim3 grid(DIV_UP(nrow*nlon_out, block.y), batch_size); + size_t shsize = 0; //sizeof(float)*chxgrp_out * block.y; const int pscale = nlon_in / nlon_out; -#if 0 - printf("Launching s2_disco_fwd_special_vec_k<%d, %d, %d, float%s><<<(%d, %d, %d), (%d, %d), ..., %zu, ...>>> with:\n" - "\tngroup: %ld\n" - "\tnchan_in: %ld\n" - "\tK: %ld\n\n", - BDIM_X, BDIM_Y, CUR_LOC_SIZE, sizeof(FLOATV_T)==16?"4":"", grid.x, grid.y, grid.z, block.x, block.y, shsize, ngroup, nchan_in, K); -#endif + s2_disco_fwd_special_vec_k <<>>(nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, pscale, K, - _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp); + _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp); CHECK_ERROR("s2_disco_fwd_special_vec_k"); @@ -735,7 +715,7 @@ void launch_spc_disco_fwd(int nloc, // "BDIM_X*nloc" >= nchans launch_spc_disco_fwd(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, - K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); + K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); } return; } @@ -748,9 +728,10 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, int64_t nlon_out, int64_t K, at::Tensor xP, - at::Tensor row_off, // CSR row offsets - at::Tensor col_idx, // CSR column indices - at::Tensor val_dat, // CSR value data + at::Tensor row_off, // CSR non-empty row offsets + at::Tensor row_idx, // CSR non-empty row indices + at::Tensor col_idx, // CSR non-empty col indices + at::Tensor val_dat, // CSR non-empty value data at::Tensor yP) { static_assert(0 == (MAX_LOCAL_ARR_LEN & (MAX_LOCAL_ARR_LEN-1))); @@ -763,7 +744,7 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, K <= 0) { fprintf(stderr, - ":%s:%d: invalid value of one or more input parameters!\n", + ":%s:%d: invalid value of one or more input parameters!\n", __FILE__, __LINE__); exit(EXIT_FAILURE); } @@ -771,11 +752,6 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, // get stream auto stream = at::cuda::getCurrentCUDAStream().stream(); - // sort row indices (ho-s) in descending order - // based on (row_off[ho+1]-row_off[ho]) - at::Tensor row_idx = sortRows(nlat_out, row_off, stream); - - // replace the K sequential CRSs in "val_dat": // // val_dat[ 0: nnz/K) for ker = 0 @@ -784,19 +760,28 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, // val_dat[nnz/K:2*nnz/K) for ker = K-1 // // with a packed CSR: - // + // // val_dat[nnz/K][K], i.e. with a CSR where elements of the original K CSRs are packed in consecutive elements assert(0 == (val_idx.size(0) % K)); + int64_t nrow_csr = row_off.size(0)-1; + assert(0 == (nrow_csr % K)); + + int64_t nrow = nrow_csr / K; + + // sort row indices (ho-s) in descending order + // based on (row_off[ho+1]-row_off[ho]) + at::Tensor row_sort = sortRows(nrow, row_off, stream); + // move into "disco_cuda_utils.cu" IF val_dat format won't be changed upstream in the call chain int64_t val_dims[] = {val_dat.size(0)}; auto options = torch::TensorOptions().device(val_dat.device()).dtype(val_dat.dtype()); torch::Tensor val_pck = torch::zeros(val_dims, options); { dim3 block(WARP_SIZE, THREADS/WARP_SIZE); - dim3 grid(DIV_UP(nlat_out, block.y)); + dim3 grid(DIV_UP(nrow, block.y)); - pack_vals_k<<>>(K, nlat_out, + pack_vals_k<<>>(K, nrow, row_off.data_ptr(), val_dat.data_ptr(), val_pck.data_ptr()); @@ -812,8 +797,9 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, float *_xp = reinterpret_cast(xP.data_ptr()); float *_yp = reinterpret_cast(yP.data_ptr()); - int32_t *_row_idx = reinterpret_cast(row_idx.data_ptr()); + int32_t *_row_sort = reinterpret_cast(row_sort.data_ptr()); int64_t *_row_off = reinterpret_cast(row_off.data_ptr()); + int64_t *_row_idx = reinterpret_cast(row_idx.data_ptr()); int64_t *_col_idx = reinterpret_cast(col_idx.data_ptr()); float *_val_pck = reinterpret_cast(val_pck.data_ptr()); @@ -823,35 +809,31 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, !is_aligned(_val_pck) || (K % VEC_SIZE) != 0) { - //printf("Is aligned: %s:%d: VEC_SIZE: %d, nchan_in: %d, K: %d, _xp: %p, _yp: %p\n", __func__, __LINE__, VEC_SIZE, nchan_in, K, _xp, _yp); - const int nloc = DIV_UP(nchan_in*K, bdimx); - + // to avoid the compilation of unused template instances; // we use a block size BDIM_X that is the smallest power of 2 // such that BDIM_X*MAX_LOCAL_ARR_LEN >= nchan_in*K, so // BDIM_X > 32 are used only for: // // (BDIM_X-1)*MAX_LOCAL_ARR_LEN < nchan <= BDIM_X*MAX_LOCAL_ARR_LEN - constexpr int MIN_LOC_ARR_LEN = MAX_LOCAL_ARR_LEN/2+1; + constexpr int MIN_LOCAL_ARR_LEN = MAX_LOCAL_ARR_LEN/2+1; // use 2D blocks only if 32 threads are enough switch(bdimx) { - case 8: launch_spc_disco_fwd< 8, 1, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; - case 16: launch_spc_disco_fwd< 16, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; - case 32: launch_spc_disco_fwd< 32, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; - case 64: launch_spc_disco_fwd< 64, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; - case 128: launch_spc_disco_fwd< 128, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; - case 256: launch_spc_disco_fwd< 256, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; - case 512: launch_spc_disco_fwd< 512, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; - case 1024: launch_spc_disco_fwd<1024, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; - default: launch_gen_disco_fwd ( batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck, _yp, stream); break; + case 8: launch_spc_disco_fwd< 8, 1, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 16: launch_spc_disco_fwd< 16, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 32: launch_spc_disco_fwd< 32, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 64: launch_spc_disco_fwd< 64, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 128: launch_spc_disco_fwd< 128, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 256: launch_spc_disco_fwd< 256, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 512: launch_spc_disco_fwd< 512, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 1024: launch_spc_disco_fwd<1024, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + default: launch_gen_disco_fwd ( batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; } } else { - //printf("Is not aligned: %s:%d: VEC_SIZE: %d, nchan_in: %d, K: %d, _xp: %p, _yp: %p\n", __func__, __LINE__, VEC_SIZE, nchan_in, K, _xp, _yp); - //float4 *_xp4 = reinterpret_cast(_xp); float4 *_yp4 = reinterpret_cast(_yp); @@ -861,80 +843,26 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, const int nloc = DIV_UP(nchan_in*K, bdimx); constexpr int MAX_LOCAL_VEC_LEN = MAX_LOCAL_ARR_LEN / VEC_SIZE; - constexpr int MIN_LOC_VEC_LEN = MAX_LOCAL_VEC_LEN/2+1; + constexpr int MIN_LOCAL_VEC_LEN = MAX_LOCAL_VEC_LEN/2+1; // use 2D blocks only if 32 threads are enough switch(bdimx) { - case 8: launch_spc_disco_fwd< 8, 1, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck4, _yp4, stream); break; - case 16: launch_spc_disco_fwd< 16, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck4, _yp4, stream); break; - case 32: launch_spc_disco_fwd< 32, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck4, _yp4, stream); break; - case 64: launch_spc_disco_fwd< 64, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck4, _yp4, stream); break; - case 128: launch_spc_disco_fwd< 128, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck4, _yp4, stream); break; - case 256: launch_spc_disco_fwd< 256, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck4, _yp4, stream); break; - case 512: launch_spc_disco_fwd< 512, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck4, _yp4, stream); break; - case 1024: launch_spc_disco_fwd<1024, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck4, _yp4, stream); break; - default: launch_gen_disco_fwd ( batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, _row_idx, _row_off, _col_idx, _val_pck4, _yp4, stream); break; + case 8: launch_spc_disco_fwd< 8, 1, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp4, stream); break; + case 16: launch_spc_disco_fwd< 16, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp4, stream); break; + case 32: launch_spc_disco_fwd< 32, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp4, stream); break; + case 64: launch_spc_disco_fwd< 64, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp4, stream); break; + case 128: launch_spc_disco_fwd< 128, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp4, stream); break; + case 256: launch_spc_disco_fwd< 256, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp4, stream); break; + case 512: launch_spc_disco_fwd< 512, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp4, stream); break; + case 1024: launch_spc_disco_fwd<1024, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp4, stream); break; + default: launch_gen_disco_fwd ( batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp4, stream); break; } } return; } -#endif // END VERSION WITH CHANNEL-LAST WITH 2D BLOCKS, 2ND DIM IDENTIFYING CHANNLES, NO EINSUM - - - - - - - // utility functions - void dump_out_kers(const char *fprefix, at::Tensor t) { - - int64_t B = t.size(0); - int64_t C = t.size(1); - int64_t K = t.size(2); - int64_t Ho = t.size(3); - int64_t Wo = t.size(4); - - at::Tensor t_h = t.to(torch::kCPU); - - auto accessor = t_h.accessor(); - - printf("Writing data to file..."); - - char fname[256]; - - for(size_t k = 0; k < K; k++) { - - snprintf(fname, sizeof(fname), "%s_%ld.txt", fprefix, k); - - FILE *fp = fopen(fname, "w"); - if (!fp) { - fprintf(stderr, "Cannot open file %s for writing!\n", fname); - exit(EXIT_FAILURE); - } - for(int64_t b = 0; b < B; b++) { - fprintf(fp, "b: %ld\n", b); - for(int64_t c = 0; c < C; c++) { - fprintf(fp, "c: %ld\n", c); - for(int64_t h = 0; h < Ho; h++) { - for(int64_t w = 0; w < Wo; w++) { - fprintf(fp, " %f", accessor[b][c][k][h][w]); - } - fprintf(fp, "\n"); - } - fprintf(fp, "\n"); - } - fprintf(fp, "\n"); - } - fclose(fp); - } - printf("done\n"); - - return; - } - torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo) { @@ -963,71 +891,11 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, int64_t nlat_out = Ho; int64_t nlon_out = Wo; - //printf("%s:%d: batch_size: %ld, nchan: %ld, nlat_in: %ld, nlon_in: %ld, nlat_out: %ld, nlon_out: %ld, nrows: %ld, nnz_tot: %ld, K: %ld\n", - // __func__, __LINE__, batch_size, nchan, nlat_in, nlon_in, nlat_out, nlon_out, nrows, col_idx.size(0), K); - // get stream auto stream = at::cuda::getCurrentCUDAStream().stream(); - // assert - static_assert(0 == (ELXTH_MAX % 2)); -#if 0 - // allocate output - int64_t out_dims[] = {B, C, K, Ho, Wo}; - auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype()); - torch::Tensor out = torch::zeros(out_dims, options); - - - // pick the correct launch config - if (Wo <= 64 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { - launch_kernel<64, 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); - } else if (Wo <= 128 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { - launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); - } else if (Wo <= 256 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { - launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); - } else if (Wo <= 512 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { - launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); - } else if (Wo <= 1024 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { - launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); - } else { - fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo, - 1024 * ELXTH_MAX); - exit(EXIT_FAILURE); - } -#else - // switch to channel-last - // version with fused enisum + // version with fused enisum auto x_type = inp.dtype(); auto xP = inp.to(torch::kFloat32).contiguous(); @@ -1035,7 +903,7 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, // to test before fusion int64_t out_dims[] = {batch_size, nlat_out, nlon_out, nchan*K}; //auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype()); - torch::Tensor yP = torch::empty(out_dims, xP.options()); + torch::Tensor yP = torch::zeros(out_dims, xP.options()); // call channel-last kernel implementation s2_disco_fwd_dispatch(batch_size, @@ -1047,194 +915,19 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, K, xP, roff_idx, + row_idx, col_idx, val, yP); - auto y = yP.reshape({batch_size, nlat_out, nlon_out, nchan, K}).to(x_type); - - torch::Tensor out = y; - -#endif // closes ORIGINAL if -#if 0 - if (std::getenv("S2_DISCO_DUMP_Y")) { - printf("waiting for kernel to finish..."); - CHECK_CUDA(cudaStreamSynchronize(stream)); - printf("done\n"); - fflush(stdout); - dump_tensor("yout.txt", out); - //dump_csr_linear("csr_disco.txt", roff_idx, ker_idx, row_idx, col_idx, val); - //dump_out_kers("out_kers", out); - } -#endif - return out; - } - - TORCH_LIBRARY_IMPL(disco_kernels, CUDA, m) - { - m.impl("forward", &disco_cuda_fwd); - } - - // utility functions - void dump_tensor(const char *fname, at::Tensor t) { - - size_t n = 1; - for(int i = 0; i < t.dim(); i++) { - n *= t.size(i); - } - - float *data_h = (float *)malloc(sizeof(*data_h)*n); - if (!data_h) { - fprintf(stderr, "Cannot allcoate %zu bytes!\n", sizeof(*data_h)*n); - exit(EXIT_FAILURE); - } - - float *float_d = t.data_ptr(); - - CHECK_CUDA(cudaMemcpy(data_h, float_d, sizeof(*data_h)*n, cudaMemcpyDeviceToHost)); - - printf("Writing data to file..."); - - FILE *fp = fopen(fname, "w"); - if (!fp) { - fprintf(stderr, "Cannot open file %s for writing!\n", fname); - exit(EXIT_FAILURE); - } - - for(size_t i = 0; i < n; i++) { - fprintf(fp, "%f\n", data_h[i]); - } - - fclose(fp); - printf("done\n"); - - free(data_h); - - return; - } - - void dump_csr(const char *fname, - at::Tensor roff, - at::Tensor cols) { - - int64_t nrows = roff.size(0)-1; - int64_t nnz = cols.size(0); - - int64_t *roff_h = new int64_t[nrows+1]; - int64_t *cols_h = new int64_t[nnz]; - - int64_t *roff_d = roff.data_ptr(); - int64_t *cols_d = cols.data_ptr(); - - CHECK_CUDA(cudaMemcpy(roff_h, roff_d, sizeof(*roff_h)*(nrows+1), cudaMemcpyDeviceToHost)); - CHECK_CUDA(cudaMemcpy(cols_h, cols_d, sizeof(*cols_d)*nnz , cudaMemcpyDeviceToHost)); - - printf("Writing data to file..."); + auto out = yP.reshape({batch_size, nlat_out, nlon_out, nchan, K}).to(x_type); - FILE *fp = fopen(fname, "w"); - if (!fp) { - fprintf(stderr, "Cannot open file %s for writing!\n", fname); - exit(EXIT_FAILURE); - } - for(int64_t r = 0; r < nrows; r++) { - - fprintf(fp, "%10ld %10ld", r, roff_h[r+1]-roff_h[r]); - - for(int64_t o = roff_h[r]; o < roff_h[r+1]; o++) { - fprintf(fp, "%10ld", cols_h[o]); - } - fprintf(fp, "\n"); - } - fclose(fp); - printf("done\n"); - - delete [] roff_h; - delete [] cols_h; + return out; } - - void dump_csr_linear(const char *fname, - at::Tensor roff, - at::Tensor kers, - at::Tensor rows, - at::Tensor cols, - at::Tensor vals) { - - int64_t nrows = roff.size(0)-1; - int64_t nnz = cols.size(0); - - int64_t *roff_h = new int64_t[nrows+1]; - int64_t *kers_h = new int64_t[nnz]; - int64_t *rows_h = new int64_t[nnz]; - int64_t *cols_h = new int64_t[nnz]; - float *vals_h = new float[nnz]; - - int64_t *roff_d = roff.data_ptr(); - int64_t *kers_d = kers.data_ptr(); - int64_t *rows_d = rows.data_ptr(); - int64_t *cols_d = cols.data_ptr(); - float *vals_d = vals.data_ptr(); - - CHECK_CUDA(cudaMemcpy(roff_h, roff_d, sizeof(*roff_h)*(nrows+1), cudaMemcpyDeviceToHost)); - CHECK_CUDA(cudaMemcpy(kers_h, kers_d, sizeof(*kers_h)*nnz , cudaMemcpyDeviceToHost)); - CHECK_CUDA(cudaMemcpy(rows_h, rows_d, sizeof(*rows_h)*nnz , cudaMemcpyDeviceToHost)); - CHECK_CUDA(cudaMemcpy(cols_h, cols_d, sizeof(*cols_h)*nnz , cudaMemcpyDeviceToHost)); - CHECK_CUDA(cudaMemcpy(vals_h, vals_d, sizeof(*vals_h)*nnz , cudaMemcpyDeviceToHost)); - - printf("Writing data to file..."); - - FILE *fp = fopen(fname, "w"); - if (!fp) { - fprintf(stderr, "Cannot open file %s for writing!\n", fname); - exit(EXIT_FAILURE); - } - fprintf(fp, "COLS:\n"); - for(int64_t r = 0; r < nrows; r++) { - fprintf(fp, "%10ld %10ld", r, roff_h[r+1]-roff_h[r]); - - for(int64_t o = roff_h[r]; o < roff_h[r+1]; o++) { - fprintf(fp, "%10ld", cols_h[o]); - } - fprintf(fp, "\n"); - } - fprintf(fp, "KERS:\n"); - for(int64_t r = 0; r < nrows; r++) { - - fprintf(fp, "%10ld %10ld", r, roff_h[r+1]-roff_h[r]); - - for(int64_t o = roff_h[r]; o < roff_h[r+1]; o++) { - fprintf(fp, "%10ld", kers_h[o]); - } - fprintf(fp, "\n"); - } - fprintf(fp, "ROWS:\n"); - for(int64_t r = 0; r < nrows; r++) { - - fprintf(fp, "%10ld %10ld", r, roff_h[r+1]-roff_h[r]); - - for(int64_t o = roff_h[r]; o < roff_h[r+1]; o++) { - fprintf(fp, "%10ld", rows_h[o]); - } - fprintf(fp, "\n"); - } - fprintf(fp, "VALS:\n"); - for(int64_t r = 0; r < nrows; r++) { - - fprintf(fp, "%10ld %10ld", r, roff_h[r+1]-roff_h[r]); - - for(int64_t o = roff_h[r]; o < roff_h[r+1]; o++) { - fprintf(fp, "%10f", vals_h[o]); - } - fprintf(fp, "\n"); - } - fclose(fp); - printf("done\n"); - - delete [] roff_h; - delete [] kers_h; - delete [] rows_h; - delete [] cols_h; - delete [] vals_h; - } +TORCH_LIBRARY_IMPL(disco_kernels, CUDA, m) +{ + m.impl("forward", &disco_cuda_fwd); +} } diff --git a/torch_harmonics/utils/csrc/csr_cuda.cuh b/torch_harmonics/utils/csrc/csr_cuda.cuh index a4653cfb..51fcb5a9 100644 --- a/torch_harmonics/utils/csrc/csr_cuda.cuh +++ b/torch_harmonics/utils/csrc/csr_cuda.cuh @@ -159,17 +159,16 @@ __device__ float4 __forceinline__ __vdiv(float s, float4 v) { template static __device__ void __sync() { - unsigned int subwarp_mask = FULL_MASK; + static_assert(BDIM_X > 0 && 0 == (BDIM_X & (BDIM_X-1))); - if constexpr(BDIM_X <= WARP_SIZE) { - const int tidy = threadIdx.y; + if constexpr(BDIM_X > WARP_SIZE) { __syncthreads(); } + else if constexpr(BDIM_X == WARP_SIZE) { __syncwarp(); } + else { // BDIM_X < WARP_SIZE constexpr unsigned int MASK = (1ull << BDIM_X)-1; - subwarp_mask = MASK << (tidy*BDIM_X); + unsigned int subwarp_id = threadIdx.y % (WARP_SIZE/BDIM_X); + unsigned int subwarp_mask = MASK << (subwarp_id*BDIM_X); + __syncwarp(subwarp_mask); } - - if constexpr(BDIM_X <= WARP_SIZE) { __syncwarp(subwarp_mask); } - else { __syncthreads(); } - return; } From 001c28ab54a7afb5a40fa0109900726471f1809f Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 8 Oct 2025 23:56:19 -0700 Subject: [PATCH 23/31] adding permute with contig call --- torch_harmonics/utils/csrc/permute_cuda.cu | 30 +++++++++++++--------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/torch_harmonics/utils/csrc/permute_cuda.cu b/torch_harmonics/utils/csrc/permute_cuda.cu index e97eb574..eb23ea4f 100644 --- a/torch_harmonics/utils/csrc/permute_cuda.cu +++ b/torch_harmonics/utils/csrc/permute_cuda.cu @@ -59,20 +59,23 @@ static int getPtxver() { torch::Tensor permute_4D_to0231(torch::Tensor src) { - auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device()); - torch::Tensor dst = torch::empty({src.size(0), src.size(2), src.size(3), src.size(1)}, options); + // make input contiguous: + auto srcc = src.contiguous(); + + auto options = torch::TensorOptions().dtype(srcc.dtype()).device(srcc.device()); + torch::Tensor dst = torch::empty({srcc.size(0), srcc.size(2), srcc.size(3), srcc.size(1)}, options); const int ptxv = getPtxver(); // to be further specialized for additional archs, if necessary if (ptxv < 100) { - AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_generic", ([&] { - launch_permute_to0231(src, dst); + AT_DISPATCH_FLOATING_TYPES(srcc.scalar_type(), "permute_to0231_k_tile_generic", ([&] { + launch_permute_to0231(srcc, dst); })); CHECK_ERROR("permute_to0231_k_tile_generic"); } else { - AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_sm100", ([&] { - launch_permute_to0231(src, dst); + AT_DISPATCH_FLOATING_TYPES(srcc.scalar_type(), "permute_to0231_k_tile_sm100", ([&] { + launch_permute_to0231(srcc, dst); })); CHECK_ERROR("permute_to0231_k_tile_sm100"); } @@ -82,20 +85,23 @@ torch::Tensor permute_4D_to0231(torch::Tensor src) { torch::Tensor permute_4D_to0312(torch::Tensor src) { - auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device()); - torch::Tensor dst = torch::empty({src.size(0), src.size(3), src.size(1), src.size(2)}, options); + // make input contiguous: + auto srcc = src.contiguous(); + + auto options = torch::TensorOptions().dtype(srcc.dtype()).device(srcc.device()); + torch::Tensor dst = torch::empty({srcc.size(0), srcc.size(3), srcc.size(1), srcc.size(2)}, options); const int ptxv = getPtxver(); // to be further specialized for additional archs, if necessary if (ptxv < 100) { - AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_generic", ([&] { - launch_permute_to0312(src, dst); + AT_DISPATCH_FLOATING_TYPES(srcc.scalar_type(), "permute_to0312_k_tile_generic", ([&] { + launch_permute_to0312(srcc, dst); })); CHECK_ERROR("permute_to0312_k_tile_generic"); } else { - AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_sm100", ([&] { - launch_permute_to0312(src, dst); + AT_DISPATCH_FLOATING_TYPES(srcc.scalar_type(), "permute_to0312_k_tile_sm100", ([&] { + launch_permute_to0312(srcc, dst); })); CHECK_ERROR("permute_to0312_k_tile_sm100"); } From 104bf6ddcc8c701bb12b21c06d13be3848678d07 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 13 Oct 2025 02:16:04 -0700 Subject: [PATCH 24/31] dbg --- tests/test_cache.py | 41 ++++- tests/test_distributed_convolution.py | 163 +++++++++++++++--- torch_harmonics/disco/convolution.py | 26 +-- .../distributed/distributed_convolution.py | 86 +++++---- torch_harmonics/filter_basis.py | 2 - 5 files changed, 254 insertions(+), 64 deletions(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index 3a2d1720..747a5c70 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -34,8 +34,11 @@ import math import torch +from torch_harmonics.cache import lru_cache +from testutils import compare_tensors class TestCacheConsistency(unittest.TestCase): + def test_consistency(self, verbose=False): if verbose: print("Testing that cache values does not get modified externally") @@ -47,7 +50,43 @@ def test_consistency(self, verbose=False): # perform in-place modification of leg1 leg1 *= -1.0 leg2 = _precompute_legpoly(10, 10, cost) - self.assertFalse(torch.allclose(leg1, leg2)) + self.assertFalse(compare_tensors("legendre", leg2, leg1, verbose=verbose)) + + + def test_pytorch_tensors(self, verbose=False): + if verbose: + print("Testing that PyTorch tensors are cached") + + @lru_cache(typed=True, copy=True) + def test_func(tens1, tens2): + return tens1, tens2 + + # initial tensors + tens1 = torch.randn(4, 4, dtype=torch.float32) + tens2 = torch.randn(4, 4, dtype=torch.float32) + + # retrieve from cache + tens1c, tens2c = test_func(tens1, tens2) + + # modify copies + tens1c *= -1.0 + tens2c *= -1.0 + + # retrieve from cache again + tens1cc, tens2cc = test_func(tens1, tens2) + + if verbose: + print("first tensor", tens1) + print("first tensor after modification", tens1c) + print("first tensor cached", tens1cc) + print("second tensor", tens2) + print("second tensor after modification", tens2c) + print("second tensor cached", tens2cc) + + self.assertFalse(compare_tensors("first cached", tens1cc, tens1c, verbose=verbose)) + self.assertFalse(compare_tensors("second cached", tens2cc, tens2c, verbose=verbose)) + self.assertTrue(compare_tensors("first raw", tens1, tens1cc, verbose=verbose)) + self.assertTrue(compare_tensors("second raw", tens2, tens2cc, verbose=verbose)) if __name__ == "__main__": diff --git a/tests/test_distributed_convolution.py b/tests/test_distributed_convolution.py index c364ac5a..2dd30ecb 100644 --- a/tests/test_distributed_convolution.py +++ b/tests/test_distributed_convolution.py @@ -38,6 +38,17 @@ import torch.distributed as dist import torch_harmonics as th import torch_harmonics.distributed as thd +from torch_harmonics.disco import optimized_kernels_is_available + +from disco_helpers import preprocess_psi +from torch_harmonics.filter_basis import get_filter_basis +from torch_harmonics.disco.convolution import _precompute_convolution_tensor_s2 +from torch_harmonics.distributed.distributed_convolution import _split_distributed_convolution_tensor_s2 + +from testutils import compare_tensors + +if not optimized_kernels_is_available(): + print(f"Warning: Couldn't import optimized disco convolution kernels") class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): @@ -128,6 +139,9 @@ def _split_helper(self, tensor): tensor_list_local = thd.split_tensor_along_dim(tensor_local, dim=-2, num_chunks=self.grid_size_h) tensor_local = tensor_list_local[self.hrank] + # make contiguous + tensor_local = tensor_local.contiguous() + return tensor_local def _gather_helper_fwd(self, tensor, B, C, convolution_dist): @@ -156,6 +170,9 @@ def _gather_helper_fwd(self, tensor, B, C, convolution_dist): dist.all_gather(olist, tensor_gather, group=self.h_group) tensor_gather = torch.cat(olist, dim=-2) + # make contiguous + tensor_gather = tensor_gather.contiguous() + return tensor_gather def _gather_helper_bwd(self, tensor, B, C, convolution_dist): @@ -165,6 +182,7 @@ def _gather_helper_bwd(self, tensor, B, C, convolution_dist): lon_shapes = convolution_dist.lon_in_shapes # gather in W + tensor = tensor.contiguous() if self.grid_size_w > 1: gather_shapes = [(B, C, lat_shapes[self.hrank], w) for w in lon_shapes] olist = [torch.empty(shape, dtype=tensor.dtype, device=tensor.device) for shape in gather_shapes] @@ -175,6 +193,7 @@ def _gather_helper_bwd(self, tensor, B, C, convolution_dist): tensor_gather = tensor # gather in H + tensor_gather = tensor_gather.contiguous() if self.grid_size_h > 1: gather_shapes = [(B, C, h, convolution_dist.nlon_in) for h in lat_shapes] olist = [torch.empty(shape, dtype=tensor_gather.dtype, device=tensor_gather.device) for shape in gather_shapes] @@ -182,31 +201,139 @@ def _gather_helper_bwd(self, tensor, B, C, convolution_dist): dist.all_gather(olist, tensor_gather, group=self.h_group) tensor_gather = torch.cat(olist, dim=-2) + # make contiguous + tensor_gather = tensor_gather.contiguous() + return tensor_gather @parameterized.expand( [ + # piecewise linear + # normal isotropic + [(16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular"], + [(17, 32), (17, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular"], + # normal anisotropic + [(16, 32), (16, 32), (3, 4), "piecewise linear", "mean", "equiangular", "equiangular"], + [(16, 32), (16, 32), (3, 2), "piecewise linear", "mean", "equiangular", "equiangular"], + # downsampling isotropic + [(16, 32), (8, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular"], + [(17, 32), (9, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular"], + # downsampling anisotropic + [(16, 32), (8, 16), (3, 4), "piecewise linear", "mean", "equiangular", "equiangular"], + [(16, 32), (8, 16), (3, 2), "piecewise linear", "mean", "equiangular", "equiangular"], + # morlet + # normal isotropic + [(16, 32), (16, 32), (1), "morlet", "mean", "equiangular", "equiangular"], # important for attention + [(16, 32), (16, 32), (3), "morlet", "mean", "equiangular", "equiangular"], + [(17, 32), (17, 32), (3), "morlet", "mean", "equiangular", "equiangular"], + # normal anisotropic + [(16, 32), (16, 32), (3, 4), "morlet", "mean", "equiangular", "equiangular"], + [(16, 32), (16, 32), (3, 2), "morlet", "mean", "equiangular", "equiangular"], + # downsampling isotropic + [(16, 32), (8, 16), (1), "morlet", "mean", "equiangular", "equiangular"], # important for attention + [(16, 32), (8, 16), (3), "morlet", "mean", "equiangular", "equiangular"], + [(17, 32), (9, 16), (3), "morlet", "mean", "equiangular", "equiangular"], + # downsampling anisotropic + [(16, 32), (8, 16), (3, 4), "morlet", "mean", "equiangular", "equiangular"], + [(16, 32), (8, 16), (3, 2), "morlet", "mean", "equiangular", "equiangular"], + # zernike + # normal + [(16, 32), (16, 32), (1), "zernike", "mean", "equiangular", "equiangular"], + [(16, 32), (16, 32), (3, 3), "zernike", "mean", "equiangular", "equiangular"], + [(17, 32), (17, 32), (3, 3), "zernike", "mean", "equiangular", "equiangular"], + # downsampling + [(16, 32), (8, 16), (1), "zernike", "mean", "equiangular", "equiangular"], + [(16, 32), (8, 16), (3, 3), "zernike", "mean", "equiangular", "equiangular"], + [(17, 32), (9, 16), (3, 3), "zernike", "mean", "equiangular", "equiangular"], + ], + skip_on_empty=True, + ) + def test_distributed_convolution_tensor_integrity(self, in_shape, out_shape, kernel_shape, basis_type, basis_norm_mode, grid_in, grid_out, verbose=False): + + nlat_in, _ = in_shape + nlat_out, _ = out_shape + + # get filter basis + filter_basis = get_filter_basis(kernel_shape=kernel_shape, basis_type=basis_type) + + # use default value for cutoff + theta_cutoff = torch.pi / float(nlat_out - 1) + + idx, vals, _ = _precompute_convolution_tensor_s2( + in_shape=in_shape, + out_shape=out_shape, + filter_basis=filter_basis, + grid_in=grid_in, + grid_out=grid_out, + theta_cutoff=theta_cutoff, + transpose_normalization=False, + basis_norm_mode=basis_norm_mode, + merge_quadrature=True, + ) + + # split tensor along latitude + idx, vals = _split_distributed_convolution_tensor_s2(idx, vals, in_shape) + + # define contiguous tensors + ker_idx = idx[0, ...].contiguous() + row_idx = idx[1, ...].contiguous() + col_idx = idx[2, ...].contiguous() + vals = vals.contiguous() + + # sort values + roff_idx = preprocess_psi(filter_basis.kernel_size, nlat_out, ker_idx, row_idx, col_idx, vals).contiguous() + + print(f"{self.hrank} after splitting sorted \n ker = {ker_idx}\n row = {row_idx}\n col = {col_idx}", flush=True) + print(f"{self.hrank} roff_idx = {roff_idx}", flush=True) + print(f"{self.hrank} ker_shape = {ker_idx.shape}, row_shape = {row_idx.shape}, col_shape = {col_idx.shape}, vals_shape = {vals.shape}", flush=True) + + # check shapes + self.assertTrue(ker_idx.shape[0] == row_idx.shape[0], f"ker_idx and row_idx have to have the same shape: found {ker_idx.shape[0]} and {row_idx.shape[0]}") + self.assertTrue(ker_idx.shape[0] == col_idx.shape[0], f"ker_idx and col_idx have to have the same shape: found {ker_idx.shape[0]} and {col_idx.shape[0]}") + self.assertTrue(ker_idx.shape[0] == vals.shape[0], f"ker_idx and vals have to have the same shape: found {ker_idx.shape[0]} and {vals.shape[0]}") + + # the multiplicitiy in ker_idx has to be the same for all kernel indices + unique, counts = torch.unique(ker_idx, return_counts=True) + self.assertTrue(torch.all(counts.max() == counts), f"The multiplicity in ker_idx has to be the same for all kernel indices: found {counts} for entries {unique}") + + if verbose: + print(f"\n ker_idx = {ker_idx},\n row_idx = {row_idx},\n col_idx = {col_idx}") + + # the following has to be true: the row_idx and col_idx have to be the same for all kernel indices + row_idx_ref = row_idx[ker_idx == 0] + col_idx_ref = col_idx[ker_idx == 0] + for k in range(1, filter_basis.kernel_size): + self.assertTrue(torch.all(row_idx_ref == row_idx[ker_idx == k]), f"The row_idx has to be the same for all kernel indices: found {row_idx_ref} for entries {ker_idx == k}") + self.assertTrue(torch.all(col_idx_ref == col_idx[ker_idx == k]), f"The row_idx has to be the same for all kernel indices: found {col_idx_ref} for entries {ker_idx == k}") + + + @parameterized.expand( + [ + # Forward [128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5], [129, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5], [128, 256, 128, 256, 32, 8, (3, 2), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5], [128, 256, 64, 128, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5], [128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 2, "equiangular", "equiangular", False, 1e-5], [128, 256, 128, 256, 32, 6, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5], + [129, 256, 129, 256, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", False, 1e-5], + [129, 256, 65, 128, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", False, 1e-5], + # Transpose [128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5], [129, 256, 129, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5], [128, 256, 128, 256, 32, 8, (3, 2), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5], [64, 128, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5], [128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 2, "equiangular", "equiangular", True, 1e-5], [128, 256, 128, 256, 32, 6, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5], - [129, 256, 129, 256, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", False, 1e-5], [129, 256, 129, 256, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", True, 1e-5], [65, 128, 129, 256, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", True, 1e-5], - [129, 256, 65, 128, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", False, 1e-5], - ] + ], + skip_on_empty=True, ) def test_distributed_disco_conv( - self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, basis_type, basis_norm_mode, groups, grid_in, grid_out, transpose, tol + self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, basis_type, basis_norm_mode, groups, grid_in, grid_out, transpose, tol, verbose=True ): + verbose = verbose and (self.world_rank == 0) B, C, H, W = batch_size, num_chan, nlat_in, nlon_in @@ -222,21 +349,22 @@ def test_distributed_disco_conv( grid_in=grid_in, grid_out=grid_out, bias=True, + optimized_kernel=True, ) # set up handles if transpose: - conv_local = th.DiscreteContinuousConvTransposeS2(**disco_args).to(self.device) + conv_full = th.DiscreteContinuousConvTransposeS2(**disco_args).to(self.device) conv_dist = thd.DistributedDiscreteContinuousConvTransposeS2(**disco_args).to(self.device) else: - conv_local = th.DiscreteContinuousConvS2(**disco_args).to(self.device) + conv_full = th.DiscreteContinuousConvS2(**disco_args).to(self.device) conv_dist = thd.DistributedDiscreteContinuousConvS2(**disco_args).to(self.device) # copy the weights from the local conv into the dist conv with torch.no_grad(): - conv_dist.weight.copy_(conv_local.weight) + conv_dist.weight.copy_(conv_full.weight) if disco_args["bias"]: - conv_dist.bias.copy_(conv_local.bias) + conv_dist.bias.copy_(conv_full.bias) # create tensors inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device) @@ -244,7 +372,7 @@ def test_distributed_disco_conv( # local conv # FWD pass inp_full.requires_grad = True - out_full = conv_local(inp_full) + out_full = conv_full(inp_full) # create grad for backward with torch.no_grad(): @@ -257,12 +385,14 @@ def test_distributed_disco_conv( # distributed conv # FWD pass - inp_local = self._split_helper(inp_full) + with torch.no_grad(): + inp_local = self._split_helper(inp_full) inp_local.requires_grad = True out_local = conv_dist(inp_local) # BWD pass - ograd_local = self._split_helper(ograd_full) + with torch.no_grad(): + ograd_local = self._split_helper(ograd_full) out_local = conv_dist(inp_local) out_local.backward(ograd_local) igrad_local = inp_local.grad.clone() @@ -270,19 +400,12 @@ def test_distributed_disco_conv( # evaluate FWD pass with torch.no_grad(): out_gather_full = self._gather_helper_fwd(out_local, B, C, conv_dist) - err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2))) - if self.world_rank == 0: - print(f"final relative error of output: {err.item()}") - self.assertTrue(err.item() <= tol) + self.assertTrue(compare_tensors("output", out_gather_full, out_full, rtol=tol, atol=tol, verbose=verbose)) # evaluate BWD pass with torch.no_grad(): igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, conv_dist) - - err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2))) - if self.world_rank == 0: - print(f"final relative error of gradients: {err.item()}") - self.assertTrue(err.item() <= tol) + self.assertTrue(compare_tensors("input gradient", igrad_gather_full, igrad_full, rtol=tol, atol=tol, verbose=verbose)) if __name__ == "__main__": diff --git a/torch_harmonics/disco/convolution.py b/torch_harmonics/disco/convolution.py index 78a8f08d..3e96b3c5 100644 --- a/torch_harmonics/disco/convolution.py +++ b/torch_harmonics/disco/convolution.py @@ -162,8 +162,8 @@ def _normalize_convolution_tensor_s2( @lru_cache(typed=True, copy=True) def _precompute_convolution_tensor_s2( - in_shape: Tuple[int], - out_shape: Tuple[int], + in_shape: Tuple[int, int], + out_shape: Tuple[int, int], filter_basis: FilterBasis, grid_in: Optional[str]="equiangular", grid_out: Optional[str]="equiangular", @@ -190,9 +190,9 @@ def _precompute_convolution_tensor_s2( Parameters ----------- - in_shape: Tuple[int] + in_shape: Tuple[int, int] Input shape of the convolution tensor - out_shape: Tuple[int] + out_shape: Tuple[int, int] Output shape of the convolution tensor filter_basis: FilterBasis Filter basis functions @@ -450,15 +450,16 @@ def __init__( self.nlat_in, self.nlon_in = in_shape self.nlat_out, self.nlon_out = out_shape + self.theta_cutoff = theta_cutoff # make sure the p-shift works by checking that longitudes are divisible assert self.nlon_in % self.nlon_out == 0 # heuristic to compute theta cutoff based on the bandlimit of the input field and overlaps of the basis functions - if theta_cutoff is None: - theta_cutoff = torch.pi / float(self.nlat_out - 1) + if self.theta_cutoff is None: + self.theta_cutoff = torch.pi / float(self.nlat_out - 1) - if theta_cutoff <= 0.0: + if self.theta_cutoff <= 0.0: raise ValueError("Error, theta_cutoff has to be positive.") idx, vals, _ = _precompute_convolution_tensor_s2( @@ -467,7 +468,7 @@ def __init__( self.filter_basis, grid_in=grid_in, grid_out=grid_out, - theta_cutoff=theta_cutoff, + theta_cutoff=self.theta_cutoff, transpose_normalization=False, basis_norm_mode=basis_norm_mode, merge_quadrature=True, @@ -609,15 +610,16 @@ def __init__( self.nlat_in, self.nlon_in = in_shape self.nlat_out, self.nlon_out = out_shape + self.theta_cutoff = theta_cutoff # make sure the p-shift works by checking that longitudes are divisible assert self.nlon_out % self.nlon_in == 0 # bandlimit - if theta_cutoff is None: - theta_cutoff = torch.pi / float(self.nlat_in - 1) + if self.theta_cutoff is None: + self.theta_cutoff = torch.pi / float(self.nlat_in - 1) - if theta_cutoff <= 0.0: + if self.theta_cutoff <= 0.0: raise ValueError("Error, theta_cutoff has to be positive.") # switch in_shape and out_shape since we want the transpose convolution @@ -627,7 +629,7 @@ def __init__( self.filter_basis, grid_in=grid_out, grid_out=grid_in, - theta_cutoff=theta_cutoff, + theta_cutoff=self.theta_cutoff, transpose_normalization=True, basis_norm_mode=basis_norm_mode, merge_quadrature=True, diff --git a/torch_harmonics/distributed/distributed_convolution.py b/torch_harmonics/distributed/distributed_convolution.py index 8504b446..5786d5e7 100644 --- a/torch_harmonics/distributed/distributed_convolution.py +++ b/torch_harmonics/distributed/distributed_convolution.py @@ -58,7 +58,6 @@ def _split_distributed_convolution_tensor_s2( idx: torch.Tensor, vals: torch.Tensor, in_shape: Tuple[int], - out_shape: Tuple[int], ): """ Splits a pre-computed convolution tensor along the latitude dimension for distributed processing. @@ -75,8 +74,6 @@ def _split_distributed_convolution_tensor_s2( Values of the pre-computed convolution tensor in_shape: Tuple[int] Shape of the input tensor (nlat_in, nlon_in) - out_shape: Tuple[int] - Shape of the output tensor (nlat_out, nlon_out) Returns ------- @@ -87,7 +84,6 @@ def _split_distributed_convolution_tensor_s2( """ nlat_in, nlon_in = in_shape - nlat_out, nlon_out = out_shape comm_size_polar = polar_group_size() comm_rank_polar = polar_group_rank() @@ -214,8 +210,13 @@ def __init__( merge_quadrature=True, ) + print(f"{self.comm_rank_polar} before splitting shapes idx = {idx.shape}, vals = {vals.shape}") + print(f"{self.comm_rank_polar} before splitting \n ker = {idx[0]}\n row = {idx[1]}\n col = {idx[2]}", flush=True) + # split the convolution tensor along latitude - idx, vals = _split_distributed_convolution_tensor_s2(idx, vals, in_shape, out_shape) + idx, vals = _split_distributed_convolution_tensor_s2(idx, vals, in_shape) + + #print(f"{self.comm_rank_polar} after splitting \n ker = {idx[0]}\n row = {idx[1]}\n col = {idx[2]}", flush=True) # sort the values ker_idx = idx[0, ...].contiguous() @@ -225,7 +226,7 @@ def __init__( if self.optimized_kernel: # preprocessed data-structure for GPU kernel - roff_idx = preprocess_psi(self.kernel_size, self.nlat_out_local, ker_idx, row_idx, col_idx, vals).contiguous() + roff_idx = preprocess_psi(self.kernel_size, self.nlat_out, ker_idx, row_idx, col_idx, vals) self.register_buffer("psi_roff_idx", roff_idx, persistent=False) # save all datastructures @@ -234,12 +235,15 @@ def __init__( self.register_buffer("psi_col_idx", col_idx, persistent=False) self.register_buffer("psi_vals", vals, persistent=False) + print(f"{self.comm_rank_polar} after splitting sorted \n ker_idx = {self.psi_ker_idx}\n row = {self.psi_row_idx}\n col = {self.psi_col_idx}", flush=True) + print(f"{self.comm_rank_polar} roff_idx = {self.psi_roff_idx}", flush=True) + # store psi jic: if not self.optimized_kernel: self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, self.nlat_in_local, self.nlat_out_local) def extra_repr(self): - return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" + return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize_in * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" @property def psi_idx(self): @@ -250,19 +254,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # store number of channels num_chans = x.shape[1] + print(f"{self.comm_rank_polar} input shape", x.shape) + # h and w is split. First we make w local by transposing into channel dim if self.comm_size_azimuth > 1: x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes) - if self.optimized_kernel: - polar_dim = -3 - azimuth_dim = -2 - chan_dim = -1 + print(f"{self.comm_rank_polar} after azimuth transpose forward", x.shape) - # permute input + if self.optimized_kernel: + # permute input: B, C, Hi, Wi -> B, Hi, Wi, C xp = permute_to_0231(x) - # disco contraction + print(f"{self.comm_rank_polar} before disco contraction", xp.shape) + + # disco contraction: B, Hi, Wi, C -> B, Ho, Wo, C, K x = _disco_s2_contraction_optimized( xp, self.psi_roff_idx, @@ -271,32 +277,49 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.psi_col_idx, self.psi_vals, self.kernel_size, - self.nlat_out_local, + self.nlat_out, self.nlon_out ) + + # extract shapes + polar_dim = -4 + azimuth_dim = -3 + chan_dim = -2 + else: + # disco contraction: B, C, Hi, Wi -> B, C, K, Ho, Wo + x = _disco_s2_contraction_torch(x, self.psi.to(x.device), self.nlon_out) + + # extract shapes polar_dim = -2 azimuth_dim = -1 - chan_dim = -3 + chan_dim = -4 - # disco contraction - x = _disco_s2_contraction_torch(x, self.psi.to(x.device), self.nlon_out) + print(f"{self.comm_rank_polar} after disco contraction", x.shape) + + print(f"{self.comm_rank_polar} polar dim", polar_dim) + print(f"{self.comm_rank_polar} azimuth dim", azimuth_dim) + print(f"{self.comm_rank_polar} chan dim", chan_dim) # perform reduce scatter in polar region x = reduce_from_polar_region(x) x = scatter_to_polar_region(x, polar_dim) + print(f"{self.comm_rank_polar} after polar scatter", x.shape) + # now we can transpose back the result, so that lon is split and channels are local if self.comm_size_azimuth > 1: chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth) x = distributed_transpose_azimuth.apply(x, (azimuth_dim, chan_dim), chan_shapes) + print(f"{self.comm_rank_polar} after azimuth transpose inverse", x.shape, num_chans, azimuth_dim, chan_dim) + # extract shape if self.optimized_kernel: # weight multiplication B, H, W, _, K = x.shape x = x.reshape(B, H, W, self.groups, self.groupsize_in, K) - outp = torch.einsum("bxygck,gock->bxygo", x, self.weight) + outp = torch.einsum("bxygck,gock->bxygo", x, self.weight.reshape(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size)) outp = outp.reshape(B, H, W, -1).contiguous() # permute output @@ -305,9 +328,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # weight multiplication B, _, K, H, W = x.shape x = x.reshape(B, self.groups, self.groupsize_in, K, H, W) - out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight) + out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size)) out = out.reshape(B, -1, H, W).contiguous() + print(f"{self.comm_rank_polar} after weight multiplication", out.shape) + if self.bias is not None: out = out + self.bias.reshape(1, -1, 1, 1) @@ -422,7 +447,7 @@ def __init__( # split the convolution tensor along latitude, again, we need to swap the meaning # of in_shape and out_shape - idx, vals = _split_distributed_convolution_tensor_s2(idx, vals, out_shape, in_shape) + idx, vals = _split_distributed_convolution_tensor_s2(idx, vals, out_shape) # sort the values ker_idx = idx[0, ...].contiguous() @@ -432,7 +457,7 @@ def __init__( if self.optimized_kernel: # preprocessed data-structure for GPU kernel - roff_idx = preprocess_psi(self.kernel_size, self.nlat_in_local, ker_idx, row_idx, col_idx, vals).contiguous() + roff_idx = preprocess_psi(self.kernel_size, self.nlat_in, ker_idx, row_idx, col_idx, vals).contiguous() self.register_buffer("psi_roff_idx", roff_idx, persistent=False) # save all datastructures @@ -446,7 +471,7 @@ def __init__( self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, self.nlat_in_local, self.nlat_out_local, semi_transposed=True) def extra_repr(self): - return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" + return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize_in * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" @property def psi_idx(self): @@ -463,21 +488,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # weight multiplication xp = xp.reshape(B, H, W, self.groups, self.groupsize_in) - x = torch.einsum("bxygc,gock->bxygok", xp, self.weight) + x = torch.einsum("bxygc,gock->bxygok", xp, self.weight.reshape(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size)) x = x.reshape(B, H, W, -1, self.kernel_size).contiguous() - polar_dim = -3 - azimuth_dim = -2 - chan_dim = -1 + # count from front since this does not change + # after disco conv + polar_dim = 1 + azimuth_dim = 2 + chan_dim = 3 else: # weight multiplication x = x.reshape(B, self.groups, self.groupsize_in, H, W) - x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight) - x = x.reshape(B, -1, x.shape[-3], H, W).contiguous() + x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size)) + x = x.reshape(B, -1, self.kernel_size, H, W).contiguous() + # count from back since this changes after disco transpose conv polar_dim = -2 azimuth_dim = -1 chan_dim = 1 - # save number of channels + # store number of channels num_chans = x.shape[chan_dim] # transpose such that lon is local, channels are split diff --git a/torch_harmonics/filter_basis.py b/torch_harmonics/filter_basis.py index a35bd842..086518a7 100644 --- a/torch_harmonics/filter_basis.py +++ b/torch_harmonics/filter_basis.py @@ -130,8 +130,6 @@ def _compute_support_vals_isotropic(self, r: torch.Tensor, phi: torch.Tensor, r_ else: ir = (ikernel + 0.5) * dr - #print("ir", ir, ir.min()) - # find the indices where the rotated position falls into the support of the kernel #iidx = torch.argwhere((r.abs() <= dr) & (r <= r_cutoff)) #vals = 1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr From 455ae903d30cdf927288d1398e41578eeefddd44 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 20 Oct 2025 00:31:22 -0700 Subject: [PATCH 25/31] after rebase --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 42832a1b..7e801b49 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ def get_compile_args(module_name): cpp_extra_flags.append("-fopenmp") nvcc_extra_flags = [] - if True or profile_mode: + if profile_mode: nvcc_extra_flags.append("-lineinfo") nvcc_extra_flags.append("-Xptxas=-v") From 8fc8427a2646f0779ea4d1571347059e44a1ab9c Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 20 Oct 2025 02:21:01 -0700 Subject: [PATCH 26/31] small fixes after rebase --- tests/testutils.py | 44 ++++++++++++++++++++++++++++ torch_harmonics/disco/convolution.py | 16 +++++----- 2 files changed, 53 insertions(+), 7 deletions(-) create mode 100644 tests/testutils.py diff --git a/tests/testutils.py b/tests/testutils.py new file mode 100644 index 00000000..ef7ef0c0 --- /dev/null +++ b/tests/testutils.py @@ -0,0 +1,44 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import torch + +def compare_tensors(msg, tensor, tensor_ref, rtol=1e-8, atol=1e-5, verbose=False): + allclose = torch.allclose(tensor, tensor_ref, rtol=rtol, atol=atol) + if (not allclose) and verbose: + diff = torch.abs(tensor - tensor_ref) + print(f"{msg} absolute tensor diff: min = {torch.min(diff)}, mean = {torch.mean(diff)}, max = {torch.max(diff)}.") + reldiff = diff / torch.abs(tensor_ref) + print(f"{msg} relative tensor diff: min = {torch.min(reldiff)}, mean = {torch.mean(reldiff)}, max = {torch.max(reldiff)}.") + # find element with maximum difference + index = torch.argmax(diff) + print(f"{msg} element {index} with maximum difference: value = {tensor.flatten()[index]}, reference value = {tensor_ref.flatten()[index]}, diff = {diff.flatten()[index]}.") + return allclose diff --git a/torch_harmonics/disco/convolution.py b/torch_harmonics/disco/convolution.py index 3e96b3c5..1a87c70f 100644 --- a/torch_harmonics/disco/convolution.py +++ b/torch_harmonics/disco/convolution.py @@ -370,8 +370,10 @@ def __init__( raise ValueError("Error, the number of output channels has to be an integer multiple of the group size") self.groupsize_in = in_channels // self.groups self.groupsize_out = out_channels // self.groups + # keep this for backward compatibility + self.groupsize = self.groupsize_in scale = math.sqrt(1.0 / self.groupsize_in / self.kernel_size) - self.weight = nn.Parameter(scale * torch.randn(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size)) + self.weight = nn.Parameter(scale * torch.randn(self.groups * self.groupsize_out, self.groupsize_in, self.kernel_size)) if bias: self.bias = nn.Parameter(torch.zeros(out_channels)) @@ -496,7 +498,7 @@ def __init__( self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out) def extra_repr(self): - return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" + return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize_in * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" @property def psi_idx(self): @@ -524,7 +526,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # weight multiplication B, H, W, _, K = xpc.shape xpc = xpc.reshape(B, H, W, self.groups, self.groupsize_in, K) - outp = torch.einsum("bxygck,gock->bxygo", xpc, self.weight) + outp = torch.einsum("bxygck,gock->bxygo", xpc, self.weight.reshape(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size)) outp = outp.reshape(B, H, W, -1).contiguous() # permute output @@ -538,7 +540,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.reshape(B, self.groups, self.groupsize_in, K, H, W) # weight multiplication - out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight) + out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size)) out = out.reshape(B, -1, H, W).contiguous() if self.bias is not None: @@ -657,7 +659,7 @@ def __init__( self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, semi_transposed=True) def extra_repr(self): - return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" + return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize_in * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" @property def psi_idx(self): @@ -674,7 +676,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # weight multiplication xp = xp.reshape(B, H, W, self.groups, self.groupsize_in) - xpc = torch.einsum("bxygc,gock->bxygok", xp, self.weight) + xpc = torch.einsum("bxygc,gock->bxygok", xp, self.weight.reshape(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size)) xpc = xpc.reshape(B, H, W, -1, self.kernel_size).contiguous() # disco contraction @@ -695,7 +697,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: else: # weight multiplication x = x.reshape(B, self.groups, self.groupsize_in, H, W) - xc = torch.einsum("bgcxy,gock->bgokxy", x, self.weight) + xc = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size)) xc = xc.reshape(B, self.groups* self.groupsize_out, -1, H, W).contiguous() # disco contraction From d9b0a2a9453162ea15da608aca20933a6bb75d85 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 27 Oct 2025 02:06:34 -0700 Subject: [PATCH 27/31] adding theta cutoff to test --- tests/test_attention.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_attention.py b/tests/test_attention.py index 9fa0ea82..98a930b6 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -326,12 +326,12 @@ def test_optimized_pt2_compatibility(self, batch_size, channels, heads, in_shape @parameterized.expand( [ # self attention - [1, 256, 1, (91, 180), (91, 180), "equiangular", "equiangular"], + [1, 256, 1, (91, 180), (91, 180), "equiangular", "equiangular", None], ], skip_on_empty=True, ) @unittest.skipUnless(optimized_kernels_is_available() and _run_perf_tests, "skipping performance test because optimized kernels are not available or perf tests are disabled") - def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, verbose=False): + def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, theta_cutoff, verbose=False): if (self.device.type == "cuda") and (not cuda_kernels_is_available()): raise unittest.SkipTest("skipping test because CUDA kernels are not available") @@ -350,7 +350,8 @@ def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, g att_optimized = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads, in_shape=in_shape, out_shape=out_shape, - grid_in=grid_in, grid_out=grid_out, bias=True, + grid_in=grid_in, grid_out=grid_out, + theta_cutoff=theta_cutoff, bias=True, optimized_kernel=True).to(self.device) # random weights From 52f5253510e558d7b968fa0d102384bb94a5800d Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Tue, 28 Oct 2025 06:58:27 -0700 Subject: [PATCH 28/31] cleanups --- setup.py | 5 + torch_harmonics/attention/csrc/attention.h | 7 +- .../attention/csrc/attention_cpu.h | 2 + .../attention/csrc/attention_cuda.cuh | 7 +- .../attention/csrc/attention_cuda_bwd.cu | 149 +----------------- .../attention/csrc/attention_cuda_fwd.cu | 3 +- .../attention/csrc/attention_cuda_utils.cu | 2 +- .../attention/csrc/attention_cuda_utils.cuh | 4 +- torch_harmonics/disco/csrc/cudamacro.h | 47 ------ torch_harmonics/disco/csrc/disco.h | 7 - torch_harmonics/disco/csrc/disco_cpu.h | 2 + torch_harmonics/disco/csrc/disco_cuda.cuh | 11 +- torch_harmonics/disco/csrc/disco_cuda_bwd.cu | 1 - torch_harmonics/disco/csrc/disco_cuda_fwd.cu | 2 +- torch_harmonics/disco/csrc/disco_helpers.cpp | 4 +- .../cudamacro.h => utils/csrc/cppmacro.h} | 21 +-- torch_harmonics/utils/csrc/csr_cuda.cu | 4 - torch_harmonics/utils/csrc/csr_cuda.cuh | 10 +- torch_harmonics/utils/csrc/cudamacro.h | 11 ++ torch_harmonics/utils/csrc/permute_cuda.cu | 4 - torch_harmonics/utils/csrc/permute_cuda.cuh | 8 - 21 files changed, 43 insertions(+), 268 deletions(-) delete mode 100644 torch_harmonics/disco/csrc/cudamacro.h rename torch_harmonics/{attention/csrc/cudamacro.h => utils/csrc/cppmacro.h} (63%) diff --git a/setup.py b/setup.py index 7e801b49..841b865e 100644 --- a/setup.py +++ b/setup.py @@ -119,6 +119,7 @@ def get_ext_modules(): [ "torch_harmonics/disco/csrc/disco_helpers.cpp", ], + include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")], extra_compile_args=get_helpers_compile_args(), ) ) @@ -130,6 +131,7 @@ def get_ext_modules(): [ "torch_harmonics/attention/csrc/attention_helpers.cpp", ], + include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")], extra_compile_args=get_helpers_compile_args(), ) ) @@ -189,6 +191,7 @@ def get_ext_modules(): CppExtension( "torch_harmonics.disco._C", disco_sources, + include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")], extra_compile_args=get_compile_args("disco") ) ) @@ -212,6 +215,7 @@ def get_ext_modules(): CUDAExtension( "torch_harmonics.attention._C", attention_sources, + include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")], extra_compile_args=get_compile_args("attention") ) ) @@ -220,6 +224,7 @@ def get_ext_modules(): CppExtension( "torch_harmonics.attention._C", attention_sources, + include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")], extra_compile_args=get_compile_args("attention") ) ) diff --git a/torch_harmonics/attention/csrc/attention.h b/torch_harmonics/attention/csrc/attention.h index 373d4494..28b29905 100644 --- a/torch_harmonics/attention/csrc/attention.h +++ b/torch_harmonics/attention/csrc/attention.h @@ -36,9 +36,4 @@ #include #include -#define CHECK_CPU_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCPU) -#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT_TENSOR(x) CHECK_CONTIGUOUS_TENSOR(x) -#define CHECK_CPU_INPUT_TENSOR(x) \ - CHECK_CPU_TENSOR(x); \ - CHECK_CONTIGUOUS_TENSOR(x) \ No newline at end of file +#include "cppmacro.h" diff --git a/torch_harmonics/attention/csrc/attention_cpu.h b/torch_harmonics/attention/csrc/attention_cpu.h index c61a987f..93008aab 100644 --- a/torch_harmonics/attention/csrc/attention_cpu.h +++ b/torch_harmonics/attention/csrc/attention_cpu.h @@ -34,6 +34,8 @@ #include #include +#include "cppmacro.h" + #define CACHE_BLOCK_SIZE (64) namespace attention_kernels { diff --git a/torch_harmonics/attention/csrc/attention_cuda.cuh b/torch_harmonics/attention/csrc/attention_cuda.cuh index 533b5938..a17c7559 100644 --- a/torch_harmonics/attention/csrc/attention_cuda.cuh +++ b/torch_harmonics/attention/csrc/attention_cuda.cuh @@ -34,11 +34,8 @@ #include #include -#define CHECK_CUDA_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCUDA) -#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_INTERNAL_ASSERT(x.is_contiguous(), #x " must be contiguous") -#define CHECK_CUDA_INPUT_TENSOR(x) \ - CHECK_CUDA_TENSOR(x); \ - CHECK_CONTIGUOUS_TENSOR(x) +#include "cudamacro.h" + namespace attention_kernels { diff --git a/torch_harmonics/attention/csrc/attention_cuda_bwd.cu b/torch_harmonics/attention/csrc/attention_cuda_bwd.cu index 21e0b28e..ec34a2a0 100644 --- a/torch_harmonics/attention/csrc/attention_cuda_bwd.cu +++ b/torch_harmonics/attention/csrc/attention_cuda_bwd.cu @@ -41,7 +41,7 @@ #include #include -#include "cudamacro.h" +//#include "cudamacro.h" #include "attention_cuda_utils.cuh" #include @@ -52,153 +52,8 @@ #define MAX_LOCAL_ARR_LEN (16) -namespace attention_kernels { - -#if 0 -class ScopeTimer -{ - public: - explicit ScopeTimer(const std::string &label = "") : - label_(label), start_(std::chrono::high_resolution_clock::now()) - { - } - - ~ScopeTimer() - { - auto end = std::chrono::high_resolution_clock::now(); - auto elapsed = std::chrono::duration_cast(end - start_); - std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl; - } - - private: - std::string label_; - std::chrono::high_resolution_clock::time_point start_; -}; - -// easier to understand version of manual shfl_xor_sync, performance appears similar -static __device__ float __warp_sum_cub(float val) -{ - // use cub to reduce within a warp - __shared__ typename cub::WarpReduce::TempStorage temp_storage; - - // 1. Compute sum (initially only in lane 0) - float sum = cub::WarpReduce(temp_storage).Sum(val); - // 2. Broadcast sum to all threads - sum = __shfl_sync(0xFFFFFFFF, sum, 0); - return sum; -} - -// This kernel computes the backward pass for the S2 attention mechanism, using -// shared memory as a cache and one warp per output point, warp-parallel over -// channels, which should be layed out in the fastest dimension for coalesced -// memory access. -template -__global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel( - int num_channels, int nlon_in, int nlat_out, int nlon_out, - const torch::PackedTensorAccessor32 kx, - const torch::PackedTensorAccessor32 vx, - const torch::PackedTensorAccessor32 qy, - const torch::PackedTensorAccessor32 dy, - torch::PackedTensorAccessor32 dydk, - torch::PackedTensorAccessor32 dydv, - torch::PackedTensorAccessor32 dydq, - const torch::PackedTensorAccessor64 psi_col_idx, - const torch::PackedTensorAccessor64 psi_row_offset, - const torch::PackedTensorAccessor32 quad_weights) -{ - extern __shared__ float sh[]; - float *sh_alpha_k = sh + threadIdx.y * num_channels * 5; - float *sh_alpha_vw = sh_alpha_k + num_channels; - float *sh_alpha_kvw = sh_alpha_vw + num_channels; - float *sh_dy = sh_alpha_kvw + num_channels; - float *sh_qy = sh_dy + num_channels; - // (optionally, could use more shared memory for other intermediates) - - const uint64_t batchId = blockIdx.y; - const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y; - if (wid >= uint64_t(nlat_out) * nlon_in) return; - const int tidx = threadIdx.x; - const int ho = wid / nlon_out; - const int wo = wid - (ho * nlon_out); - - // Zero shared memory - for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { - sh_alpha_k[chan] = 0.0f; - sh_alpha_vw[chan] = 0.0f; - sh_alpha_kvw[chan] = 0.0f; - sh_dy[chan] = dy[batchId][chan][ho][wo]; - sh_qy[chan] = qy[batchId][chan][ho][wo]; - } - float alpha_sum = 0.0f; - float qdotk_max = -FLT_MAX; - float integral = 0.0f; - __syncthreads(); - - const int64_t rbeg = psi_row_offset[ho]; - const int64_t rend = psi_row_offset[ho + 1]; - const int rlen = rend - rbeg; - - // 1st pass: accumulate alpha_sum, integral, and shared stats, along with a progressively computed qdotk_max. - for (int off = 0; off < rlen; off++) { - const int64_t col = psi_col_idx[rbeg + off]; - const int hi = col / nlon_in; - const int wi = col - (hi * nlon_in); - const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in; - float qdotk = 0.0f, gdotv = 0.0f; - for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { - qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip]; - gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip]; - } - qdotk = __warp_sum_cub(qdotk); - gdotv = __warp_sum_cub(gdotv); - float qdotk_max_tmp = max(qdotk_max, qdotk); - float alpha_inz = expf(qdotk - qdotk_max_tmp) * quad_weights[hi]; - float max_correction = expf(qdotk_max - qdotk_max_tmp); - alpha_sum = alpha_sum * max_correction + alpha_inz; - integral = integral * max_correction + alpha_inz * gdotv; - for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { - float kxval = kx[batchId][chan][hi][wip]; - sh_alpha_k[chan] = sh_alpha_k[chan] * max_correction + alpha_inz * kxval; - sh_alpha_vw[chan] = sh_alpha_vw[chan] * max_correction + alpha_inz * gdotv; - sh_alpha_kvw[chan] = sh_alpha_kvw[chan] * max_correction + alpha_inz * kxval * gdotv; - } - qdotk_max = qdotk_max_tmp; - } - - integral /= alpha_sum; - - // Write dydq - for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { - dydq[batchId][chan][ho][wo] - = (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum); - } - - // Third pass: accumulate gradients for k and v - for (int off = 0; off < rlen; off++) { - const int64_t col = psi_col_idx[rbeg + off]; - const int hi = col / nlon_in; - const int wi = col - (hi * nlon_in); - const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in; - float qdotk = 0.0f, gdotv = 0.0f; - for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { - qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip]; - gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip]; - } - qdotk = __warp_sum_cub(qdotk); - gdotv = __warp_sum_cub(gdotv); - float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi]; - for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { - float qyval = qy[batchId][chan][ho][wo]; - float dyval = sh_dy[chan]; - atomicAdd(&dydk[batchId][chan][hi][wip], qyval * (alpha_inz / alpha_sum) * (gdotv - integral)); - atomicAdd(&dydv[batchId][chan][hi][wip], (alpha_inz / alpha_sum) * dyval); - } - } -} -#endif - -// BEGIN backward kernels and functions +namespace attention_kernels { // called with (blockDim.x=32 and blockDim.y>1, BDIM=blockDim.x*blockDim.y) template #include -#include "cudamacro.h" +//#include "cudamacro.h" #include "attention_cuda_utils.cuh" #define THREADS (64) #define MAX_LOCAL_ARR_LEN (16) -// BEGIN - forward kernels and functions namespace attention_kernels { diff --git a/torch_harmonics/attention/csrc/attention_cuda_utils.cu b/torch_harmonics/attention/csrc/attention_cuda_utils.cu index f758fcae..b1e21571 100644 --- a/torch_harmonics/attention/csrc/attention_cuda_utils.cu +++ b/torch_harmonics/attention/csrc/attention_cuda_utils.cu @@ -39,7 +39,7 @@ #include #include -#include "cudamacro.h" +//#include "cudamacro.h" #include "attention_cuda.cuh" #define THREADS (64) diff --git a/torch_harmonics/attention/csrc/attention_cuda_utils.cuh b/torch_harmonics/attention/csrc/attention_cuda_utils.cuh index 81c75871..625e8131 100644 --- a/torch_harmonics/attention/csrc/attention_cuda_utils.cuh +++ b/torch_harmonics/attention/csrc/attention_cuda_utils.cuh @@ -34,9 +34,9 @@ #include #include +#include "cudamacro.h" + #define WARP_SIZE (32) -#define FULL_MASK (0xFFFFFFFF) -#define DIV_UP(a,b) (((a)+((b)-1))/(b)) namespace attention_kernels { diff --git a/torch_harmonics/disco/csrc/cudamacro.h b/torch_harmonics/disco/csrc/cudamacro.h deleted file mode 100644 index 0edef184..00000000 --- a/torch_harmonics/disco/csrc/cudamacro.h +++ /dev/null @@ -1,47 +0,0 @@ -// coding=utf-8 -// -// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. -// SPDX-License-Identifier: BSD-3-Clause -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -#pragma once - -#define CHECK_CUDA(call) { \ - cudaError_t err = call; \ - if( cudaSuccess != err) { \ - fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \ - __FILE__, __LINE__, cudaGetErrorString( err) ); \ - exit(EXIT_FAILURE); \ - }} - -#define CHECK_ERROR(errorMessage) { \ - cudaError_t err = cudaGetLastError(); \ - if( cudaSuccess != err) { \ - fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \ - errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) );\ - exit(EXIT_FAILURE); \ - }} diff --git a/torch_harmonics/disco/csrc/disco.h b/torch_harmonics/disco/csrc/disco.h index 7198ad8c..b20ccff9 100644 --- a/torch_harmonics/disco/csrc/disco.h +++ b/torch_harmonics/disco/csrc/disco.h @@ -35,10 +35,3 @@ #include #include #include - -#define CHECK_CPU_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCPU) -#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT_TENSOR(x) CHECK_CONTIGUOUS_TENSOR(x) -#define CHECK_CPU_INPUT_TENSOR(x) \ - CHECK_CPU_TENSOR(x); \ - CHECK_CONTIGUOUS_TENSOR(x) diff --git a/torch_harmonics/disco/csrc/disco_cpu.h b/torch_harmonics/disco/csrc/disco_cpu.h index 921ff5ae..96c768a6 100644 --- a/torch_harmonics/disco/csrc/disco_cpu.h +++ b/torch_harmonics/disco/csrc/disco_cpu.h @@ -32,6 +32,8 @@ #include "disco.h" +#include "cppmacro.h" + #define CACHE_BLOCK_SIZE (64) namespace disco_kernels { diff --git a/torch_harmonics/disco/csrc/disco_cuda.cuh b/torch_harmonics/disco/csrc/disco_cuda.cuh index 2d342a24..64d2d6cd 100644 --- a/torch_harmonics/disco/csrc/disco_cuda.cuh +++ b/torch_harmonics/disco/csrc/disco_cuda.cuh @@ -35,16 +35,7 @@ #include #include -#define CHECK_CUDA_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCUDA) -#define CHECK_CUDA_INPUT_TENSOR(x) \ - CHECK_CUDA_TENSOR(x); \ - CHECK_CONTIGUOUS_TENSOR(x) - -// will come from ../../attention/csrc/attention_cuda_utils.cuh -#ifndef DIV_UP -#define DIV_UP(a, b) (((a) + ((b)-1)) / (b)) -#endif - +#include "cudamacro.h" #define MIN_THREADS (64) #define ELXTH_MAX (32) diff --git a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu index 288e5a71..874e9b9d 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu @@ -31,7 +31,6 @@ #include "disco.h" #include "disco_cuda.cuh" #include "csr_cuda.cuh" -#include "cudamacro.h" #define THREADS (64) diff --git a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu index 7cdd2683..f769b308 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu @@ -28,7 +28,7 @@ // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -#include "cudamacro.h" +//#include "cudamacro.h" #include "disco.h" #include "disco_cuda.cuh" #include "csr_cuda.cuh" diff --git a/torch_harmonics/disco/csrc/disco_helpers.cpp b/torch_harmonics/disco/csrc/disco_helpers.cpp index 1737cec2..67a4a583 100644 --- a/torch_harmonics/disco/csrc/disco_helpers.cpp +++ b/torch_harmonics/disco/csrc/disco_helpers.cpp @@ -28,9 +28,11 @@ // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -#include "disco.h" #include +#include "disco.h" +#include "cppmacro.h" + template void preprocess_psi_kernel(int64_t nnz, int64_t K, int64_t Ho, int64_t *ker_h, int64_t *row_h, int64_t *col_h, int64_t *roff_h, REAL_T *val_h, int64_t &nrows) diff --git a/torch_harmonics/attention/csrc/cudamacro.h b/torch_harmonics/utils/csrc/cppmacro.h similarity index 63% rename from torch_harmonics/attention/csrc/cudamacro.h rename to torch_harmonics/utils/csrc/cppmacro.h index 0edef184..3dccace8 100644 --- a/torch_harmonics/attention/csrc/cudamacro.h +++ b/torch_harmonics/utils/csrc/cppmacro.h @@ -30,18 +30,11 @@ #pragma once -#define CHECK_CUDA(call) { \ - cudaError_t err = call; \ - if( cudaSuccess != err) { \ - fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \ - __FILE__, __LINE__, cudaGetErrorString( err) ); \ - exit(EXIT_FAILURE); \ - }} +#include -#define CHECK_ERROR(errorMessage) { \ - cudaError_t err = cudaGetLastError(); \ - if( cudaSuccess != err) { \ - fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \ - errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) );\ - exit(EXIT_FAILURE); \ - }} +#define CHECK_CPU_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCPU, #x " must be on CPU") +#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_INTERNAL_ASSERT(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT_TENSOR(x) CHECK_CONTIGUOUS_TENSOR(x) +#define CHECK_CPU_INPUT_TENSOR(x) \ + CHECK_CPU_TENSOR(x); \ + CHECK_CONTIGUOUS_TENSOR(x) diff --git a/torch_harmonics/utils/csrc/csr_cuda.cu b/torch_harmonics/utils/csrc/csr_cuda.cu index f6f0f5b1..ba0edfde 100644 --- a/torch_harmonics/utils/csrc/csr_cuda.cu +++ b/torch_harmonics/utils/csrc/csr_cuda.cu @@ -28,10 +28,6 @@ // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -//#include -//#include -//#include #include #include diff --git a/torch_harmonics/utils/csrc/csr_cuda.cuh b/torch_harmonics/utils/csrc/csr_cuda.cuh index 51fcb5a9..a9bf6cca 100644 --- a/torch_harmonics/utils/csrc/csr_cuda.cuh +++ b/torch_harmonics/utils/csrc/csr_cuda.cuh @@ -30,21 +30,15 @@ #pragma once -// -//#include -//#include - #include #include #include #include +#include "cudamacro.h" + #define WARP_SIZE (32) -#define FULL_MASK (0xFFFFFFFF) -#ifndef DIV_UP -#define DIV_UP(a,b) (((a)+((b)-1))/(b)) -#endif namespace utility_kernels { diff --git a/torch_harmonics/utils/csrc/cudamacro.h b/torch_harmonics/utils/csrc/cudamacro.h index 0edef184..1dcb8227 100644 --- a/torch_harmonics/utils/csrc/cudamacro.h +++ b/torch_harmonics/utils/csrc/cudamacro.h @@ -30,6 +30,11 @@ #pragma once +#include + +#define DIV_UP(a,b) (((a)+((b)-1))/(b)) +#define FULL_MASK (0xFFFFFFFF) + #define CHECK_CUDA(call) { \ cudaError_t err = call; \ if( cudaSuccess != err) { \ @@ -45,3 +50,9 @@ errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) );\ exit(EXIT_FAILURE); \ }} + +#define CHECK_CUDA_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCUDA, #x " must be on GPU") +#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_INTERNAL_ASSERT(x.is_contiguous(), #x " must be contiguous") +#define CHECK_CUDA_INPUT_TENSOR(x) \ + CHECK_CUDA_TENSOR(x); \ + CHECK_CONTIGUOUS_TENSOR(x) diff --git a/torch_harmonics/utils/csrc/permute_cuda.cu b/torch_harmonics/utils/csrc/permute_cuda.cu index eb23ea4f..4cba299c 100644 --- a/torch_harmonics/utils/csrc/permute_cuda.cu +++ b/torch_harmonics/utils/csrc/permute_cuda.cu @@ -28,10 +28,6 @@ // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -//#include -//#include -//#include #include #include diff --git a/torch_harmonics/utils/csrc/permute_cuda.cuh b/torch_harmonics/utils/csrc/permute_cuda.cuh index 768b84db..5052810b 100644 --- a/torch_harmonics/utils/csrc/permute_cuda.cuh +++ b/torch_harmonics/utils/csrc/permute_cuda.cuh @@ -30,21 +30,13 @@ #pragma once -// -//#include -//#include - #include #include #include #include #define WARP_SIZE (32) -#define FULL_MASK (0xFFFFFFFF) -#ifndef DIV_UP -#define DIV_UP(a,b) (((a)+((b)-1))/(b)) -#endif namespace utility_kernels { From cfc533847e452b506c1924172e4040e046523633 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Tue, 28 Oct 2025 07:52:41 -0700 Subject: [PATCH 29/31] removing prints --- torch_harmonics/attention/csrc/attention.h | 2 -- .../distributed/distributed_convolution.py | 31 +------------------ 2 files changed, 1 insertion(+), 32 deletions(-) diff --git a/torch_harmonics/attention/csrc/attention.h b/torch_harmonics/attention/csrc/attention.h index 28b29905..b20ccff9 100644 --- a/torch_harmonics/attention/csrc/attention.h +++ b/torch_harmonics/attention/csrc/attention.h @@ -35,5 +35,3 @@ #include #include #include - -#include "cppmacro.h" diff --git a/torch_harmonics/distributed/distributed_convolution.py b/torch_harmonics/distributed/distributed_convolution.py index 5786d5e7..d0b5cd41 100644 --- a/torch_harmonics/distributed/distributed_convolution.py +++ b/torch_harmonics/distributed/distributed_convolution.py @@ -33,14 +33,11 @@ from itertools import accumulate import torch -import torch.nn as nn - -from functools import partial from torch_harmonics.disco._disco_utils import _get_psi, _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch from torch_harmonics.disco._disco_utils import _disco_s2_contraction_optimized, _disco_s2_transpose_contraction_optimized from torch_harmonics.utils import permute_to_0231, permute_to_0312 -from disco_helpers import optimized_kernels_is_available, preprocess_psi +from disco_helpers import preprocess_psi from torch_harmonics.disco.convolution import ( _precompute_convolution_tensor_s2, DiscreteContinuousConv, @@ -210,14 +207,9 @@ def __init__( merge_quadrature=True, ) - print(f"{self.comm_rank_polar} before splitting shapes idx = {idx.shape}, vals = {vals.shape}") - print(f"{self.comm_rank_polar} before splitting \n ker = {idx[0]}\n row = {idx[1]}\n col = {idx[2]}", flush=True) - # split the convolution tensor along latitude idx, vals = _split_distributed_convolution_tensor_s2(idx, vals, in_shape) - #print(f"{self.comm_rank_polar} after splitting \n ker = {idx[0]}\n row = {idx[1]}\n col = {idx[2]}", flush=True) - # sort the values ker_idx = idx[0, ...].contiguous() row_idx = idx[1, ...].contiguous() @@ -235,9 +227,6 @@ def __init__( self.register_buffer("psi_col_idx", col_idx, persistent=False) self.register_buffer("psi_vals", vals, persistent=False) - print(f"{self.comm_rank_polar} after splitting sorted \n ker_idx = {self.psi_ker_idx}\n row = {self.psi_row_idx}\n col = {self.psi_col_idx}", flush=True) - print(f"{self.comm_rank_polar} roff_idx = {self.psi_roff_idx}", flush=True) - # store psi jic: if not self.optimized_kernel: self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, self.nlat_in_local, self.nlat_out_local) @@ -254,20 +243,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # store number of channels num_chans = x.shape[1] - print(f"{self.comm_rank_polar} input shape", x.shape) - # h and w is split. First we make w local by transposing into channel dim if self.comm_size_azimuth > 1: x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes) - print(f"{self.comm_rank_polar} after azimuth transpose forward", x.shape) - if self.optimized_kernel: # permute input: B, C, Hi, Wi -> B, Hi, Wi, C xp = permute_to_0231(x) - print(f"{self.comm_rank_polar} before disco contraction", xp.shape) - # disco contraction: B, Hi, Wi, C -> B, Ho, Wo, C, K x = _disco_s2_contraction_optimized( xp, @@ -295,25 +278,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: azimuth_dim = -1 chan_dim = -4 - print(f"{self.comm_rank_polar} after disco contraction", x.shape) - - print(f"{self.comm_rank_polar} polar dim", polar_dim) - print(f"{self.comm_rank_polar} azimuth dim", azimuth_dim) - print(f"{self.comm_rank_polar} chan dim", chan_dim) - # perform reduce scatter in polar region x = reduce_from_polar_region(x) x = scatter_to_polar_region(x, polar_dim) - print(f"{self.comm_rank_polar} after polar scatter", x.shape) - # now we can transpose back the result, so that lon is split and channels are local if self.comm_size_azimuth > 1: chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth) x = distributed_transpose_azimuth.apply(x, (azimuth_dim, chan_dim), chan_shapes) - print(f"{self.comm_rank_polar} after azimuth transpose inverse", x.shape, num_chans, azimuth_dim, chan_dim) - # extract shape if self.optimized_kernel: # weight multiplication @@ -331,8 +304,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size)) out = out.reshape(B, -1, H, W).contiguous() - print(f"{self.comm_rank_polar} after weight multiplication", out.shape) - if self.bias is not None: out = out + self.bias.reshape(1, -1, 1, 1) From 48269710b6fc76708dcbe21cfa688de4da205eed Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 29 Oct 2025 01:26:24 -0700 Subject: [PATCH 30/31] further cleanup --- torch_harmonics/disco/csrc/disco_cuda_bwd.cu | 4 +--- torch_harmonics/disco/csrc/disco_cuda_fwd.cu | 5 +---- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu index 874e9b9d..1790398d 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu @@ -34,7 +34,7 @@ #define THREADS (64) -#define MAX_LOCAL_ARR_LEN (16) +#define MAX_LOCAL_ARR_LEN (32) namespace disco_kernels { @@ -784,8 +784,6 @@ static void s2_disco_bwd_dispatch(int64_t batch_size, at::Tensor val_dat, // CSR non-empty value data at::Tensor yP) { - static_assert(0 == (MAX_LOCAL_ARR_LEN & (MAX_LOCAL_ARR_LEN-1))); - if (batch_size <= 0 || nchans <= 0 || nlon_in <= 0 || diff --git a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu index f769b308..7e338ddc 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu @@ -35,7 +35,7 @@ #define THREADS (64) -#define MAX_LOCAL_ARR_LEN (16) +#define MAX_LOCAL_ARR_LEN (32) namespace disco_kernels { @@ -576,7 +576,6 @@ void s2_disco_fwd_special_vec_k(const int nchan_in, // no. of input float (no constexpr int NLOC_M1 = NLOC-1; - //constexpr int VEC_SIZE = sizeof(FLOATV_T) / sizeof(float); const int tidx = threadIdx.x; const int tidy = threadIdx.y; @@ -734,8 +733,6 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, at::Tensor val_dat, // CSR non-empty value data at::Tensor yP) { - static_assert(0 == (MAX_LOCAL_ARR_LEN & (MAX_LOCAL_ARR_LEN-1))); - if (batch_size <= 0 || nchan_in <= 0 || nlon_in <= 0 || From 09e8ebf85df9c680395df3cbc84cc2eabb015190 Mon Sep 17 00:00:00 2001 From: Mauro Bisson Date: Thu, 30 Oct 2025 11:02:11 -0700 Subject: [PATCH 31/31] Increased max no. of element per thread to 20 to both fwd and bwd and changes bwd to limit the unroll length in processCSR. --- torch_harmonics/disco/csrc/disco_cuda_bwd.cu | 40 ++++++++++++++++---- torch_harmonics/disco/csrc/disco_cuda_fwd.cu | 20 ++++++++-- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu index 874e9b9d..216ae2fe 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu @@ -34,7 +34,7 @@ #define THREADS (64) -#define MAX_LOCAL_ARR_LEN (16) +#define MAX_LOCAL_ARR_LEN (20) namespace disco_kernels { @@ -457,6 +457,7 @@ static __device__ void processCSR_Kpow2_reg_d(const int wi, unsigned int subwarp_id = threadIdx.y % (WARP_SIZE/BDIM_X); subwarp_mask = MASK << (subwarp_id*BDIM_X); } + constexpr int MAX_POW2_K = (BDIM_X < WARP_SIZE) ? BDIM_X : WARP_SIZE; // K is a power of two <= BDIM_X const int log2_K = __popc(K-1); @@ -500,13 +501,13 @@ static __device__ void processCSR_Kpow2_reg_d(const int wi, // K is a power of two <= 32 #pragma unroll - for(int j = 1; j < BDIM_X; j *= 2) { + for(int j = 1; j < MAX_POW2_K; j *= 2) { if (j >= K) break; #pragma unroll for(int i = 0; i < NLOC; i++) { - locy[i] += __shfl_xor_sync(subwarp_mask, locy[i], j, BDIM_X); + locy[i] += __shfl_xor_sync(subwarp_mask, locy[i], j, MAX_POW2_K); } } @@ -678,8 +679,9 @@ void s2_disco_bwd_special_vec_k(int nchans, // no. of input float (not FLOATV __shared__ float *shYOffAll[BDIM_Y][BDIM_X+PAD]; // check if BDIM_X is a multiple of K; since BDIM_X is a power of 2, check if K is also a power of two - if (!(K & K-1) && K <= BDIM_X) { processCSR_Kpow2_reg_d(wi, rlen, nchans, nlon_out, pscale, K, locx, col_idx, val_pck, shYOffAll[tidy], NULL, y); } - else { processCSR_Kanyv_reg_d(wi, rlen, nchans, nlon_out, pscale, K, locx, col_idx, val_pck, shYOffAll[tidy], shy, y); } + constexpr int MAX_POW2_K = (BDIM_X < WARP_SIZE) ? BDIM_X : WARP_SIZE; + if (!(K & K-1) && K <= MAX_POW2_K) { processCSR_Kpow2_reg_d(wi, rlen, nchans, nlon_out, pscale, K, locx, col_idx, val_pck, shYOffAll[tidy], NULL, y); } + else { processCSR_Kanyv_reg_d(wi, rlen, nchans, nlon_out, pscale, K, locx, col_idx, val_pck, shYOffAll[tidy], shy, y); } return; @@ -709,7 +711,18 @@ void launch_gen_disco_bwd(int64_t batch_size, size_t shsize = (sizeof(FLOATV_T)*(nchans*K) + sizeof(float)*nchans)*block.y; const int pscale = nlon_out / nlon_in; - +#if 0 + printf("Launching s2_disco_bwd_generic_vec_k<%d, float%s><<<(%d,%d), (%d,%d)..., ..., %zu, ...>>> with:\n" + "\tnchan_out: %ld\n" + "\tK: %ld\n" + "\tpscale: %d\n" + "\tnlat_in: %ld\n" + "\tnlon_in: %ld\n" + "\tnlat_out: %ld\n" + "\tnlon_out: %ld\n\n", + THREADS, sizeof(FLOATV_T)==16?"4":"", grid.x, grid.y, block.x, block.y, shsize, nchans, K, pscale, + nlat_in, nlon_in, nlat_out, nlon_out); +#endif // will use only the first 1/K-th of the CSR, i.e. only the first nlat_out rows s2_disco_bwd_generic_vec_k <<>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out, pscale, K, @@ -752,7 +765,18 @@ void launch_spc_disco_bwd(int nloc, // "BDIM_X*nloc" >= nchans size_t shsize = (K & (K-1)) ? sizeof(float)*DIV_UP(nchans, BDIM_X)*BDIM_X*block.y : 0; const int pscale = nlon_out / nlon_in; - +#if 0 + printf("Launching s2_disco_bwd_special_vec_k<%d, %d, %d, float%s><<<(%d, %d), (%d, %d), ..., %zu, ...>>> with:\n" + "\tnchans: %ld\n" + "\tK: %ld\n" + "\tpscale: %d\n" + "\tnlat_in: %ld\n" + "\tnlon_in: %ld\n" + "\tnlat_out: %ld\n" + "\tnlon_in: %ld\n\n", + BDIM_X, BDIM_Y, CUR_LOC_SIZE, sizeof(FLOATV_T)==16?"4":"", grid.x, grid.y, block.x, block.y, shsize, nchans, K, pscale, + nlat_in, nlon_in, nlat_out, nlon_out); +#endif s2_disco_bwd_special_vec_k <<>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out, pscale, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp); @@ -784,7 +808,7 @@ static void s2_disco_bwd_dispatch(int64_t batch_size, at::Tensor val_dat, // CSR non-empty value data at::Tensor yP) { - static_assert(0 == (MAX_LOCAL_ARR_LEN & (MAX_LOCAL_ARR_LEN-1))); + //static_assert(0 == (MAX_LOCAL_ARR_LEN & (MAX_LOCAL_ARR_LEN-1))); if (batch_size <= 0 || nchans <= 0 || diff --git a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu index f769b308..c9fdc349 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu @@ -35,7 +35,7 @@ #define THREADS (64) -#define MAX_LOCAL_ARR_LEN (16) +#define MAX_LOCAL_ARR_LEN (20) namespace disco_kernels { @@ -660,7 +660,13 @@ void launch_gen_disco_fwd(int64_t batch_size, size_t shsize = sizeof(FLOATV_T)*(nchan_in*K)*block.y; const int pscale = nlon_in / nlon_out; - +#if 0 + printf("Launching s2_disco_fwd_generic_vec_k<%d, float%s><<<..., ..., %zu, ...>>> with:\n" + "\tnchan_in: %ld\n" + "\tK: %ld\n" + "\tpscale: %d\n\n", + THREADS, sizeof(FLOATV_T)==16?"4":"", shsize, nchan_in, K, pscale); +#endif // will use only the first 1/K-th of the CSR, i.e. only the first nlat_out rows s2_disco_fwd_generic_vec_k <<>>(nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, pscale, K, @@ -702,7 +708,13 @@ void launch_spc_disco_fwd(int nloc, // "BDIM_X*nloc" >= nchans size_t shsize = 0; //sizeof(float)*chxgrp_out * block.y; const int pscale = nlon_in / nlon_out; - +#if 0 + printf("Launching s2_disco_fwd_special_vec_k<%d, %d, %d, float%s><<<(%d, %d), (%d, %d), ..., %zu, ...>>> with:\n" + "\tnchan_in: %ld\n" + "\tK: %ld\n" + "\tpscale: %d\n\n", + BDIM_X, BDIM_Y, CUR_LOC_SIZE, sizeof(FLOATV_T)==16?"4":"", grid.x, grid.y, block.x, block.y, shsize, nchan_in, K, pscale); +#endif s2_disco_fwd_special_vec_k <<>>(nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, pscale, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp); @@ -734,7 +746,7 @@ static void s2_disco_fwd_dispatch(int64_t batch_size, at::Tensor val_dat, // CSR non-empty value data at::Tensor yP) { - static_assert(0 == (MAX_LOCAL_ARR_LEN & (MAX_LOCAL_ARR_LEN-1))); + //static_assert(0 == (MAX_LOCAL_ARR_LEN & (MAX_LOCAL_ARR_LEN-1))); if (batch_size <= 0 || nchan_in <= 0 ||