diff --git a/setup.py b/setup.py index b10a7e2e..841b865e 100644 --- a/setup.py +++ b/setup.py @@ -93,31 +93,77 @@ def get_helpers_compile_args(): def get_ext_modules(): """Get list of extension modules to compile.""" + # define setup dir + setup_dir = os.path.abspath(os.path.dirname(__file__)) + ext_modules = [] cmdclass = {} print(f"Compiling helper routines for torch-harmonics.") + + # Utility helpers + ext_modules.append( + CppExtension( + "utility_helpers", + [ + "torch_harmonics/utils/csrc/utils_helpers.cpp", + ], + extra_compile_args=get_helpers_compile_args(), + ) + ) + + # DISCO helpers ext_modules.append( CppExtension( "disco_helpers", [ "torch_harmonics/disco/csrc/disco_helpers.cpp", ], + include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")], extra_compile_args=get_helpers_compile_args(), ) ) + # Attention helpers ext_modules.append( CppExtension( "attention_helpers", [ "torch_harmonics/attention/csrc/attention_helpers.cpp", ], + include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")], extra_compile_args=get_helpers_compile_args(), ) ) if BUILD_CPP: + # HELPERS + utility_sources = [ + "torch_harmonics/utils/csrc/utils_interface.cpp", + "torch_harmonics/utils/csrc/permute_cpu.cpp", + ] + + if BUILD_CUDA: + print(f"Compiling custom CUDA kernels for torch-harmonics.") + utility_sources.extend([ + "torch_harmonics/utils/csrc/permute_cuda.cu", + ]) + ext_modules.append( + CUDAExtension( + "torch_harmonics.utils._C", + utility_sources, + extra_compile_args=get_compile_args("utils") + ) + ) + else: + ext_modules.append( + CppExtension( + "torch_harmonics.utils._C", + utility_sources, + extra_compile_args=get_compile_args("utils") + ) + ) + # DISCO # Create a single extension that includes both CPU and CUDA code disco_sources = [ @@ -128,6 +174,7 @@ def get_ext_modules(): if BUILD_CUDA: print(f"Compiling custom CUDA kernels for torch-harmonics.") disco_sources.extend([ + "torch_harmonics/utils/csrc/csr_cuda.cu", "torch_harmonics/disco/csrc/disco_cuda_fwd.cu", "torch_harmonics/disco/csrc/disco_cuda_bwd.cu", ]) @@ -135,6 +182,7 @@ def get_ext_modules(): CUDAExtension( "torch_harmonics.disco._C", disco_sources, + include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")], extra_compile_args=get_compile_args("disco") ) ) @@ -143,10 +191,10 @@ def get_ext_modules(): CppExtension( "torch_harmonics.disco._C", disco_sources, + include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")], extra_compile_args=get_compile_args("disco") ) ) - cmdclass["build_ext"] = BuildExtension # ATTENTION # Create a single extension that includes both CPU and CUDA code @@ -167,6 +215,7 @@ def get_ext_modules(): CUDAExtension( "torch_harmonics.attention._C", attention_sources, + include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")], extra_compile_args=get_compile_args("attention") ) ) @@ -175,9 +224,12 @@ def get_ext_modules(): CppExtension( "torch_harmonics.attention._C", attention_sources, + include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")], extra_compile_args=get_compile_args("attention") ) ) + + # set cmdclass cmdclass["build_ext"] = BuildExtension return ext_modules, cmdclass diff --git a/tests/test_attention.py b/tests/test_attention.py index 67364485..98a930b6 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -54,8 +54,8 @@ # CPU results normalized to 16 OpenMP threads, # GPU results normalized to V100 16 GB GPU # this is just to detect performance regressions, not for absolute performance -_perf_test_thresholds = {"cpu": {"fwd_ms": 1000, "bwd_ms": 8000}, - "cuda": {"fwd_ms": 50, "bwd_ms": 150}} +_perf_test_thresholds = {"cpu": {"fwd_ms": 800, "bwd_ms": 6000}, + "cuda": {"fwd_ms": 10, "bwd_ms": 30}} _run_perf_tests = (os.getenv("TORCH_HARMONICS_RUN_PERF_TESTS", "0") == "1") @@ -326,12 +326,12 @@ def test_optimized_pt2_compatibility(self, batch_size, channels, heads, in_shape @parameterized.expand( [ # self attention - [1, 256, 1, (91, 180), (91, 180), "equiangular", "equiangular", 1e-5, 1e-5], + [1, 256, 1, (91, 180), (91, 180), "equiangular", "equiangular", None], ], skip_on_empty=True, ) @unittest.skipUnless(optimized_kernels_is_available() and _run_perf_tests, "skipping performance test because optimized kernels are not available or perf tests are disabled") - def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=False): + def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, theta_cutoff, verbose=False): if (self.device.type == "cuda") and (not cuda_kernels_is_available()): raise unittest.SkipTest("skipping test because CUDA kernels are not available") @@ -350,7 +350,8 @@ def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, g att_optimized = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads, in_shape=in_shape, out_shape=out_shape, - grid_in=grid_in, grid_out=grid_out, bias=True, + grid_in=grid_in, grid_out=grid_out, + theta_cutoff=theta_cutoff, bias=True, optimized_kernel=True).to(self.device) # random weights diff --git a/tests/test_cache.py b/tests/test_cache.py index 3a2d1720..747a5c70 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -34,8 +34,11 @@ import math import torch +from torch_harmonics.cache import lru_cache +from testutils import compare_tensors class TestCacheConsistency(unittest.TestCase): + def test_consistency(self, verbose=False): if verbose: print("Testing that cache values does not get modified externally") @@ -47,7 +50,43 @@ def test_consistency(self, verbose=False): # perform in-place modification of leg1 leg1 *= -1.0 leg2 = _precompute_legpoly(10, 10, cost) - self.assertFalse(torch.allclose(leg1, leg2)) + self.assertFalse(compare_tensors("legendre", leg2, leg1, verbose=verbose)) + + + def test_pytorch_tensors(self, verbose=False): + if verbose: + print("Testing that PyTorch tensors are cached") + + @lru_cache(typed=True, copy=True) + def test_func(tens1, tens2): + return tens1, tens2 + + # initial tensors + tens1 = torch.randn(4, 4, dtype=torch.float32) + tens2 = torch.randn(4, 4, dtype=torch.float32) + + # retrieve from cache + tens1c, tens2c = test_func(tens1, tens2) + + # modify copies + tens1c *= -1.0 + tens2c *= -1.0 + + # retrieve from cache again + tens1cc, tens2cc = test_func(tens1, tens2) + + if verbose: + print("first tensor", tens1) + print("first tensor after modification", tens1c) + print("first tensor cached", tens1cc) + print("second tensor", tens2) + print("second tensor after modification", tens2c) + print("second tensor cached", tens2cc) + + self.assertFalse(compare_tensors("first cached", tens1cc, tens1c, verbose=verbose)) + self.assertFalse(compare_tensors("second cached", tens2cc, tens2c, verbose=verbose)) + self.assertTrue(compare_tensors("first raw", tens1, tens1cc, verbose=verbose)) + self.assertTrue(compare_tensors("second raw", tens2, tens2cc, verbose=verbose)) if __name__ == "__main__": diff --git a/tests/test_convolution.py b/tests/test_convolution.py index a4261614..d612c9d2 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -39,8 +39,13 @@ from torch.library import opcheck from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 -from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes +from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longitudes from torch_harmonics.disco import cuda_kernels_is_available, optimized_kernels_is_available +from disco_helpers import preprocess_psi +from torch_harmonics.filter_basis import get_filter_basis +from torch_harmonics.disco.convolution import _precompute_convolution_tensor_s2 + +from testutils import compare_tensors if not optimized_kernels_is_available(): print(f"Warning: Couldn't import optimized disco convolution kernels") @@ -183,6 +188,99 @@ def setUp(self): if self.device.type == "cuda": torch.cuda.manual_seed(333) + @parameterized.expand( + [ + # piecewise linear + # normal isotropic + [(16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular"], + [(17, 32), (17, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular"], + # normal anisotropic + [(16, 32), (16, 32), (3, 4), "piecewise linear", "mean", "equiangular", "equiangular"], + [(16, 32), (16, 32), (3, 2), "piecewise linear", "mean", "equiangular", "equiangular"], + # downsampling isotropic + [(16, 32), (8, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular"], + [(17, 32), (9, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular"], + # downsampling anisotropic + [(16, 32), (8, 16), (3, 4), "piecewise linear", "mean", "equiangular", "equiangular"], + [(16, 32), (8, 16), (3, 2), "piecewise linear", "mean", "equiangular", "equiangular"], + # morlet + # normal isotropic + [(16, 32), (16, 32), (1), "morlet", "mean", "equiangular", "equiangular"], # important for attention + [(16, 32), (16, 32), (3), "morlet", "mean", "equiangular", "equiangular"], + [(17, 32), (17, 32), (3), "morlet", "mean", "equiangular", "equiangular"], + # normal anisotropic + [(16, 32), (16, 32), (3, 4), "morlet", "mean", "equiangular", "equiangular"], + [(16, 32), (16, 32), (3, 2), "morlet", "mean", "equiangular", "equiangular"], + # downsampling isotropic + [(16, 32), (8, 16), (1), "morlet", "mean", "equiangular", "equiangular"], # important for attention + [(16, 32), (8, 16), (3), "morlet", "mean", "equiangular", "equiangular"], + [(17, 32), (9, 16), (3), "morlet", "mean", "equiangular", "equiangular"], + # downsampling anisotropic + [(16, 32), (8, 16), (3, 4), "morlet", "mean", "equiangular", "equiangular"], + [(16, 32), (8, 16), (3, 2), "morlet", "mean", "equiangular", "equiangular"], + # zernike + # normal + [(16, 32), (16, 32), (1), "zernike", "mean", "equiangular", "equiangular"], + [(16, 32), (16, 32), (3, 3), "zernike", "mean", "equiangular", "equiangular"], + [(17, 32), (17, 32), (3, 3), "zernike", "mean", "equiangular", "equiangular"], + # downsampling + [(16, 32), (8, 16), (1), "zernike", "mean", "equiangular", "equiangular"], + [(16, 32), (8, 16), (3, 3), "zernike", "mean", "equiangular", "equiangular"], + [(17, 32), (9, 16), (3, 3), "zernike", "mean", "equiangular", "equiangular"], + ], + skip_on_empty=True, + ) + def test_convolution_tensor_integrity(self, in_shape, out_shape, kernel_shape, basis_type, basis_norm_mode, grid_in, grid_out, verbose=False): + + nlat_in, nlon_in = in_shape + nlat_out, nlon_out = out_shape + + filter_basis = get_filter_basis(kernel_shape=kernel_shape, basis_type=basis_type) + + # use default value cutoff + theta_cutoff = torch.pi / float(nlat_out - 1) + + idx, vals, _ = _precompute_convolution_tensor_s2( + in_shape=in_shape, + out_shape=out_shape, + filter_basis=filter_basis, + grid_in=grid_in, + grid_out=grid_out, + theta_cutoff=theta_cutoff, + transpose_normalization=False, + basis_norm_mode=basis_norm_mode, + merge_quadrature=True, + ) + + ker_idx = idx[0, ...].contiguous() + row_idx = idx[1, ...].contiguous() + col_idx = idx[2, ...].contiguous() + vals = vals.contiguous() + + # sort values + roff_idx = preprocess_psi(filter_basis.kernel_size, nlat_out, ker_idx, row_idx, col_idx, vals).contiguous() + + # check shapes + self.assertTrue(ker_idx.shape[0] == row_idx.shape[0], f"ker_idx and row_idx have to have the same shape: found {ker_idx.shape[0]} and {row_idx.shape[0]}") + self.assertTrue(ker_idx.shape[0] == col_idx.shape[0], f"ker_idx and col_idx have to have the same shape: found {ker_idx.shape[0]} and {col_idx.shape[0]}") + self.assertTrue(ker_idx.shape[0] == vals.shape[0], f"ker_idx and vals have to have the same shape: found {ker_idx.shape[0]} and {vals.shape[0]}") + self.assertTrue((roff_idx.shape[0] - 1) == filter_basis.kernel_size * nlat_out, f"roff_idx has to have shape: found {(roff_idx.shape[0] - 1)} and {filter_basis.kernel_size * nlat_out}") + + # the multiplicitiy in ker_idx has to be the same for all kernel indices + unique, counts = torch.unique(ker_idx, return_counts=True) + self.assertTrue(torch.all(counts.max() == counts), f"The multiplicity in ker_idx has to be the same for all kernel indices: found {counts} for entries {unique}") + + if verbose: + print(f"\n ker_idx = {ker_idx},\n row_idx = {row_idx},\n col_idx = {col_idx}") + + # the following has to be true: the row_idx and col_idx have to be the same for all kernel indices + row_idx_ref = row_idx[ker_idx == 0] + col_idx_ref = col_idx[ker_idx == 0] + for k in range(1, filter_basis.kernel_size): + self.assertTrue(torch.all(row_idx_ref == row_idx[ker_idx == k]), f"The row_idx has to be the same for all kernel indices: found {row_idx_ref} for entries {ker_idx == k}") + self.assertTrue(torch.all(col_idx_ref == col_idx[ker_idx == k]), f"The row_idx has to be the same for all kernel indices: found {col_idx_ref} for entries {ker_idx == k}") + + @parameterized.expand( [ # regular convolution @@ -203,7 +301,7 @@ def setUp(self): [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False], [8, 4, 2, (12, 24), (24, 48), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False], [8, 4, 2, (12, 24), (24, 48), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False], - [8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", True, 1e-4, False], + [8, 2, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", True, 1e-4, False], [8, 4, 2, (12, 24), (24, 48), (2, 1), "morlet", "mean", "equiangular", "equiangular", True, 1e-4, False], [8, 4, 2, (12, 24), (24, 48), (3), "zernike", "mean", "equiangular", "equiangular", True, 1e-4, False], [8, 4, 2, (8, 8), (16, 24), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False], @@ -241,11 +339,6 @@ def test_sparse_against_dense( nlat_in, nlon_in = in_shape nlat_out, nlon_out = out_shape - if isinstance(kernel_shape, int): - theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat_in - 1) - else: - theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1) - Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2 conv = Conv( in_channels, @@ -259,7 +352,7 @@ def test_sparse_against_dense( grid_in=grid_in, grid_out=grid_out, bias=False, - theta_cutoff=theta_cutoff, + theta_cutoff=None, optimized_kernel=use_optimized_kernels, ).to(self.device) @@ -272,7 +365,7 @@ def test_sparse_against_dense( filter_basis, grid_in=grid_out, grid_out=grid_in, - theta_cutoff=theta_cutoff, + theta_cutoff=conv.theta_cutoff, transpose_normalization=transpose, basis_norm_mode=basis_norm_mode, merge_quadrature=True, @@ -280,7 +373,7 @@ def test_sparse_against_dense( psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_in, conv.nlat_out * conv.nlon_out)).to_dense() - self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_in, nlat_out * nlon_out))) + self.assertTrue(compare_tensors("psi", psi, psi_dense[:, :, 0].reshape(-1, nlat_in, nlat_out * nlon_out))) else: psi_dense = _precompute_convolution_tensor_dense( in_shape, @@ -288,7 +381,7 @@ def test_sparse_against_dense( filter_basis, grid_in=grid_in, grid_out=grid_out, - theta_cutoff=theta_cutoff, + theta_cutoff=conv.theta_cutoff, transpose_normalization=transpose, basis_norm_mode=basis_norm_mode, merge_quadrature=True, @@ -296,7 +389,7 @@ def test_sparse_against_dense( psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_out, conv.nlat_in * conv.nlon_in)).to_dense() - self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in))) + self.assertTrue(compare_tensors("psi", psi, psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in))) # create a copy of the weight w_ref = torch.empty_like(conv.weight) @@ -327,11 +420,11 @@ def test_sparse_against_dense( x_ref_grad = x_ref.grad.clone() # compare results - self.assertTrue(torch.allclose(y, y_ref, rtol=tol, atol=tol)) + self.assertTrue(compare_tensors("output", y, y_ref, rtol=tol, atol=tol, verbose=verbose)) # compare - self.assertTrue(torch.allclose(x_grad, x_ref_grad, rtol=tol, atol=tol)) - self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad, rtol=tol, atol=tol)) + self.assertTrue(compare_tensors("input gradient", x_grad, x_ref_grad, rtol=tol, atol=tol, verbose=verbose)) + self.assertTrue(compare_tensors("weight gradient", conv.weight.grad, w_ref.grad.unsqueeze(0), rtol=tol, atol=tol, verbose=verbose)) @parameterized.expand( @@ -380,13 +473,7 @@ def test_optimized_against_torch( if verbose: print(f"Testing DISCO convolution on {in_shape[0]}x{in_shape[1]} {grid_in} grid to {out_shape[0]}x{out_shape[1]} {grid_out} grid on {self.device.type} device") - nlat_in, nlon_in = in_shape - nlat_out, nlon_out = out_shape - - if isinstance(kernel_shape, int): - theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat_in - 1) - else: - theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1) + nlat_in, _ = in_shape Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2 @@ -402,7 +489,7 @@ def test_optimized_against_torch( grid_in=grid_in, grid_out=grid_out, bias=False, - theta_cutoff=theta_cutoff, + theta_cutoff=None, optimized_kernel=False, ).to(self.device) @@ -418,7 +505,7 @@ def test_optimized_against_torch( grid_in=grid_in, grid_out=grid_out, bias=False, - theta_cutoff=theta_cutoff, + theta_cutoff=None, optimized_kernel=True, ).to(self.device) @@ -443,11 +530,11 @@ def test_optimized_against_torch( inp_grad_opt = inp.grad.clone() # compare results - self.assertTrue(torch.allclose(out_naive, out_opt, rtol=tol, atol=tol)) + self.assertTrue(compare_tensors("output", out_naive, out_opt, rtol=tol, atol=tol, verbose=verbose)) # compare - self.assertTrue(torch.allclose(inp_grad_naive, inp_grad_opt, rtol=tol, atol=tol)) - self.assertTrue(torch.allclose(conv_naive.weight.grad, conv_opt.weight.grad, rtol=tol, atol=tol)) + self.assertTrue(compare_tensors("input gradient", inp_grad_naive, inp_grad_opt, rtol=tol, atol=tol, verbose=verbose)) + self.assertTrue(compare_tensors("weight gradient", conv_naive.weight.grad, conv_opt.weight.grad, rtol=tol, atol=tol, verbose=verbose)) @parameterized.expand( @@ -465,11 +552,6 @@ def test_device_instantiation(self, batch_size, in_channels, out_channels, in_sh nlat_in, nlon_in = in_shape nlat_out, nlon_out = out_shape - if isinstance(kernel_shape, int): - theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat_in - 1) - else: - theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1) - # get handle Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2 @@ -486,7 +568,7 @@ def test_device_instantiation(self, batch_size, in_channels, out_channels, in_sh grid_in=grid_in, grid_out=grid_out, bias=False, - theta_cutoff=theta_cutoff, + theta_cutoff=None, ) #torch.set_default_device(self.device) @@ -503,7 +585,7 @@ def test_device_instantiation(self, batch_size, in_channels, out_channels, in_sh grid_in=grid_in, grid_out=grid_out, bias=False, - theta_cutoff=theta_cutoff, + theta_cutoff=None, ) # since we specified the device specifier everywhere, it should always @@ -517,10 +599,10 @@ def test_device_instantiation(self, batch_size, in_channels, out_channels, in_sh @parameterized.expand( [ - [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False], - [8, 4, 2, (16, 32), (8, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False], - [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False], - [8, 4, 2, (8, 16), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False], + [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, False], + [8, 4, 2, (16, 32), (8, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, False], + [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, False], + [8, 4, 2, (8, 16), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, False], ], skip_on_empty=True, @@ -539,7 +621,6 @@ def test_optimized_pt2_compatibility( grid_in, grid_out, transpose, - tol, verbose, ): """Tests whether the optimized kernels are PyTorch 2 compatible""" @@ -550,13 +631,7 @@ def test_optimized_pt2_compatibility( if verbose: print(f"Testing DISCO convolution on {in_shape[0]}x{in_shape[1]} {grid_in} grid to {out_shape[0]}x{out_shape[1]} {grid_out} grid on {self.device.type} device") - nlat_in, nlon_in = in_shape - nlat_out, nlon_out = out_shape - - if isinstance(kernel_shape, int): - theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat_in - 1) - else: - theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1) + nlat_in, _ = in_shape Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2 conv = Conv( @@ -571,14 +646,14 @@ def test_optimized_pt2_compatibility( grid_in=grid_in, grid_out=grid_out, bias=False, - theta_cutoff=theta_cutoff, + theta_cutoff=None, ).to(self.device) # forward test if not transpose: - inp = torch.randn(batch_size, in_channels, *in_shape, device=self.device) + inp = torch.randn(batch_size, *in_shape, in_channels, device=self.device) else: - inp = torch.randn(batch_size, conv.kernel_size, in_channels, *in_shape, device=self.device) + inp = torch.randn(batch_size, *in_shape, in_channels, conv.kernel_size, device=self.device) test_inputs = (inp, conv.psi_roff_idx, conv.psi_ker_idx, conv.psi_row_idx, conv.psi_col_idx, conv.psi_vals, conv.kernel_size, conv.nlat_out, conv.nlon_out) @@ -602,7 +677,7 @@ def test_optimized_pt2_compatibility( @parameterized.expand( [ - [8, 4, 2, (91, 180), (91, 180), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4], + #[8, 4, 2, (91, 180), (91, 180), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4], ], skip_on_empty=True, ) @@ -615,11 +690,6 @@ def test_perf(self, batch_size, in_channels, out_channels, in_shape, out_shape, nlat_in, nlon_in = in_shape nlat_out, nlon_out = out_shape - if isinstance(kernel_shape, int): - theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat_in - 1) - else: - theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1) - # get handle Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2 @@ -636,7 +706,7 @@ def test_perf(self, batch_size, in_channels, out_channels, in_shape, out_shape, grid_in=grid_in, grid_out=grid_out, bias=True, - theta_cutoff=theta_cutoff, + theta_cutoff=None, optimized_kernel=True, ).to(self.device) diff --git a/tests/test_distributed_convolution.py b/tests/test_distributed_convolution.py index c364ac5a..2dd30ecb 100644 --- a/tests/test_distributed_convolution.py +++ b/tests/test_distributed_convolution.py @@ -38,6 +38,17 @@ import torch.distributed as dist import torch_harmonics as th import torch_harmonics.distributed as thd +from torch_harmonics.disco import optimized_kernels_is_available + +from disco_helpers import preprocess_psi +from torch_harmonics.filter_basis import get_filter_basis +from torch_harmonics.disco.convolution import _precompute_convolution_tensor_s2 +from torch_harmonics.distributed.distributed_convolution import _split_distributed_convolution_tensor_s2 + +from testutils import compare_tensors + +if not optimized_kernels_is_available(): + print(f"Warning: Couldn't import optimized disco convolution kernels") class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): @@ -128,6 +139,9 @@ def _split_helper(self, tensor): tensor_list_local = thd.split_tensor_along_dim(tensor_local, dim=-2, num_chunks=self.grid_size_h) tensor_local = tensor_list_local[self.hrank] + # make contiguous + tensor_local = tensor_local.contiguous() + return tensor_local def _gather_helper_fwd(self, tensor, B, C, convolution_dist): @@ -156,6 +170,9 @@ def _gather_helper_fwd(self, tensor, B, C, convolution_dist): dist.all_gather(olist, tensor_gather, group=self.h_group) tensor_gather = torch.cat(olist, dim=-2) + # make contiguous + tensor_gather = tensor_gather.contiguous() + return tensor_gather def _gather_helper_bwd(self, tensor, B, C, convolution_dist): @@ -165,6 +182,7 @@ def _gather_helper_bwd(self, tensor, B, C, convolution_dist): lon_shapes = convolution_dist.lon_in_shapes # gather in W + tensor = tensor.contiguous() if self.grid_size_w > 1: gather_shapes = [(B, C, lat_shapes[self.hrank], w) for w in lon_shapes] olist = [torch.empty(shape, dtype=tensor.dtype, device=tensor.device) for shape in gather_shapes] @@ -175,6 +193,7 @@ def _gather_helper_bwd(self, tensor, B, C, convolution_dist): tensor_gather = tensor # gather in H + tensor_gather = tensor_gather.contiguous() if self.grid_size_h > 1: gather_shapes = [(B, C, h, convolution_dist.nlon_in) for h in lat_shapes] olist = [torch.empty(shape, dtype=tensor_gather.dtype, device=tensor_gather.device) for shape in gather_shapes] @@ -182,31 +201,139 @@ def _gather_helper_bwd(self, tensor, B, C, convolution_dist): dist.all_gather(olist, tensor_gather, group=self.h_group) tensor_gather = torch.cat(olist, dim=-2) + # make contiguous + tensor_gather = tensor_gather.contiguous() + return tensor_gather @parameterized.expand( [ + # piecewise linear + # normal isotropic + [(16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular"], + [(17, 32), (17, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular"], + # normal anisotropic + [(16, 32), (16, 32), (3, 4), "piecewise linear", "mean", "equiangular", "equiangular"], + [(16, 32), (16, 32), (3, 2), "piecewise linear", "mean", "equiangular", "equiangular"], + # downsampling isotropic + [(16, 32), (8, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular"], + [(17, 32), (9, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular"], + # downsampling anisotropic + [(16, 32), (8, 16), (3, 4), "piecewise linear", "mean", "equiangular", "equiangular"], + [(16, 32), (8, 16), (3, 2), "piecewise linear", "mean", "equiangular", "equiangular"], + # morlet + # normal isotropic + [(16, 32), (16, 32), (1), "morlet", "mean", "equiangular", "equiangular"], # important for attention + [(16, 32), (16, 32), (3), "morlet", "mean", "equiangular", "equiangular"], + [(17, 32), (17, 32), (3), "morlet", "mean", "equiangular", "equiangular"], + # normal anisotropic + [(16, 32), (16, 32), (3, 4), "morlet", "mean", "equiangular", "equiangular"], + [(16, 32), (16, 32), (3, 2), "morlet", "mean", "equiangular", "equiangular"], + # downsampling isotropic + [(16, 32), (8, 16), (1), "morlet", "mean", "equiangular", "equiangular"], # important for attention + [(16, 32), (8, 16), (3), "morlet", "mean", "equiangular", "equiangular"], + [(17, 32), (9, 16), (3), "morlet", "mean", "equiangular", "equiangular"], + # downsampling anisotropic + [(16, 32), (8, 16), (3, 4), "morlet", "mean", "equiangular", "equiangular"], + [(16, 32), (8, 16), (3, 2), "morlet", "mean", "equiangular", "equiangular"], + # zernike + # normal + [(16, 32), (16, 32), (1), "zernike", "mean", "equiangular", "equiangular"], + [(16, 32), (16, 32), (3, 3), "zernike", "mean", "equiangular", "equiangular"], + [(17, 32), (17, 32), (3, 3), "zernike", "mean", "equiangular", "equiangular"], + # downsampling + [(16, 32), (8, 16), (1), "zernike", "mean", "equiangular", "equiangular"], + [(16, 32), (8, 16), (3, 3), "zernike", "mean", "equiangular", "equiangular"], + [(17, 32), (9, 16), (3, 3), "zernike", "mean", "equiangular", "equiangular"], + ], + skip_on_empty=True, + ) + def test_distributed_convolution_tensor_integrity(self, in_shape, out_shape, kernel_shape, basis_type, basis_norm_mode, grid_in, grid_out, verbose=False): + + nlat_in, _ = in_shape + nlat_out, _ = out_shape + + # get filter basis + filter_basis = get_filter_basis(kernel_shape=kernel_shape, basis_type=basis_type) + + # use default value for cutoff + theta_cutoff = torch.pi / float(nlat_out - 1) + + idx, vals, _ = _precompute_convolution_tensor_s2( + in_shape=in_shape, + out_shape=out_shape, + filter_basis=filter_basis, + grid_in=grid_in, + grid_out=grid_out, + theta_cutoff=theta_cutoff, + transpose_normalization=False, + basis_norm_mode=basis_norm_mode, + merge_quadrature=True, + ) + + # split tensor along latitude + idx, vals = _split_distributed_convolution_tensor_s2(idx, vals, in_shape) + + # define contiguous tensors + ker_idx = idx[0, ...].contiguous() + row_idx = idx[1, ...].contiguous() + col_idx = idx[2, ...].contiguous() + vals = vals.contiguous() + + # sort values + roff_idx = preprocess_psi(filter_basis.kernel_size, nlat_out, ker_idx, row_idx, col_idx, vals).contiguous() + + print(f"{self.hrank} after splitting sorted \n ker = {ker_idx}\n row = {row_idx}\n col = {col_idx}", flush=True) + print(f"{self.hrank} roff_idx = {roff_idx}", flush=True) + print(f"{self.hrank} ker_shape = {ker_idx.shape}, row_shape = {row_idx.shape}, col_shape = {col_idx.shape}, vals_shape = {vals.shape}", flush=True) + + # check shapes + self.assertTrue(ker_idx.shape[0] == row_idx.shape[0], f"ker_idx and row_idx have to have the same shape: found {ker_idx.shape[0]} and {row_idx.shape[0]}") + self.assertTrue(ker_idx.shape[0] == col_idx.shape[0], f"ker_idx and col_idx have to have the same shape: found {ker_idx.shape[0]} and {col_idx.shape[0]}") + self.assertTrue(ker_idx.shape[0] == vals.shape[0], f"ker_idx and vals have to have the same shape: found {ker_idx.shape[0]} and {vals.shape[0]}") + + # the multiplicitiy in ker_idx has to be the same for all kernel indices + unique, counts = torch.unique(ker_idx, return_counts=True) + self.assertTrue(torch.all(counts.max() == counts), f"The multiplicity in ker_idx has to be the same for all kernel indices: found {counts} for entries {unique}") + + if verbose: + print(f"\n ker_idx = {ker_idx},\n row_idx = {row_idx},\n col_idx = {col_idx}") + + # the following has to be true: the row_idx and col_idx have to be the same for all kernel indices + row_idx_ref = row_idx[ker_idx == 0] + col_idx_ref = col_idx[ker_idx == 0] + for k in range(1, filter_basis.kernel_size): + self.assertTrue(torch.all(row_idx_ref == row_idx[ker_idx == k]), f"The row_idx has to be the same for all kernel indices: found {row_idx_ref} for entries {ker_idx == k}") + self.assertTrue(torch.all(col_idx_ref == col_idx[ker_idx == k]), f"The row_idx has to be the same for all kernel indices: found {col_idx_ref} for entries {ker_idx == k}") + + + @parameterized.expand( + [ + # Forward [128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5], [129, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5], [128, 256, 128, 256, 32, 8, (3, 2), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5], [128, 256, 64, 128, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5], [128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 2, "equiangular", "equiangular", False, 1e-5], [128, 256, 128, 256, 32, 6, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5], + [129, 256, 129, 256, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", False, 1e-5], + [129, 256, 65, 128, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", False, 1e-5], + # Transpose [128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5], [129, 256, 129, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5], [128, 256, 128, 256, 32, 8, (3, 2), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5], [64, 128, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5], [128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 2, "equiangular", "equiangular", True, 1e-5], [128, 256, 128, 256, 32, 6, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5], - [129, 256, 129, 256, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", False, 1e-5], [129, 256, 129, 256, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", True, 1e-5], [65, 128, 129, 256, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", True, 1e-5], - [129, 256, 65, 128, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", False, 1e-5], - ] + ], + skip_on_empty=True, ) def test_distributed_disco_conv( - self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, basis_type, basis_norm_mode, groups, grid_in, grid_out, transpose, tol + self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, basis_type, basis_norm_mode, groups, grid_in, grid_out, transpose, tol, verbose=True ): + verbose = verbose and (self.world_rank == 0) B, C, H, W = batch_size, num_chan, nlat_in, nlon_in @@ -222,21 +349,22 @@ def test_distributed_disco_conv( grid_in=grid_in, grid_out=grid_out, bias=True, + optimized_kernel=True, ) # set up handles if transpose: - conv_local = th.DiscreteContinuousConvTransposeS2(**disco_args).to(self.device) + conv_full = th.DiscreteContinuousConvTransposeS2(**disco_args).to(self.device) conv_dist = thd.DistributedDiscreteContinuousConvTransposeS2(**disco_args).to(self.device) else: - conv_local = th.DiscreteContinuousConvS2(**disco_args).to(self.device) + conv_full = th.DiscreteContinuousConvS2(**disco_args).to(self.device) conv_dist = thd.DistributedDiscreteContinuousConvS2(**disco_args).to(self.device) # copy the weights from the local conv into the dist conv with torch.no_grad(): - conv_dist.weight.copy_(conv_local.weight) + conv_dist.weight.copy_(conv_full.weight) if disco_args["bias"]: - conv_dist.bias.copy_(conv_local.bias) + conv_dist.bias.copy_(conv_full.bias) # create tensors inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device) @@ -244,7 +372,7 @@ def test_distributed_disco_conv( # local conv # FWD pass inp_full.requires_grad = True - out_full = conv_local(inp_full) + out_full = conv_full(inp_full) # create grad for backward with torch.no_grad(): @@ -257,12 +385,14 @@ def test_distributed_disco_conv( # distributed conv # FWD pass - inp_local = self._split_helper(inp_full) + with torch.no_grad(): + inp_local = self._split_helper(inp_full) inp_local.requires_grad = True out_local = conv_dist(inp_local) # BWD pass - ograd_local = self._split_helper(ograd_full) + with torch.no_grad(): + ograd_local = self._split_helper(ograd_full) out_local = conv_dist(inp_local) out_local.backward(ograd_local) igrad_local = inp_local.grad.clone() @@ -270,19 +400,12 @@ def test_distributed_disco_conv( # evaluate FWD pass with torch.no_grad(): out_gather_full = self._gather_helper_fwd(out_local, B, C, conv_dist) - err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2))) - if self.world_rank == 0: - print(f"final relative error of output: {err.item()}") - self.assertTrue(err.item() <= tol) + self.assertTrue(compare_tensors("output", out_gather_full, out_full, rtol=tol, atol=tol, verbose=verbose)) # evaluate BWD pass with torch.no_grad(): igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, conv_dist) - - err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2))) - if self.world_rank == 0: - print(f"final relative error of gradients: {err.item()}") - self.assertTrue(err.item() <= tol) + self.assertTrue(compare_tensors("input gradient", igrad_gather_full, igrad_full, rtol=tol, atol=tol, verbose=verbose)) if __name__ == "__main__": diff --git a/tests/test_permute.py b/tests/test_permute.py new file mode 100644 index 00000000..b9cd82c6 --- /dev/null +++ b/tests/test_permute.py @@ -0,0 +1,116 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import os +import unittest +from parameterized import parameterized, parameterized_class + + +import torch +from torch.library import opcheck +from torch_harmonics.utils import permute_to_0231, permute_to_0312 + +_devices = [(torch.device("cpu"),)] +if torch.cuda.is_available(): + _devices.append((torch.device("cuda"),)) + +@parameterized_class(("device"), _devices) +class TestPermutation(unittest.TestCase): + """Test the optimized convolution module (CPU/CUDA if available).""" + + def setUp(self): + torch.manual_seed(333) + if self.device.type == "cuda": + torch.cuda.manual_seed(333) + + @parameterized.expand( + [ + [8, 8, 16, 32, "0231"], + [8, 1, 16, 32, "0231"], + [1, 8, 16, 32, "0231"], + [8, 8, 16, 32, "0312"], + [8, 1, 16, 32, "0312"], + [1, 8, 16, 32, "0312"], + ], + skip_on_empty=True, + ) + def test_permutation( + self, batch_size, channels, nlat, nlon, mode + ): + # create input + if mode == "0231": + inp = torch.randn(batch_size, channels, nlat, nlon, device=self.device) + permute_fn = permute_to_0231 + permute_shape = (0, 2, 3, 1) + else: + inp = torch.randn(batch_size, nlat, nlon, channels, device=self.device) + permute_fn = permute_to_0312 + permute_shape = (0, 3, 1, 2) + inp.requires_grad = True + + # forward test + out_opt = permute_fn(inp) + out_naive = inp.permute(*permute_shape).contiguous().clone() + self.assertTrue(torch.allclose(out_opt, out_naive)) + + # backward test + ograd = torch.randn_like(out_opt) + out_opt.backward(ograd) + igrad_opt = inp.grad.clone() + inp.grad = None + out_naive.backward(ograd) + igrad_naive = inp.grad.clone() + self.assertTrue(torch.allclose(igrad_opt, igrad_naive)) + + + @parameterized.expand( + [ + [8, 8, 16, 32, "0231"], + [8, 8, 16, 32, "0312"], + ], + skip_on_empty=True, + ) + def test_pt2_compatibility(self, batch_size, channels, nlat, nlon, mode): + + if mode == "0231": + inp = torch.randn(batch_size, channels, nlat, nlon, device=self.device) + permute_fn = torch.ops.utility_kernels.permute_to_0231 + else: + inp = torch.randn(batch_size, nlat, nlon, channels, device=self.device) + permute_fn = torch.ops.utility_kernels.permute_to_0312 + + test_inputs = (inp, ) + + opcheck(permute_fn, test_inputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/testutils.py b/tests/testutils.py new file mode 100644 index 00000000..ef7ef0c0 --- /dev/null +++ b/tests/testutils.py @@ -0,0 +1,44 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import torch + +def compare_tensors(msg, tensor, tensor_ref, rtol=1e-8, atol=1e-5, verbose=False): + allclose = torch.allclose(tensor, tensor_ref, rtol=rtol, atol=atol) + if (not allclose) and verbose: + diff = torch.abs(tensor - tensor_ref) + print(f"{msg} absolute tensor diff: min = {torch.min(diff)}, mean = {torch.mean(diff)}, max = {torch.max(diff)}.") + reldiff = diff / torch.abs(tensor_ref) + print(f"{msg} relative tensor diff: min = {torch.min(reldiff)}, mean = {torch.mean(reldiff)}, max = {torch.max(reldiff)}.") + # find element with maximum difference + index = torch.argmax(diff) + print(f"{msg} element {index} with maximum difference: value = {tensor.flatten()[index]}, reference value = {tensor_ref.flatten()[index]}, diff = {diff.flatten()[index]}.") + return allclose diff --git a/torch_harmonics/attention/__init__.py b/torch_harmonics/attention/__init__.py index ee69cb83..4fea72b9 100644 --- a/torch_harmonics/attention/__init__.py +++ b/torch_harmonics/attention/__init__.py @@ -39,6 +39,6 @@ from torch.ops import attention_kernels else: attention_kernels = None - warnings.warn("No optimized kernels are available. Please compile the extension first setting BUILD_CPP and BUILD_CUDA to 1.") + warnings.warn("No optimized attention kernels are available. Please compile the extension first setting BUILD_CPP and BUILD_CUDA to 1.") from .attention import AttentionS2, NeighborhoodAttentionS2 diff --git a/torch_harmonics/attention/_attention_utils.py b/torch_harmonics/attention/_attention_utils.py index a8c1b58a..633b6607 100644 --- a/torch_harmonics/attention/_attention_utils.py +++ b/torch_harmonics/attention/_attention_utils.py @@ -35,6 +35,7 @@ import torch import torch.nn.functional as F from attention_helpers import optimized_kernels_is_available +from torch_harmonics.utils import permute_to_0231, permute_to_0312 from . import attention_kernels # HELPER ROUTINE FOR BACKWARD setup_context @@ -54,7 +55,8 @@ def _setup_context_attention_backward(ctx, inputs, output): def _(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - out_shape = (kw.shape[0], vw.shape[1], nlat_out, nlon_out) + # the raw kernel uses channels last format + out_shape = (kw.shape[0], nlat_out, nlon_out, vw.shape[3]) return torch.empty(out_shape, dtype=kw.dtype, device=kw.device) # raw backward fake @@ -62,6 +64,7 @@ def _(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, def _(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, grad_output: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, nlon_in: int, nlat_out: int, nlon_out: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # the raw kernel uses channels last format dk = torch.empty_like(kw) dv = torch.empty_like(vw) dq = torch.empty_like(qw) @@ -87,6 +90,11 @@ def _neighborhood_s2_attention_optimized(k: torch.Tensor, v: torch.Tensor, q: to B, _, H, W = qw.shape qw = qw.reshape(B*nh, -1, H, W) + # permute to 0231 + kw = permute_to_0231(kw) + vw = permute_to_0231(vw) + qw = permute_to_0231(qw) + # convert to float32 inp_dtype = kw.dtype kw = kw.to(torch.float32).contiguous() @@ -97,12 +105,18 @@ def _neighborhood_s2_attention_optimized(k: torch.Tensor, v: torch.Tensor, q: to col_idx, row_off, nlon_in, nlat_out, nlon_out) - _, C, H, W = output.shape - output = output.reshape(B, -1, H, W) + #_, H, W, C = output.shape + #output = output.reshape(-1, H, W, C) # convert back precision output = output.to(dtype=inp_dtype) + # permute back to 0312 + output = permute_to_0312(output) + + # fold heads back into channel dimension + output = output.reshape(B, -1, H, W) + return output @torch.library.register_fake("attention_kernels::_neighborhood_s2_attention_optimized") @@ -111,6 +125,7 @@ def _(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + # the wrapped kernel uses channels first format out_shape = (k.shape[0], wv.shape[0], nlat_out, nlon_out) return torch.empty(out_shape, dtype=k.dtype, device=k.device) @@ -147,21 +162,31 @@ def _neighborhood_s2_attention_bwd_optimized(ctx, grad_output): B, _, H, W = grad_output.shape grad_output = grad_output.reshape(B*nh, -1, H, W) - # save type and convert to float32 + # permute to 0231 + kw = permute_to_0231(kw.contiguous()) + vw = permute_to_0231(vw.contiguous()) + qw = permute_to_0231(qw.contiguous()) + grad_output = permute_to_0231(grad_output.contiguous()) + + # save type and convert to float32 kw_dtype = kw.dtype vw_dtype = vw.dtype qw_dtype = qw.dtype - - kw = kw.to(torch.float32).contiguous() - vw = vw.to(torch.float32).contiguous() - qw = qw.to(torch.float32).contiguous() - grad_output = grad_output.to(torch.float32).contiguous() + kw = kw.to(torch.float32) + vw = vw.to(torch.float32) + qw = qw.to(torch.float32) + grad_output = grad_output.to(torch.float32) dkw, dvw, dqw = attention_kernels.backward.default(kw, vw, qw, grad_output, quad_weights, col_idx, row_off, nlon_in, nlat_out, nlon_out) + # permute back to 0312 + dkw = permute_to_0312(dkw) + dvw = permute_to_0312(dvw) + dqw = permute_to_0312(dqw) + # weight grads _, C, H, W = dkw.shape dkw = dkw.reshape(B, -1, H, W) diff --git a/torch_harmonics/attention/attention.py b/torch_harmonics/attention/attention.py index 1d76286f..0110b63d 100644 --- a/torch_harmonics/attention/attention.py +++ b/torch_harmonics/attention/attention.py @@ -353,6 +353,7 @@ def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value self.nlon_out, ) + # compute the output out = nn.functional.conv2d(out, self.proj_weights, bias=self.proj_bias) return out diff --git a/torch_harmonics/attention/csrc/attention.h b/torch_harmonics/attention/csrc/attention.h index 373d4494..b20ccff9 100644 --- a/torch_harmonics/attention/csrc/attention.h +++ b/torch_harmonics/attention/csrc/attention.h @@ -35,10 +35,3 @@ #include #include #include - -#define CHECK_CPU_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCPU) -#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT_TENSOR(x) CHECK_CONTIGUOUS_TENSOR(x) -#define CHECK_CPU_INPUT_TENSOR(x) \ - CHECK_CPU_TENSOR(x); \ - CHECK_CONTIGUOUS_TENSOR(x) \ No newline at end of file diff --git a/torch_harmonics/attention/csrc/attention_cpu.h b/torch_harmonics/attention/csrc/attention_cpu.h index f96e7a7b..93008aab 100644 --- a/torch_harmonics/attention/csrc/attention_cpu.h +++ b/torch_harmonics/attention/csrc/attention_cpu.h @@ -34,6 +34,8 @@ #include #include +#include "cppmacro.h" + #define CACHE_BLOCK_SIZE (64) namespace attention_kernels { @@ -50,6 +52,8 @@ namespace attention_kernels { const int64_t nlon_in, const int64_t nlat_out, const int64_t nlon_out, const int64_t batch_size, const int64_t nchannels_in, const int64_t nchannels_out) { + // IMPORTANT: all input tensors are in channels last format! + // some parameters const int64_t block_wo = CACHE_BLOCK_SIZE; const int64_t nblock_wo = static_cast((nlon_out + block_wo - 1) / block_wo); @@ -95,7 +99,7 @@ namespace attention_kernels { float qdotk = 0.0; //#pragma omp simd reduction(+:qdotk) for (int64_t ci = 0; ci < nchannels_in; ci++) { - qdotk += static_cast(qy_arr[b][ci][ho][wo] * kx_arr[b][ci][hi][wip]); + qdotk += static_cast(qy_arr[b][ho][wo][ci] * kx_arr[b][hi][wip][ci]); } // update tmp max @@ -106,7 +110,7 @@ namespace attention_kernels { alpha_sum[wo-wo_start] = alpha + alpha_sum[wo-wo_start] * std::exp(qdotk_max[wo-wo_start] - qdotk_max_tmp); // update output - y_tmp[wo-wo_start] = y_tmp[wo-wo_start] * std::exp(qdotk_max[wo-wo_start] - qdotk_max_tmp) + alpha * static_cast(vx_arr[b][co][hi][wip]); + y_tmp[wo-wo_start] = y_tmp[wo-wo_start] * std::exp(qdotk_max[wo-wo_start] - qdotk_max_tmp) + alpha * static_cast(vx_arr[b][hi][wip][co]); // define new max qdotk_max[wo-wo_start] = qdotk_max_tmp; @@ -115,7 +119,7 @@ namespace attention_kernels { // update output for (int64_t wo = wo_start; wo < wo_end; wo++) { - y_arr[b][co][ho][wo] = static_cast(y_tmp[wo-wo_start] / alpha_sum[wo-wo_start]); + y_arr[b][ho][wo][co] = static_cast(y_tmp[wo-wo_start] / alpha_sum[wo-wo_start]); } } } @@ -138,6 +142,8 @@ namespace attention_kernels { const int64_t nlon_in, const int64_t nlat_out, const int64_t nlon_out, const int64_t batch_size, const int64_t nchannels_in, const int64_t nchannels_out) { + // IMPORTANT: all input tensors are in channels last format! + // compute dqy and dkx #pragma omp parallel for collapse(2) for (int64_t b = 0; b < batch_size; b++) { @@ -176,7 +182,7 @@ namespace attention_kernels { // compute correlation & softmax numerator qdotk_nz[idz-zstart] = 0.0; for (int64_t cit = 0; cit < nchannels_in; cit++) { - qdotk_nz[idz-zstart] += qy_arr[b][cit][ho][wo] * kx_arr[b][cit][hi][wip]; + qdotk_nz[idz-zstart] += qy_arr[b][ho][wo][cit] * kx_arr[b][hi][wip][cit]; } // tmp max and discount @@ -190,16 +196,16 @@ namespace attention_kernels { // dkx: input dot float gdotv = 0.0; for (int64_t cot = 0; cot < nchannels_out; cot++) { - gdotv += dy_arr[b][cot][ho][wo] * vx_arr[b][cot][hi][wip]; + gdotv += dy_arr[b][ho][wo][cot] * vx_arr[b][hi][wip][cot]; } float alpha_gdotv_tmp = alpha_nz[idz-zstart] * gdotv; alpha_gdotv = alpha_gdotv_tmp + alpha_gdotv * discount; // dqy: alpha_k - alpha_k = alpha_nz[idz-zstart] * kx_arr[b][ci][hi][wip] + alpha_k * discount; + alpha_k = alpha_nz[idz-zstart] * kx_arr[b][hi][wip][ci] + alpha_k * discount; // dqy: alpha_k_gdotv - alpha_k_gdotv = alpha_gdotv_tmp * kx_arr[b][ci][hi][wip] + alpha_k_gdotv * discount; + alpha_k_gdotv = alpha_gdotv_tmp * kx_arr[b][hi][wip][ci] + alpha_k_gdotv * discount; // define new max qdotk_max = qdotk_max_tmp; @@ -211,7 +217,7 @@ namespace attention_kernels { alpha_k_gdotv = alpha_k_gdotv / alpha_sum; // dqy: update - dqy_arr[b][ci][ho][wo] = (alpha_k_gdotv - alpha_gdotv * alpha_k); + dqy_arr[b][ho][wo][ci] = (alpha_k_gdotv - alpha_gdotv * alpha_k); for (int64_t idz = zstart; idz < zend; idz++) { int64_t nz_col_idx = col_idx_arr[idz]; @@ -228,11 +234,11 @@ namespace attention_kernels { // dkx: input dot float gdotv = 0.0; for (int64_t cot = 0; cot < nchannels_out; cot++) { - gdotv += dy_arr[b][cot][ho][wo] * vx_arr[b][cot][hi][wip]; + gdotv += dy_arr[b][ho][wo][cot] * vx_arr[b][hi][wip][cot]; } // dkx: update - dkx_arr[b][ci][hi][wip] += qy_arr[b][ci][ho][wo] * alpha_norm * (gdotv - alpha_gdotv); + dkx_arr[b][hi][wip][ci] += qy_arr[b][ho][wo][ci] * alpha_norm * (gdotv - alpha_gdotv); } } } @@ -270,7 +276,7 @@ namespace attention_kernels { // compute correlation & softmax numerator qdotk_nz[idz-zstart] = 0.0; for (int64_t ci = 0; ci < nchannels_in; ci++) { - qdotk_nz[idz-zstart] += qy_arr[b][ci][ho][wo] * kx_arr[b][ci][hi][wip]; + qdotk_nz[idz-zstart] += qy_arr[b][ho][wo][ci] * kx_arr[b][hi][wip][ci]; } // tmp max and discount @@ -296,7 +302,7 @@ namespace attention_kernels { // recompute alpha float alpha_norm = std::exp(qdotk_nz[idz-zstart] - qdotk_max) * quad_weights_arr[hi] / alpha_sum; - dvx_arr[b][co][hi][wip] += alpha_norm * dy_arr[b][co][ho][wo]; + dvx_arr[b][hi][wip][co] += alpha_norm * dy_arr[b][ho][wo][co]; } } } diff --git a/torch_harmonics/attention/csrc/attention_cpu_bwd.cpp b/torch_harmonics/attention/csrc/attention_cpu_bwd.cpp index f7ec9b6d..0297a79b 100644 --- a/torch_harmonics/attention/csrc/attention_cpu_bwd.cpp +++ b/torch_harmonics/attention/csrc/attention_cpu_bwd.cpp @@ -38,16 +38,16 @@ std::tuple s2_attention_bwd_cpu(tor torch::Tensor quad_weights, torch::Tensor col_idx, torch::Tensor row_off, int64_t nlon_in, int64_t nlat_out, int64_t nlon_out) { - // shapes: + // shapes: all channels LAST! // input - // kx: B, C, Hi, Wi - // vx: B, C, Hi, Wi - // qy: B, C, Ho, Wo + // kx: B, Hi, Wi, C + // vx: B, Hi, Wi, C + // qy: B, Ho, Wo, C // quad_weights: Hi // output - // dkx: B, C, Hi, Wi - // dvx: B, C, Hi, Wi - // dqy: B, C, Ho, Wo + // dkx: B, Hi, Wi, C + // dvx: B, Hi, Wi, C + // dqy: B, Ho, Wo, C // sanity checks CHECK_CPU_INPUT_TENSOR(kx); @@ -58,25 +58,14 @@ std::tuple s2_attention_bwd_cpu(tor CHECK_CPU_INPUT_TENSOR(col_idx); CHECK_CPU_INPUT_TENSOR(row_off); - // change to channels first: - bool kx_is_channels_last = kx.strides()[1] == 1; - bool vx_is_channels_last = vx.strides()[1] == 1; - bool qy_is_channels_last = qy.strides()[1] == 1; - bool dy_is_channels_last = dy.strides()[1] == 1; - - if (!kx_is_channels_last) { kx = kx.contiguous(at::MemoryFormat::ChannelsLast); } - if (!vx_is_channels_last) { vx = vx.contiguous(at::MemoryFormat::ChannelsLast); } - if (!qy_is_channels_last) { qy = qy.contiguous(at::MemoryFormat::ChannelsLast); } - if (!dy_is_channels_last) { dy = dy.contiguous(at::MemoryFormat::ChannelsLast); } - auto dkx = torch::zeros_like(kx); auto dvx = torch::zeros_like(vx); auto dqy = torch::zeros_like(qy); // some parameters const int64_t batch_size = kx.size(0); - const int64_t nchannels_out = vx.size(1); - const int64_t nchannels_in = qy.size(1); + const int64_t nchannels_out = vx.size(3); + const int64_t nchannels_in = qy.size(3); // extract accessors auto kx_arr = kx.packed_accessor64(); @@ -97,11 +86,6 @@ std::tuple s2_attention_bwd_cpu(tor nlon_in, nlat_out, nlon_out, batch_size, nchannels_in, nchannels_out); - // permute back - if (!qy_is_channels_last) { dqy = dqy.contiguous(at::MemoryFormat::Contiguous); } - if (!vx_is_channels_last) { dvx = dvx.contiguous(at::MemoryFormat::Contiguous); } - if (!kx_is_channels_last) { dkx = dkx.contiguous(at::MemoryFormat::Contiguous); } - return std::make_tuple(dkx, dvx, dqy); } diff --git a/torch_harmonics/attention/csrc/attention_cpu_fwd.cpp b/torch_harmonics/attention/csrc/attention_cpu_fwd.cpp index 3abbb7b3..9f4d96bc 100644 --- a/torch_harmonics/attention/csrc/attention_cpu_fwd.cpp +++ b/torch_harmonics/attention/csrc/attention_cpu_fwd.cpp @@ -45,22 +45,13 @@ namespace attention_kernels { CHECK_CPU_INPUT_TENSOR(col_idx); CHECK_CPU_INPUT_TENSOR(row_off); - // change to channels first: - bool kx_is_channels_last = kx.strides()[1] == 1; - bool vx_is_channels_last = vx.strides()[1] == 1; - bool qy_is_channels_last = qy.strides()[1] == 1; - - if (!kx_is_channels_last) { kx = kx.contiguous(at::MemoryFormat::ChannelsLast); } - if (!vx_is_channels_last) { vx = vx.contiguous(at::MemoryFormat::ChannelsLast); } - if (!qy_is_channels_last) { qy = qy.contiguous(at::MemoryFormat::ChannelsLast); } - // some parameters const int64_t batch_size = kx.size(0); - const int64_t nchannels_out = vx.size(1); - const int64_t nchannels_in = qy.size(1); + const int64_t nchannels_out = vx.size(3); + const int64_t nchannels_in = qy.size(3); // prepare result tensor - auto y = torch::zeros({batch_size, nchannels_out, nlat_out, nlon_out}, qy.options()); + auto y = torch::zeros({batch_size, nlat_out, nlon_out, nchannels_out}, qy.options()); // extract accessors auto roff_arr = row_off.packed_accessor64(); @@ -74,9 +65,6 @@ namespace attention_kernels { s2_attn_fwd_kernel(kx_arr, vx_arr, qy_arr, quad_weights_arr, col_idx_arr, roff_arr, y_arr, nlon_in, nlat_out, nlon_out, batch_size, nchannels_in, nchannels_out); - // permute back - if (!qy_is_channels_last) { y = y.contiguous(at::MemoryFormat::Contiguous); } - return y; } diff --git a/torch_harmonics/attention/csrc/attention_cuda.cuh b/torch_harmonics/attention/csrc/attention_cuda.cuh index 0042bfce..a17c7559 100644 --- a/torch_harmonics/attention/csrc/attention_cuda.cuh +++ b/torch_harmonics/attention/csrc/attention_cuda.cuh @@ -34,11 +34,8 @@ #include #include -#define CHECK_CUDA_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCUDA) -#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_INTERNAL_ASSERT(x.is_contiguous() || x.is_contiguous(at::MemoryFormat::ChannelsLast)) -#define CHECK_CUDA_INPUT_TENSOR(x) \ - CHECK_CUDA_TENSOR(x); \ - CHECK_CONTIGUOUS_TENSOR(x) +#include "cudamacro.h" + namespace attention_kernels { diff --git a/torch_harmonics/attention/csrc/attention_cuda_bwd.cu b/torch_harmonics/attention/csrc/attention_cuda_bwd.cu index c34ccecd..ec34a2a0 100644 --- a/torch_harmonics/attention/csrc/attention_cuda_bwd.cu +++ b/torch_harmonics/attention/csrc/attention_cuda_bwd.cu @@ -41,7 +41,7 @@ #include #include -#include "cudamacro.h" +//#include "cudamacro.h" #include "attention_cuda_utils.cuh" #include @@ -52,153 +52,8 @@ #define MAX_LOCAL_ARR_LEN (16) -namespace attention_kernels { - -#if 0 -class ScopeTimer -{ - public: - explicit ScopeTimer(const std::string &label = "") : - label_(label), start_(std::chrono::high_resolution_clock::now()) - { - } - - ~ScopeTimer() - { - auto end = std::chrono::high_resolution_clock::now(); - auto elapsed = std::chrono::duration_cast(end - start_); - std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl; - } - - private: - std::string label_; - std::chrono::high_resolution_clock::time_point start_; -}; - -// easier to understand version of manual shfl_xor_sync, performance appears similar -static __device__ float __warp_sum_cub(float val) -{ - // use cub to reduce within a warp - __shared__ typename cub::WarpReduce::TempStorage temp_storage; - - // 1. Compute sum (initially only in lane 0) - float sum = cub::WarpReduce(temp_storage).Sum(val); - // 2. Broadcast sum to all threads - sum = __shfl_sync(0xFFFFFFFF, sum, 0); - return sum; -} - -// This kernel computes the backward pass for the S2 attention mechanism, using -// shared memory as a cache and one warp per output point, warp-parallel over -// channels, which should be layed out in the fastest dimension for coalesced -// memory access. -template -__global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel( - int num_channels, int nlon_in, int nlat_out, int nlon_out, - const torch::PackedTensorAccessor32 kx, - const torch::PackedTensorAccessor32 vx, - const torch::PackedTensorAccessor32 qy, - const torch::PackedTensorAccessor32 dy, - torch::PackedTensorAccessor32 dydk, - torch::PackedTensorAccessor32 dydv, - torch::PackedTensorAccessor32 dydq, - const torch::PackedTensorAccessor64 psi_col_idx, - const torch::PackedTensorAccessor64 psi_row_offset, - const torch::PackedTensorAccessor32 quad_weights) -{ - - extern __shared__ float sh[]; - float *sh_alpha_k = sh + threadIdx.y * num_channels * 5; - float *sh_alpha_vw = sh_alpha_k + num_channels; - float *sh_alpha_kvw = sh_alpha_vw + num_channels; - float *sh_dy = sh_alpha_kvw + num_channels; - float *sh_qy = sh_dy + num_channels; - // (optionally, could use more shared memory for other intermediates) - const uint64_t batchId = blockIdx.y; - const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y; - if (wid >= uint64_t(nlat_out) * nlon_in) return; - const int tidx = threadIdx.x; - const int ho = wid / nlon_out; - const int wo = wid - (ho * nlon_out); - - // Zero shared memory - for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { - sh_alpha_k[chan] = 0.0f; - sh_alpha_vw[chan] = 0.0f; - sh_alpha_kvw[chan] = 0.0f; - sh_dy[chan] = dy[batchId][chan][ho][wo]; - sh_qy[chan] = qy[batchId][chan][ho][wo]; - } - float alpha_sum = 0.0f; - float qdotk_max = -FLT_MAX; - float integral = 0.0f; - __syncthreads(); - - const int64_t rbeg = psi_row_offset[ho]; - const int64_t rend = psi_row_offset[ho + 1]; - const int rlen = rend - rbeg; - - // 1st pass: accumulate alpha_sum, integral, and shared stats, along with a progressively computed qdotk_max. - for (int off = 0; off < rlen; off++) { - const int64_t col = psi_col_idx[rbeg + off]; - const int hi = col / nlon_in; - const int wi = col - (hi * nlon_in); - const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in; - float qdotk = 0.0f, gdotv = 0.0f; - for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { - qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip]; - gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip]; - } - qdotk = __warp_sum_cub(qdotk); - gdotv = __warp_sum_cub(gdotv); - float qdotk_max_tmp = max(qdotk_max, qdotk); - float alpha_inz = expf(qdotk - qdotk_max_tmp) * quad_weights[hi]; - float max_correction = expf(qdotk_max - qdotk_max_tmp); - alpha_sum = alpha_sum * max_correction + alpha_inz; - integral = integral * max_correction + alpha_inz * gdotv; - for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { - float kxval = kx[batchId][chan][hi][wip]; - sh_alpha_k[chan] = sh_alpha_k[chan] * max_correction + alpha_inz * kxval; - sh_alpha_vw[chan] = sh_alpha_vw[chan] * max_correction + alpha_inz * gdotv; - sh_alpha_kvw[chan] = sh_alpha_kvw[chan] * max_correction + alpha_inz * kxval * gdotv; - } - qdotk_max = qdotk_max_tmp; - } - - integral /= alpha_sum; - - // Write dydq - for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { - dydq[batchId][chan][ho][wo] - = (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum); - } - - // Third pass: accumulate gradients for k and v - for (int off = 0; off < rlen; off++) { - const int64_t col = psi_col_idx[rbeg + off]; - const int hi = col / nlon_in; - const int wi = col - (hi * nlon_in); - const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in; - float qdotk = 0.0f, gdotv = 0.0f; - for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { - qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip]; - gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip]; - } - qdotk = __warp_sum_cub(qdotk); - gdotv = __warp_sum_cub(gdotv); - float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi]; - for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { - float qyval = qy[batchId][chan][ho][wo]; - float dyval = sh_dy[chan]; - atomicAdd(&dydk[batchId][chan][hi][wip], qyval * (alpha_inz / alpha_sum) * (gdotv - integral)); - atomicAdd(&dydv[batchId][chan][hi][wip], (alpha_inz / alpha_sum) * dyval); - } - } -} -#endif - -// BEGIN backward kernels and functions +namespace attention_kernels { // called with (blockDim.x=32 and blockDim.y>1, BDIM=blockDim.x*blockDim.y) template s2_attention_bwd_dkvq_cuda(at::Te CHECK_CUDA_TENSOR(psi_col_idx); CHECK_CUDA_TENSOR(psi_row_off); - //const size_t uo_num_channels = kx.size(1); - size_t nchans_in = qy.size(1); // or kx.size(1) - size_t nchans_out = vx.size(1); + // IMPORTANT: all input tensors are in channels last format! + + size_t nchans_in = qy.size(3); // or kx.size(3) + size_t nchans_out = vx.size(3); const int batch_size = kx.size(0); // extract dtype - auto kx_type = kx.dtype(); // nchans_in + auto kx_type = kx.dtype(); auto qy_type = qy.dtype(); - auto vx_type = vx.dtype(); // ncahn_out + auto vx_type = vx.dtype(); auto dy_type = dy.dtype(); torch::Tensor kxP = kx.to(torch::kFloat32); @@ -1020,19 +876,7 @@ std::tuple s2_attention_bwd_dkvq_cuda(at::Te torch::Tensor qyP = qy.to(torch::kFloat32); torch::Tensor dyP = dy.to(torch::kFloat32); - // exract memory format: this is much safer than checking is_contiguous(at::MemoryFormat::ChannelsLast) - // the former fails for num_channels == 1 - bool kx_is_channels_last = kxP.strides()[1] == 1; - bool vx_is_channels_last = vxP.strides()[1] == 1; - bool qy_is_channels_last = qyP.strides()[1] == 1; - bool dy_is_channels_last = dyP.strides()[1] == 1; - - // transpose if required - if (!kx_is_channels_last) { kxP = permute_4D_to0231(kxP); } - if (!vx_is_channels_last) { vxP = permute_4D_to0231(vxP); } - if (!qy_is_channels_last) { qyP = permute_4D_to0231(qyP); } - if (!dy_is_channels_last) { dyP = permute_4D_to0231(dyP); } - + // create output tensors torch::Tensor dkxP = torch::zeros_like(kxP); torch::Tensor dvxP = torch::zeros_like(vxP); torch::Tensor dqyP = torch::zeros_like(qyP); @@ -1053,10 +897,6 @@ std::tuple s2_attention_bwd_dkvq_cuda(at::Te torch::Tensor dvx = dvxP; torch::Tensor dqy = dqyP; - if (!kx_is_channels_last) { dkx = permute_4D_to0312(dkx); } - if (!vx_is_channels_last) { dvx = permute_4D_to0312(dvx); } - if (!qy_is_channels_last) { dqy = permute_4D_to0312(dqy); } - // convert precision back to starting dkx = dkx.to(kx_type); dvx = dvx.to(vx_type); diff --git a/torch_harmonics/attention/csrc/attention_cuda_fwd.cu b/torch_harmonics/attention/csrc/attention_cuda_fwd.cu index 6de7321d..01bdab41 100644 --- a/torch_harmonics/attention/csrc/attention_cuda_fwd.cu +++ b/torch_harmonics/attention/csrc/attention_cuda_fwd.cu @@ -39,14 +39,13 @@ #include #include -#include "cudamacro.h" +//#include "cudamacro.h" #include "attention_cuda_utils.cuh" #define THREADS (64) #define MAX_LOCAL_ARR_LEN (16) -// BEGIN - forward kernels and functions namespace attention_kernels { @@ -530,8 +529,10 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, CHECK_CUDA_TENSOR(psi_col_idx); CHECK_CUDA_TENSOR(psi_row_off); - size_t nchans_in = qy.size(1); // or kx.size(1) - size_t nchans_out = vx.size(1); + // IMPORTANT: all input tensors are in channels last format! + + size_t nchans_in = qy.size(3); // or kx.size(3) + size_t nchans_out = vx.size(3); const int batch_size = kx.size(0); @@ -542,16 +543,7 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, torch::Tensor vxP = vx.to(torch::kFloat32); torch::Tensor qyP = qy.to(torch::kFloat32); - // these are much safer than checking is_contiguous(at::MemoryFormat::ChannelsLast) - // the former fails for num_channels == 1 - bool kx_is_channels_last = kxP.strides()[1] == 1; - bool vx_is_channels_last = vxP.strides()[1] == 1; - bool qy_is_channels_last = qyP.strides()[1] == 1; - - if (!kx_is_channels_last) { kxP = permute_4D_to0231(kxP); } - if (!vx_is_channels_last) { vxP = permute_4D_to0231(vxP); } - if (!qy_is_channels_last) { qyP = permute_4D_to0231(qyP); } - + // output tensor torch::Tensor yP = torch::empty_like(vxP); s2_attn_fwd_dispatch(batch_size, @@ -567,7 +559,6 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, yP); torch::Tensor y = yP; - if (!qy_is_channels_last) { y = permute_4D_to0312(y); } // convert precision back to starting y = y.to(qy_type); diff --git a/torch_harmonics/attention/csrc/attention_cuda_utils.cu b/torch_harmonics/attention/csrc/attention_cuda_utils.cu index 09b42cfc..b1e21571 100644 --- a/torch_harmonics/attention/csrc/attention_cuda_utils.cu +++ b/torch_harmonics/attention/csrc/attention_cuda_utils.cu @@ -39,7 +39,7 @@ #include #include -#include "cudamacro.h" +//#include "cudamacro.h" #include "attention_cuda.cuh" #define THREADS (64) @@ -111,63 +111,6 @@ at::Tensor sortRows(int nlat_out, at::Tensor row_off, cudaStream_t stream) { } // END - CSR rows sorting kernels and functions - -// BEGIN - 4D tensor permutation kernels and functions -__global__ void empty_k() {} - -static int getPtxver() { - cudaFuncAttributes attrs; - CHECK_CUDA(cudaFuncGetAttributes(&attrs, empty_k)); - return attrs.ptxVersion*10; -} - -at::Tensor permute_4D_to0231(at::Tensor src) { - - auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device()); - torch::Tensor dst = torch::empty({src.size(0), src.size(2), src.size(3), src.size(1)}, options); - - const int ptxv = getPtxver(); - - // to be further specialized for additional archs, if necessary - if (ptxv < 100) { - AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_generic", ([&] { - launch_permute_to0231(src, dst); - })); - CHECK_ERROR("permute_to0231_k_tile_generic"); - } else { - AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_sm100", ([&] { - launch_permute_to0231(src, dst); - })); - CHECK_ERROR("permute_to0231_k_tile_sm100"); - } - - return dst; -} - -at::Tensor permute_4D_to0312(at::Tensor src) { - - auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device()); - torch::Tensor dst = torch::empty({src.size(0), src.size(3), src.size(1), src.size(2)}, options); - - const int ptxv = getPtxver(); - - // to be further specialized for additional archs, if necessary - if (ptxv < 100) { - AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_generic", ([&] { - launch_permute_to0312(src, dst); - })); - CHECK_ERROR("permute_to0312_k_tile_generic"); - } else { - AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_sm100", ([&] { - launch_permute_to0312(src, dst); - })); - CHECK_ERROR("permute_to0312_k_tile_sm100"); - } - - return dst; -} -// END - tensor permutation kernels and functions - // BEGIN - general host-side functions unsigned int next_pow2(unsigned int x) { diff --git a/torch_harmonics/attention/csrc/attention_cuda_utils.cuh b/torch_harmonics/attention/csrc/attention_cuda_utils.cuh index fbda09d5..625e8131 100644 --- a/torch_harmonics/attention/csrc/attention_cuda_utils.cuh +++ b/torch_harmonics/attention/csrc/attention_cuda_utils.cuh @@ -34,9 +34,9 @@ #include #include +#include "cudamacro.h" + #define WARP_SIZE (32) -#define FULL_MASK (0xFFFFFFFF) -#define DIV_UP(a,b) (((a)+((b)-1))/(b)) namespace attention_kernels { @@ -205,175 +205,4 @@ __device__ VAL_T __block_sum(VAL_T val) { return val; } -// transpose utils -template -__global__ -__launch_bounds__(BDIM_X*BDIM_Y) -void permute_to0231_k(const int nchn, - const int nlat, - const int nlon, - const at::PackedTensorAccessor32 src, - at::PackedTensorAccessor32 dst) { - - static_assert(!(BDIM_X & (BDIM_X-1))); - static_assert(!(BDIM_Y & (BDIM_Y-1))); - static_assert(BDIM_X >= BDIM_Y); - - __shared__ VAL_T sh[BDIM_X][BDIM_X+1]; - - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - - const int coff = blockIdx.x*BDIM_X; // channel offset - const int woff = blockIdx.y*BDIM_X; // width offset - const int batch = blockIdx.z / nlat; // batch (same for all block) - const int h = blockIdx.z - (batch * nlat); // height (same for all block) - - const int nchn_full = (nchn-coff) >= BDIM_X; - const int nlon_full = (nlon-woff) >= BDIM_X; - - if (nchn_full && nlon_full) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - sh[j+tidy][tidx] = src[batch][coff + j+tidy][h][woff+tidx]; - } - __syncthreads(); - - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy]; - } - } else { - if (woff+tidx < nlon) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - sh[j+tidy][tidx] = (coff + j+tidy < nchn) ? src[batch][coff + j+tidy][h][woff+tidx] : VAL_T(0); - } - } - __syncthreads(); - - if (coff+tidx < nchn) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - if (woff + j+tidy < nlon) { - dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy]; - } - } - } - } - return; -} - -template -void launch_permute_to0231(at::Tensor src, at::Tensor dst){ - dim3 block; - dim3 grid; - - block.x = WARP_SIZE; - block.y = WARPS_X_TILE; - grid.x = DIV_UP(src.size(1), block.x); - grid.y = DIV_UP(src.size(3), block.x); - grid.z = src.size(2)*src.size(0); - - assert(grid.y < 65536); - assert(grid.z < 65536); - - // get stream - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - permute_to0231_k - <<>>(src.size(1), - src.size(2), - src.size(3), - src.packed_accessor32(), - dst.packed_accessor32()); -} - -template -__global__ -__launch_bounds__(BDIM_X*BDIM_Y) -void permute_to0312_k(const int nchn, - const int nlat, - const int nlon, - const at::PackedTensorAccessor32 src, - at::PackedTensorAccessor32 dst) { - - static_assert(!(BDIM_X & (BDIM_X-1))); - static_assert(!(BDIM_Y & (BDIM_Y-1))); - static_assert(BDIM_X >= BDIM_Y); - - __shared__ VAL_T sh[BDIM_X][BDIM_X+1]; - - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - - const int woff = blockIdx.x*BDIM_X; // width offset - const int coff = blockIdx.y*BDIM_X; // channel offset - const int batch = blockIdx.z / nlat; // batch (same for all block) - const int h = blockIdx.z - (batch * nlat); // height (same for all block) - - const int nchn_full = (nchn-coff) >= BDIM_X; - const int nlon_full = (nlon-woff) >= BDIM_X; - - if (nchn_full && nlon_full) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - sh[j+tidy][tidx] = src[batch][h][woff + j+tidy][coff+tidx]; - } - __syncthreads(); - - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy]; - } - } else { - if (coff+tidx < nchn) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - sh[j+tidy][tidx] = (woff + j+tidy < nlon) ? src[batch][h][woff + j+tidy][coff+tidx] : VAL_T(0); - } - } - __syncthreads(); - - if (woff+tidx < nlon) { - #pragma unroll - for(int j = 0; j < BDIM_X; j += BDIM_Y) { - if (coff + j+tidy < nchn) { - dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];; - } - } - } - } - return; -} - -template -void launch_permute_to0312(at::Tensor src, at::Tensor dst){ - dim3 block; - dim3 grid; - - block.x = WARP_SIZE; - block.y = WARPS_X_TILE; - grid.x = DIV_UP(src.size(2), block.x); - grid.y = DIV_UP(src.size(3), block.x); - grid.z = src.size(1)*src.size(0); - - assert(grid.y < 65536); - assert(grid.z < 65536); - - // get stream - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - permute_to0312_k - <<>>(src.size(3), - src.size(1), - src.size(2), - src.packed_accessor32(), - dst.packed_accessor32()); -} - } \ No newline at end of file diff --git a/torch_harmonics/disco/__init__.py b/torch_harmonics/disco/__init__.py index 50d3268b..14ae7462 100644 --- a/torch_harmonics/disco/__init__.py +++ b/torch_harmonics/disco/__init__.py @@ -40,6 +40,6 @@ from torch.ops import disco_kernels else: disco_kernels = None - warnings.warn("No optimized kernels are available. Please compile the extension first setting BUILD_CPP and BUILD_CUDA to 1.") + warnings.warn("No optimized DISCO kernels are available. Please compile the extension first setting BUILD_CPP and BUILD_CUDA to 1.") from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 diff --git a/torch_harmonics/disco/_disco_utils.py b/torch_harmonics/disco/_disco_utils.py index 4ad1effb..5bc0fe80 100644 --- a/torch_harmonics/disco/_disco_utils.py +++ b/torch_harmonics/disco/_disco_utils.py @@ -30,7 +30,6 @@ # from typing import Optional -import math import torch from disco_helpers import optimized_kernels_is_available @@ -41,17 +40,17 @@ # raw forward fake @torch.library.register_fake("disco_kernels::forward") def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, - row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, - kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - out_shape = (inp.shape[0], inp.shape[1], kernel_size, nlat_out, nlon_out) - return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) + row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, + kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + out_shape = (inp.shape[0], nlat_out, nlon_out, inp.shape[3], kernel_size) + return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) # raw backward fake @torch.library.register_fake("disco_kernels::backward") def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - out_shape = (inp.shape[0], inp.shape[1], nlat_out, nlon_out) + out_shape = (inp.shape[0], nlat_out, nlon_out, inp.shape[3]) return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) # forward @@ -60,10 +59,7 @@ def _disco_s2_contraction_optimized( inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - itype = inp.dtype - inp = inp.to(torch.float32).contiguous() out = disco_kernels.forward.default(inp, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out) - out = out.to(itype) return out # transpose @@ -72,10 +68,7 @@ def _disco_s2_transpose_contraction_optimized( inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - itype = inp.dtype - inp = inp.to(torch.float32).contiguous() out = disco_kernels.backward.default(inp, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out) - out = out.to(itype) return out # forward fake @@ -83,39 +76,36 @@ def _disco_s2_transpose_contraction_optimized( def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - out_shape = (inp.shape[0], inp.shape[1], kernel_size, nlat_out, nlon_out) + out_shape = (inp.shape[0], nlat_out, nlon_out, inp.shape[3], kernel_size) return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) # transpose fake @torch.library.register_fake("disco_kernels::_disco_s2_transpose_contraction_optimized") def _(inp: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, - kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: - out_shape = (inp.shape[0], inp.shape[1], nlat_out, nlon_out) + kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + out_shape = (inp.shape[0], nlat_out, nlon_out, inp.shape[3]) return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) #general routines: this is the same for forward and transpose def _setup_context_conv_backward(ctx, inputs, output): - inp, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out = inputs + inp, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, _, _ = inputs ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals) + ctx.nlat_in = inp.shape[1] + ctx.nlon_in = inp.shape[2] ctx.kernel_size = kernel_size - ctx.nlat_in = inp.shape[-2] - ctx.nlon_in = inp.shape[-1] # convolution related def _disco_s2_contraction_bwd_optimized(ctx, grad_output): roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors if ctx.needs_input_grad[0]: - gtype = grad_output.dtype - grad_output = grad_output.to(torch.float32).contiguous() grad_input = disco_kernels.backward.default(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals, - ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) - grad_input = grad_input.to(gtype) + ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) # Mauro else: grad_input = None - return grad_input, None, None, None, None, None, None, None, None + return grad_input, None, None, None, None, None, None, None, None # Mauro: added a None for weights if optimized_kernels_is_available(): torch.library.register_autograd( @@ -126,15 +116,12 @@ def _disco_s2_transpose_contraction_bwd_optimized(ctx, grad_output): roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors if ctx.needs_input_grad[0]: - gtype = grad_output.dtype - grad_output = grad_output.to(torch.float32).contiguous() grad_input = disco_kernels.forward.default(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals, - ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) - grad_input = grad_input.to(gtype) + ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) # Mauro else: grad_input = None - return grad_input, None, None, None, None, None, None, None, None + return grad_input, None, None, None, None, None, None, None, None, None # Mauro: added a None for weights if optimized_kernels_is_available(): torch.library.register_autograd( @@ -182,7 +169,7 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in # add a dummy dimension for nkernel and move the batch and channel dims to the end x = x.reshape(1, batch_size * n_chans, nlat_in, nlon_in).permute(0, 2, 3, 1) - x = x.expand(kernel_size, -1, -1, -1) + x = x.expand(kernel_size, -1, -1, -1).contiguous() y = torch.zeros(nlon_out, kernel_size, nlat_out, batch_size * n_chans, device=x.device, dtype=x.dtype) @@ -193,7 +180,7 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in x = torch.roll(x, -pscale, dims=2) # reshape y back to expose the correct dimensions - y = y.permute(3, 1, 2, 0).reshape(batch_size, n_chans, kernel_size, nlat_out, nlon_out) + y = y.permute(3, 1, 2, 0).reshape(batch_size, n_chans, kernel_size, nlat_out, nlon_out).contiguous() return y @@ -212,7 +199,7 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl # interleave zeros along the longitude dimension to allow for fractional offsets to be considered x_ext = torch.zeros(kernel_size, nlat_in, nlon_out, batch_size * n_chans, device=x.device, dtype=x.dtype) - x = x.reshape(batch_size * n_chans, kernel_size, nlat_in, nlon_in).permute(1, 2, 3, 0) + x = x.reshape(batch_size * n_chans, kernel_size, nlat_in, nlon_in).permute(1, 2, 3, 0).contiguous() # x has shape kernel_size x nlat_in x nlon_in x batch_size * n_chans # we only need to apoply the nlon stride here, since nlat stride is taken care of by the kernel diff --git a/torch_harmonics/disco/convolution.py b/torch_harmonics/disco/convolution.py index 83f11c82..1a87c70f 100644 --- a/torch_harmonics/disco/convolution.py +++ b/torch_harmonics/disco/convolution.py @@ -30,18 +30,16 @@ # import abc -from typing import List, Tuple, Union, Optional -from warnings import warn +from typing import Tuple, Union, Optional import math import torch import torch.nn as nn -from functools import partial - from torch_harmonics.cache import lru_cache -from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes +from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longitudes +from torch_harmonics.utils import permute_to_0231, permute_to_0312 from ._disco_utils import _get_psi, _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch from ._disco_utils import _disco_s2_contraction_optimized, _disco_s2_transpose_contraction_optimized from torch_harmonics.filter_basis import FilterBasis, get_filter_basis @@ -164,8 +162,8 @@ def _normalize_convolution_tensor_s2( @lru_cache(typed=True, copy=True) def _precompute_convolution_tensor_s2( - in_shape: Tuple[int], - out_shape: Tuple[int], + in_shape: Tuple[int, int], + out_shape: Tuple[int, int], filter_basis: FilterBasis, grid_in: Optional[str]="equiangular", grid_out: Optional[str]="equiangular", @@ -192,9 +190,9 @@ def _precompute_convolution_tensor_s2( Parameters ----------- - in_shape: Tuple[int] + in_shape: Tuple[int, int] Input shape of the convolution tensor - out_shape: Tuple[int] + out_shape: Tuple[int, int] Output shape of the convolution tensor filter_basis: FilterBasis Filter basis functions @@ -370,9 +368,12 @@ def __init__( raise ValueError("Error, the number of input channels has to be an integer multiple of the group size") if out_channels % self.groups != 0: raise ValueError("Error, the number of output channels has to be an integer multiple of the group size") - self.groupsize = in_channels // self.groups - scale = math.sqrt(1.0 / self.groupsize / self.kernel_size) - self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size)) + self.groupsize_in = in_channels // self.groups + self.groupsize_out = out_channels // self.groups + # keep this for backward compatibility + self.groupsize = self.groupsize_in + scale = math.sqrt(1.0 / self.groupsize_in / self.kernel_size) + self.weight = nn.Parameter(scale * torch.randn(self.groups * self.groupsize_out, self.groupsize_in, self.kernel_size)) if bias: self.bias = nn.Parameter(torch.zeros(out_channels)) @@ -451,15 +452,16 @@ def __init__( self.nlat_in, self.nlon_in = in_shape self.nlat_out, self.nlon_out = out_shape + self.theta_cutoff = theta_cutoff # make sure the p-shift works by checking that longitudes are divisible assert self.nlon_in % self.nlon_out == 0 # heuristic to compute theta cutoff based on the bandlimit of the input field and overlaps of the basis functions - if theta_cutoff is None: - theta_cutoff = torch.pi / float(self.nlat_out - 1) + if self.theta_cutoff is None: + self.theta_cutoff = torch.pi / float(self.nlat_out - 1) - if theta_cutoff <= 0.0: + if self.theta_cutoff <= 0.0: raise ValueError("Error, theta_cutoff has to be positive.") idx, vals, _ = _precompute_convolution_tensor_s2( @@ -468,13 +470,13 @@ def __init__( self.filter_basis, grid_in=grid_in, grid_out=grid_out, - theta_cutoff=theta_cutoff, + theta_cutoff=self.theta_cutoff, transpose_normalization=False, basis_norm_mode=basis_norm_mode, merge_quadrature=True, ) - # sort the values + # extract values and indices ker_idx = idx[0, ...].contiguous() row_idx = idx[1, ...].contiguous() col_idx = idx[2, ...].contiguous() @@ -496,7 +498,7 @@ def __init__( self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out) def extra_repr(self): - return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" + return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize_in * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" @property def psi_idx(self): @@ -505,19 +507,41 @@ def psi_idx(self): def forward(self, x: torch.Tensor) -> torch.Tensor: if self.optimized_kernel: - x = _disco_s2_contraction_optimized( - x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out + # permute input + xp = permute_to_0231(x) + + # disco contaction + xpc = _disco_s2_contraction_optimized( + xp, + self.psi_roff_idx, + self.psi_ker_idx, + self.psi_row_idx, + self.psi_col_idx, + self.psi_vals, + self.kernel_size, + self.nlat_out, + self.nlon_out ) + + # weight multiplication + B, H, W, _, K = xpc.shape + xpc = xpc.reshape(B, H, W, self.groups, self.groupsize_in, K) + outp = torch.einsum("bxygck,gock->bxygo", xpc, self.weight.reshape(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size)) + outp = outp.reshape(B, H, W, -1).contiguous() + + # permute output + out = permute_to_0312(outp) else: + # disco contaction x = _disco_s2_contraction_torch(x, self.psi.to(x.device), self.nlon_out) - # extract shape - B, C, K, H, W = x.shape - x = x.reshape(B, self.groups, self.groupsize, K, H, W) + # extract shape + B, _, K, H, W = x.shape + x = x.reshape(B, self.groups, self.groupsize_in, K, H, W) - # do weight multiplication - out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous() - out = out.reshape(B, -1, H, W) + # weight multiplication + out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size)) + out = out.reshape(B, -1, H, W).contiguous() if self.bias is not None: out = out + self.bias.reshape(1, -1, 1, 1) @@ -588,15 +612,16 @@ def __init__( self.nlat_in, self.nlon_in = in_shape self.nlat_out, self.nlon_out = out_shape + self.theta_cutoff = theta_cutoff # make sure the p-shift works by checking that longitudes are divisible assert self.nlon_out % self.nlon_in == 0 # bandlimit - if theta_cutoff is None: - theta_cutoff = torch.pi / float(self.nlat_in - 1) + if self.theta_cutoff is None: + self.theta_cutoff = torch.pi / float(self.nlat_in - 1) - if theta_cutoff <= 0.0: + if self.theta_cutoff <= 0.0: raise ValueError("Error, theta_cutoff has to be positive.") # switch in_shape and out_shape since we want the transpose convolution @@ -606,7 +631,7 @@ def __init__( self.filter_basis, grid_in=grid_out, grid_out=grid_in, - theta_cutoff=theta_cutoff, + theta_cutoff=self.theta_cutoff, transpose_normalization=True, basis_norm_mode=basis_norm_mode, merge_quadrature=True, @@ -634,7 +659,7 @@ def __init__( self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, semi_transposed=True) def extra_repr(self): - return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" + return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize_in * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" @property def psi_idx(self): @@ -643,19 +668,40 @@ def psi_idx(self): def forward(self, x: torch.Tensor) -> torch.Tensor: # extract shape - B, C, H, W = x.shape - x = x.reshape(B, self.groups, self.groupsize, H, W) - - # do weight multiplication - x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous() - x = x.reshape(B, -1, x.shape[-3], H, W) + B, _, H, W = x.shape if self.optimized_kernel: - out = _disco_s2_transpose_contraction_optimized( - x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out + # permute input + xp = permute_to_0231(x) + + # weight multiplication + xp = xp.reshape(B, H, W, self.groups, self.groupsize_in) + xpc = torch.einsum("bxygc,gock->bxygok", xp, self.weight.reshape(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size)) + xpc = xpc.reshape(B, H, W, -1, self.kernel_size).contiguous() + + # disco contraction + outp = _disco_s2_transpose_contraction_optimized( + xpc, + self.psi_roff_idx, + self.psi_ker_idx, + self.psi_row_idx, + self.psi_col_idx, + self.psi_vals, + self.kernel_size, + self.nlat_out, + self.nlon_out ) + + # permute output + out = permute_to_0312(outp) else: - out = _disco_s2_transpose_contraction_torch(x, self.psi_st.to(x.device), self.nlon_out) + # weight multiplication + x = x.reshape(B, self.groups, self.groupsize_in, H, W) + xc = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size)) + xc = xc.reshape(B, self.groups* self.groupsize_out, -1, H, W).contiguous() + + # disco contraction + out = _disco_s2_transpose_contraction_torch(xc, self.psi_st.to(x.device), self.nlon_out) if self.bias is not None: out = out + self.bias.reshape(1, -1, 1, 1) diff --git a/torch_harmonics/disco/csrc/disco.h b/torch_harmonics/disco/csrc/disco.h index 7198ad8c..b20ccff9 100644 --- a/torch_harmonics/disco/csrc/disco.h +++ b/torch_harmonics/disco/csrc/disco.h @@ -35,10 +35,3 @@ #include #include #include - -#define CHECK_CPU_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCPU) -#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT_TENSOR(x) CHECK_CONTIGUOUS_TENSOR(x) -#define CHECK_CPU_INPUT_TENSOR(x) \ - CHECK_CPU_TENSOR(x); \ - CHECK_CONTIGUOUS_TENSOR(x) diff --git a/torch_harmonics/disco/csrc/disco_cpu.cpp b/torch_harmonics/disco/csrc/disco_cpu.cpp index 8a544e66..1549789b 100644 --- a/torch_harmonics/disco/csrc/disco_cpu.cpp +++ b/torch_harmonics/disco/csrc/disco_cpu.cpp @@ -44,12 +44,16 @@ namespace disco_kernels { CHECK_CPU_INPUT_TENSOR(col_idx); CHECK_CPU_INPUT_TENSOR(vals); + // convert input to fp32 + auto inp_dtype = inp.scalar_type(); + inp = inp.to(torch::kFloat32).contiguous(); + // initialize output tensor - auto out = torch::zeros({inp.size(0), inp.size(1), K, Ho, Wo}, inp.options()); + auto out = torch::zeros({inp.size(0), Ho, Wo, inp.size(3), K}, inp.options()); AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cpu", ([&] { disco_fwd_cpu( - inp.size(0), inp.size(1), K, inp.size(2), inp.size(3), + inp.size(0), inp.size(3), K, inp.size(1), inp.size(2), Ho, Wo, vals.size(0), roff_idx.size(0) - 1, inp.packed_accessor64(), roff_idx.packed_accessor64(), @@ -60,6 +64,9 @@ namespace disco_kernels { out.packed_accessor64()); })); + // convert to input datatype + out = out.to(inp_dtype); + return out; } @@ -73,14 +80,18 @@ namespace disco_kernels { CHECK_CPU_INPUT_TENSOR(row_idx); CHECK_CPU_INPUT_TENSOR(col_idx); CHECK_CPU_INPUT_TENSOR(vals); + + // convert input to fp32 + auto inp_dtype = inp.scalar_type(); + inp = inp.to(torch::kFloat32).contiguous(); - // initialize output tensor - auto out = torch::zeros({inp.size(0), inp.size(1), Ho, Wo}, inp.options()); + // initialize output tensor: assume channels last format + auto out = torch::zeros({inp.size(0), Ho, Wo, inp.size(3)}, inp.options()); AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cpu", ([&] { disco_bwd_cpu( - inp.size(0), inp.size(1), K, inp.size(3), - inp.size(4), Ho, Wo, vals.size(0), roff_idx.size(0) - 1, + inp.size(0), inp.size(3), K, inp.size(1), inp.size(2), + Ho, Wo, vals.size(0), roff_idx.size(0) - 1, inp.packed_accessor64(), roff_idx.packed_accessor64(), ker_idx.packed_accessor64(), @@ -90,6 +101,9 @@ namespace disco_kernels { out.packed_accessor64()); })); + // convert to input datatype + out = out.to(inp_dtype); + return out; } diff --git a/torch_harmonics/disco/csrc/disco_cpu.h b/torch_harmonics/disco/csrc/disco_cpu.h index 29b242c2..96c768a6 100644 --- a/torch_harmonics/disco/csrc/disco_cpu.h +++ b/torch_harmonics/disco/csrc/disco_cpu.h @@ -32,6 +32,8 @@ #include "disco.h" +#include "cppmacro.h" + #define CACHE_BLOCK_SIZE (64) namespace disco_kernels { @@ -55,10 +57,10 @@ namespace disco_kernels { const int64_t nblock_wo = static_cast((Wo + block_wo - 1) / block_wo); // loop over matrix entries - #pragma omp parallel for collapse(3) + #pragma omp parallel for simd collapse(3) for (int64_t b = 0; b < B; b++) { - for (int64_t c = 0; c < C; c++) { - for (int64_t row = 0; row < nnr; row++) { + for (int64_t row = 0; row < nnr; row++) { + for (int64_t c = 0; c < C; c++) { // since the rows are ordered accordingly, we can compute ho and ker in here int64_t ho = row_idx[roff_idx[row]]; @@ -89,12 +91,12 @@ namespace disco_kernels { for (int64_t wo = wo_start; wo < wo_end; wo++) { // compute shifted w int64_t wipp = static_cast((wi + pscale * wo) % Wi); - out_tmp[wo-wo_start] += val * inp[b][c][hi][wipp]; + out_tmp[wo-wo_start] += val * inp[b][hi][wipp][c]; } } // write out for (int64_t wo = wo_start; wo < wo_end; wo++) { - out[b][c][ker][ho][wo] = out_tmp[wo-wo_start]; + out[b][ho][wo][c][ker] = out_tmp[wo-wo_start]; } } } @@ -117,7 +119,7 @@ namespace disco_kernels { const int64_t pscale = static_cast(Wo / Wi); // loop over matrix entries - #pragma omp parallel for collapse(2) + #pragma omp parallel for simd collapse(2) for (int64_t b = 0; b < B; b++) { for (int64_t c = 0; c < C; c++) { @@ -142,7 +144,7 @@ namespace disco_kernels { for (int64_t wi = 0; wi < Wi; wi++) { // compute shifted w int64_t wopp = static_cast((wo + pscale * wi) % Wo); - out[b][c][ho][wopp] += val * inp[b][c][ker][hi][wi]; + out[b][ho][wopp][c] += val * inp[b][hi][wi][c][ker]; } } } diff --git a/torch_harmonics/disco/csrc/disco_cuda.cuh b/torch_harmonics/disco/csrc/disco_cuda.cuh index ad8e276d..64d2d6cd 100644 --- a/torch_harmonics/disco/csrc/disco_cuda.cuh +++ b/torch_harmonics/disco/csrc/disco_cuda.cuh @@ -35,12 +35,7 @@ #include #include -#define CHECK_CUDA_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCUDA) -#define CHECK_CUDA_INPUT_TENSOR(x) \ - CHECK_CUDA_TENSOR(x); \ - CHECK_CONTIGUOUS_TENSOR(x) - -#define DIV_UP(a, b) (((a) + ((b)-1)) / (b)) +#include "cudamacro.h" #define MIN_THREADS (64) #define ELXTH_MAX (32) @@ -48,11 +43,11 @@ namespace disco_kernels { // forward kernel - torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, - torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo); + torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, + torch::Tensor col_idx, torch::Tensor val, int64_t kernel_size, int64_t Ho, int64_t Wo); // backward kernel - torch::Tensor disco_cuda_bwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, - torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo); + torch::Tensor disco_cuda_bwd(torch::Tensor ograd, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, + torch::Tensor col_idx, torch::Tensor val, int64_t kernel_size, int64_t Ho, int64_t Wo); } diff --git a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu index 00f5e514..0e4aa886 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_bwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_bwd.cu @@ -30,9 +30,16 @@ #include "disco.h" #include "disco_cuda.cuh" +#include "csr_cuda.cuh" + +#define THREADS (64) + +#define MAX_LOCAL_ARR_LEN (20) namespace disco_kernels { +using namespace utility_kernels; + template __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int Ho, const int Wo, const int pscale, const int64_t *__restrict__ roff, const int64_t *__restrict__ kers, @@ -198,12 +205,744 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t return; } - torch::Tensor disco_cuda_bwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, +// BEGIN NEW CHANNEL-LAST VERSION + +template +static __global__ void pack_vals_k(const int64_t K, + const int64_t nrows, + const int64_t *__restrict__ row_off, + const VAL_T *__restrict__ val_dat, + VAL_T *__restrict__ val_pck) { + + const int tidx = threadIdx.x; + const int wid = blockIdx.x*blockDim.y + threadIdx.y; + if (wid >= nrows) { + return; + } + + const int64_t rbeg = row_off[wid]; + const int64_t rend = row_off[wid+1]; + + const int rlen = rend-rbeg; + + val_pck += rbeg*K; + + for(int off = tidx; off < rlen; off += blockDim.x) { + for(int ker = 0; ker < K; ker++) { + + val_pck[off*K + ker] = val_dat[ row_off[ker*nrows + wid] + off]; + } + } + + return; +} + +template +static __device__ void processCSR_Kpow2_shm_d(const int wi, + const int rlen, + const int nchans, // no. of input FLOATV_T elements along channel dim + const int nlon_out, + const int pscale, + const int K, + const FLOATV_T *__restrict__ shx, + const int64_t *__restrict__ cols, + const FLOATV_T *__restrict__ vals, + float *__restrict__ shy, + float *__restrict__ y) { + const int tidx = threadIdx.x; + + const int log2_K = __ffs(K)-1; + + const int tidxDivK = tidx >> log2_K; + const int tidxModK = tidx & (K-1); + + vals += tidxModK; + + const int BDIMX_div_K = BDIM_X >> log2_K; + + for(int chan = tidx; chan < nchans; chan += WARP_SIZE) { + shy[chan] = 0; + } + __syncwarp(); + + for(int off = 0; off < rlen; off++) { + + const int64_t col = cols[off]; + + const int ho = col / nlon_out; + const int wo = col - (ho*nlon_out); + + int wop = wo + pscale*wi; + wop -= (wop / nlon_out)*nlon_out; + + float *_y = y + int64_t(ho)*nlon_out*nchans + int64_t(wop)*nchans; + + float *_shy = shy + tidxDivK; + + const FLOATV_T myval = vals[0]; + + for(int i = 0; i < nchans*K; i += WARP_SIZE) { + + float sum = (i+tidx < nchans*K) ? __vred(__vmul(myval, shx[i+tidx])) : 0; + + for(int j = 1; j < K; j *= 2) { + sum += __shfl_xor_sync(FULL_MASK, sum, j); + } + + if (i+tidx < nchans*K && !tidxModK) { + _shy[0] += sum; + } + _shy += BDIMX_div_K; + } + __syncwarp(); + + for(int chan = tidx; chan < nchans; chan += WARP_SIZE) { + atomicAdd(_y+chan, shy[chan]); + shy[chan] = 0; + } + __syncwarp(); + + vals += K; + } + + return; +} + +template +static __device__ void processCSR_Kanyv_shm_d(const int wi, + const int rlen, + const int nchans, // no. of input FLOATV_T elements along channel dim + const int nlon_out, + const int pscale, + const int K, + const FLOATV_T *__restrict__ shx, + const int64_t *__restrict__ cols, + const FLOATV_T *__restrict__ vals, + float *__restrict__ shy, + float *__restrict__ y) { + const int tidx = threadIdx.x; + + for(int chan = tidx; chan < nchans; chan += WARP_SIZE) { + shy[chan] = 0; + } + __syncwarp(); + + for(int off = 0; off < rlen; off++) { + + const int64_t col = cols[off]; + + const int ho = col / nlon_out; + const int wo = col - (ho*nlon_out); + + int wop = wo + pscale*wi; + wop -= (wop / nlon_out)*nlon_out; + + float *_y = y + int64_t(ho)*nlon_out*nchans + int64_t(wop)*nchans; + + for(int chan = tidx; chan < nchans*K; chan += WARP_SIZE) { + + const int cDivK = chan / K; + const int cModK = chan - (cDivK*K); + + float sum = __vred(__vmul(vals[cModK], shx[chan])); + + atomicAdd(shy+cDivK, sum); + } + __syncwarp(); + + for(int chan = tidx; chan < nchans; chan += WARP_SIZE) { + atomicAdd(_y+chan, shy[chan]); + shy[chan] = 0; + } + __syncwarp(); + + vals += K; + } + + return; +} + +template // either float or float4 +__global__ +__launch_bounds__(BDIM) +void s2_disco_bwd_generic_vec_k(int nchans, // no. of input float (not FLOATV_T!) elements along channel dim + int nlat_in, + int nlon_in, + int nlat_out, + int nlon_out, + int pscale, + int K, // no. of output FLOATV_T elem along K dim (kernel size) + const FLOATV_T *__restrict__ x, + const int64_t csr_nrow, + const int32_t *__restrict__ row_sort, + const int64_t *__restrict__ row_off, + const int64_t *__restrict__ row_idx, + const int64_t *__restrict__ col_idx, + const FLOATV_T *__restrict__ val_pck, + float *__restrict__ y) { + + constexpr int VEC_SIZE = sizeof(FLOATV_T) / sizeof(float); + + const int tidx = threadIdx.x; + + const int batch = blockIdx.y; + const int ctaid = blockIdx.x*blockDim.y + threadIdx.y; + + if (ctaid >= csr_nrow*nlon_in) { + return; + } + + const int h = ctaid / nlon_in; + const int wi = ctaid - (h*nlon_in); + + // set csr_row to "h" to bypass the row sorting + const int csr_row = row_sort[h]; // h + + const int64_t rbeg = row_off[csr_row ]; + const int64_t rend = row_off[csr_row+1]; + + const int hi = row_idx[rbeg]; // reads only the first "nrow" rows of row_idx and only the first element of each row + + x += int64_t(batch)*nlat_in*nlon_in*nchans*K + int64_t(hi)*nlon_in*nchans*K + int64_t(wi)*nchans*K; + y += int64_t(batch)*nlat_out*nlon_out*nchans; + + extern __shared__ __align__(sizeof(float4)) float shext[]; + + FLOATV_T *shx = reinterpret_cast(shext) + nchans*K*threadIdx.y; + float *shy = reinterpret_cast(shext) + nchans*K*VEC_SIZE*blockDim.y + nchans*threadIdx.y; + + for(int chan = tidx; chan < nchans*K; chan += WARP_SIZE) { + shx[chan] = x[chan]; + } + + col_idx += rbeg; + val_pck += rbeg*K; // val_pck CSR contains K values per element + + const int rlen = rend-rbeg; + + // check if BDIM_X is a multiple of K; since BDIM_X is a power of 2, check if K is also a power of two + if (!(K & K-1) && K <= WARP_SIZE) { processCSR_Kpow2_shm_d(wi, rlen, nchans, nlon_out, pscale, K, shx, col_idx, val_pck, shy, y); } + else { processCSR_Kanyv_shm_d(wi, rlen, nchans, nlon_out, pscale, K, shx, col_idx, val_pck, shy, y); } + + return; +} + +template +static __device__ void processCSR_Kpow2_reg_d(const int wi, + const int rlen, + const int nchans, // no. of input FLOATV_T elements along channel dim + const int nlon_out, + const int pscale, + const int K, + const FLOATV_T (&locx)[NLOC], + const int64_t *__restrict__ cols, + const FLOATV_T *__restrict__ vals, + float *(&shYOff)[BDIM_X+SHPAD], + float *__restrict__ shy, // NO LONGER USED + float *__restrict__ y) { + constexpr int NLOC_M1 = NLOC-1; + + const int tidx = threadIdx.x; + + unsigned int subwarp_mask = FULL_MASK; + + if constexpr(BDIM_X <= WARP_SIZE) { + constexpr unsigned int MASK = (1ull << BDIM_X)-1; + unsigned int subwarp_id = threadIdx.y % (WARP_SIZE/BDIM_X); + subwarp_mask = MASK << (subwarp_id*BDIM_X); + } + constexpr int MAX_POW2_K = (BDIM_X < WARP_SIZE) ? BDIM_X : WARP_SIZE; + + // K is a power of two <= BDIM_X + const int log2_K = __popc(K-1); + + const int tidxDivK = tidx >> log2_K; + const int tidxModK = tidx & (K-1); + + cols += tidx; + vals += tidxModK; + + const int BDIMX_div_K = BDIM_X >> log2_K; + + for(int off = 0; off < rlen; off++) { + if ((off % BDIM_X) == 0) { + __sync(); + + const int64_t col = (off+tidx < rlen) ? cols[0] : 0; + + const int ho = col / nlon_out; + const int wo = col - (ho*nlon_out); + + int wop = wo + pscale*wi; + wop -= (wop / nlon_out)*nlon_out; + + shYOff[tidx] = y + int64_t(ho)*nlon_out*nchans + int64_t(wop)*nchans; + cols += BDIM_X; + + __sync(); + } + + float *_y = shYOff[off % BDIM_X] + tidxDivK; + + const FLOATV_T myval = vals[0]; + + float locy[NLOC]; + + #pragma unroll + for(int i = 0; i < NLOC; i++) { + locy[i] = __vred(__vmul(myval, locx[i])); + } + + // K is a power of two <= 32 + #pragma unroll + for(int j = 1; j < MAX_POW2_K; j *= 2) { + + if (j >= K) break; + + #pragma unroll + for(int i = 0; i < NLOC; i++) { + locy[i] += __shfl_xor_sync(subwarp_mask, locy[i], j, MAX_POW2_K); + } + } + + if (!tidxModK) { + // NLOC*BDIM_X >= nchans*K + // NLOC_M1*BDIM_X < nchans*K => NLOC_M1*BDIM_X/K < nchans + + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + atomicAdd(_y + i*BDIMX_div_K, locy[i]); + } + if (NLOC_M1*BDIM_X+tidx < nchans*K) { + atomicAdd(_y + NLOC_M1*BDIMX_div_K, locy[NLOC_M1]); + } + } + vals += K; + } + + return; +} + +template +static __device__ void processCSR_Kanyv_reg_d(const int wi, + const int rlen, + const int nchans, // no. of input FLOATV_T elements along channel dim + const int nlon_out, + const int pscale, + const int K, + const FLOATV_T (&locx)[NLOC], + const int64_t *__restrict__ cols, + const FLOATV_T *__restrict__ vals, + float *(&shYOff)[BDIM_X+SHPAD], + float *__restrict__ shy, + float *__restrict__ y) { + const int tidx = threadIdx.x; + + for(int chan = tidx; chan < nchans; chan += BDIM_X) { + shy[chan] = 0; + } + __sync(); + + cols += tidx; + + for(int off = 0; off < rlen; off++) { + + if ((off % BDIM_X) == 0) { + __sync(); + + const int64_t col = (off+tidx < rlen) ? cols[0] : 0; + + const int ho = col / nlon_out; + const int wo = col - (ho*nlon_out); + + int wop = wo + pscale*wi; + wop -= (wop / nlon_out)*nlon_out; + + shYOff[tidx] = y + int64_t(ho)*nlon_out*nchans + int64_t(wop)*nchans; + cols += BDIM_X; + + __sync(); + } + + float *_y = shYOff[off % BDIM_X]; + + // shy is allocated as ceil(nchans / (BDIM_X/K))*(BDIM_X/K) + // so we can just loop NLOC timss + #pragma unroll + for(int i = 0; i < NLOC; i++) { + + const int chan = i*BDIM_X+tidx; + const int cDivK = chan / K; + const int cModK = chan - (cDivK*K); + + float sum = __vred(__vmul(vals[cModK], locx[i])); + + atomicAdd(shy+cDivK, sum); + } + __sync(); + + for(int chan = tidx; chan < nchans; chan += BDIM_X) { + atomicAdd(_y+chan, shy[chan]); + shy[chan] = 0; + } + __sync(); + + vals += K; + } + + return; +} + +template // either float or float4 +__global__ +__launch_bounds__(BDIM_X*BDIM_Y) +void s2_disco_bwd_special_vec_k(int nchans, // no. of input float (not FLOATV_T!) elements along channel dim + int nlat_in, + int nlon_in, + int nlat_out, + int nlon_out, + int pscale, + int K, // no. of input FLOATV_T elem along K dim (kernel size) + const FLOATV_T *__restrict__ x, + const int64_t csr_nrow, + const int32_t *__restrict__ row_sort, + const int64_t *__restrict__ row_off, + const int64_t *__restrict__ row_idx, + const int64_t *__restrict__ col_idx, + const FLOATV_T *__restrict__ val_pck, + float *__restrict__ y) { + + static_assert(0 == (BDIM_X & (BDIM_X-1))); + static_assert(0 == (BDIM_Y & (BDIM_Y-1))); + static_assert((BDIM_X <= 32 && BDIM_Y > 1) || + (BDIM_X > 32 && BDIM_Y == 1)) ; + + constexpr int NLOC_M1 = NLOC-1; + + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + + const int batch = blockIdx.y; + const int ctaid = blockIdx.x*blockDim.y + threadIdx.y; + + if (ctaid >= csr_nrow*nlon_in) { + return; + } + + const int h = ctaid / nlon_in; + const int wi = ctaid - (h*nlon_in); + + // set csr_row to "h" to bypass the row sorting + const int csr_row = row_sort[h]; // h + + const int64_t rbeg = row_off[csr_row ]; + const int64_t rend = row_off[csr_row+1]; + + const int hi = row_idx[rbeg]; // reads only the first "nrow" rows of row_idx and only the first element of each row + + x += int64_t(batch)*nlat_in*nlon_in*nchans*K + int64_t(hi)*nlon_in*nchans*K + int64_t(wi)*nchans*K + tidx; + y += int64_t(batch)*nlat_out*nlon_out*nchans; + + FLOATV_T locx[NLOC]; + + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + locx[i] = x[i*BDIM_X]; + } + locx[NLOC_M1] = __vset(0.0f); + if (NLOC_M1*BDIM_X+tidx < nchans*K) { + locx[NLOC_M1] = x[NLOC_M1*BDIM_X]; + } + + // only used if K is not a multiple of 2 + extern __shared__ __align__(sizeof(float4)) float shext[]; + float *shy = shext + DIV_UP(nchans, BDIM_X)*BDIM_X*threadIdx.y; + + col_idx += rbeg; + val_pck += rbeg*K; // val_pck CSR contains K values per element + + const int rlen = rend-rbeg; + + constexpr int PAD = (BDIM_X < WARP_SIZE) ? 1 : 0; + __shared__ float *shYOffAll[BDIM_Y][BDIM_X+PAD]; + + // check if BDIM_X is a multiple of K; since BDIM_X is a power of 2, check if K is also a power of two + constexpr int MAX_POW2_K = (BDIM_X < WARP_SIZE) ? BDIM_X : WARP_SIZE; + if (!(K & K-1) && K <= MAX_POW2_K) { processCSR_Kpow2_reg_d(wi, rlen, nchans, nlon_out, pscale, K, locx, col_idx, val_pck, shYOffAll[tidy], NULL, y); } + else { processCSR_Kanyv_reg_d(wi, rlen, nchans, nlon_out, pscale, K, locx, col_idx, val_pck, shYOffAll[tidy], shy, y); } + + return; + +} + +template +void launch_gen_disco_bwd(int64_t batch_size, + int64_t nchans, + int64_t nlat_in, + int64_t nlon_in, + int64_t nlat_out, + int64_t nlon_out, + int64_t K, + FLOATV_T *__restrict__ _xp, + int64_t nrow, + int32_t *_row_sort, + int64_t *_row_off, + int64_t *_row_idx, + int64_t *_col_idx, + FLOATV_T *_val_pck, + float *__restrict__ _yp, + cudaStream_t stream) { + + dim3 block(WARP_SIZE, THREADS/WARP_SIZE); + dim3 grid(DIV_UP(nrow*nlon_in, block.y), batch_size); + + size_t shsize = (sizeof(FLOATV_T)*(nchans*K) + sizeof(float)*nchans)*block.y; + + const int pscale = nlon_out / nlon_in; +#if 1 + printf("Launching s2_disco_bwd_generic_vec_k<%d, float%s><<<(%d,%d), (%d,%d)..., ..., %zu, ...>>> with:\n" + "\tnchan_out: %ld\n" + "\tK: %ld\n" + "\tpscale: %d\n" + "\tnlat_in: %ld\n" + "\tnlon_in: %ld\n" + "\tnlat_out: %ld\n" + "\tnlon_out: %ld\n\n", + THREADS, sizeof(FLOATV_T)==16?"4":"", grid.x, grid.y, block.x, block.y, shsize, nchans, K, pscale, + nlat_in, nlon_in, nlat_out, nlon_out); +#endif + // will use only the first 1/K-th of the CSR, i.e. only the first nlat_out rows + s2_disco_bwd_generic_vec_k + <<>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out, pscale, K, + _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp); + CHECK_ERROR("s2_disco_bwd_generic_vec_k"); + + return; +} + +template +void launch_spc_disco_bwd(int nloc, // "BDIM_X*nloc" >= nchans + int64_t batch_size, + int64_t nchans, + int64_t nlat_in, + int64_t nlon_in, + int64_t nlat_out, + int64_t nlon_out, + int64_t K, + FLOATV_T *__restrict__ _xp, + int64_t nrow, + int32_t *_row_sort, + int64_t *_row_off, + int64_t *_row_idx, + int64_t *_col_idx, + FLOATV_T *_val_pck, + float *__restrict__ _yp, + cudaStream_t stream) { + + if (CUR_LOC_SIZE == nloc) { + + constexpr int BDIM_Y = (BDIM_X <= WARP_SIZE) ? THREADS / BDIM_X : 1; + + dim3 block(BDIM_X, BDIM_Y); + dim3 grid(DIV_UP(nrow*nlon_in, block.y), batch_size); + + // could be (BDIM_X/K) instead of BDIM_X but let's keep it simple + size_t shsize = (K & (K-1)) ? sizeof(float)*DIV_UP(nchans, BDIM_X)*BDIM_X*block.y : 0; + + const int pscale = nlon_out / nlon_in; +#if 1 + printf("Launching s2_disco_bwd_special_vec_k<%d, %d, %d, float%s><<<(%d, %d), (%d, %d), ..., %zu, ...>>> with:\n" + "\tnchans: %ld\n" + "\tK: %ld\n" + "\tpscale: %d\n" + "\tnlat_in: %ld\n" + "\tnlon_in: %ld\n" + "\tnlat_out: %ld\n" + "\tnlon_in: %ld\n\n", + BDIM_X, BDIM_Y, CUR_LOC_SIZE, sizeof(FLOATV_T)==16?"4":"", grid.x, grid.y, block.x, block.y, shsize, nchans, K, pscale, + nlat_in, nlon_in, nlat_out, nlon_out); +#endif + s2_disco_bwd_special_vec_k + <<>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out, pscale, K, + _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp); + + CHECK_ERROR("s2_disco_bwd_special_vec_k"); + + return; + } + if constexpr(CUR_LOC_SIZE < MAX_LOC_SIZE) { + launch_spc_disco_bwd(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, + K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); + } + return; +} + +static void s2_disco_bwd_dispatch(int64_t batch_size, + int64_t nchans, + int64_t nlat_in, + int64_t nlon_in, + int64_t nlat_out, + int64_t nlon_out, + int64_t K, + at::Tensor xP, + at::Tensor row_off, // CSR non-empty row offsets + at::Tensor row_idx, // CSR non-empty row indices + at::Tensor col_idx, // CSR non-empty col indices + at::Tensor val_dat, // CSR non-empty value data + at::Tensor yP) { + + //static_assert(0 == (MAX_LOCAL_ARR_LEN & (MAX_LOCAL_ARR_LEN-1))); + if (batch_size <= 0 || + nchans <= 0 || + nlon_in <= 0 || + nlat_out <= 0 || + nlon_out <= 0 || + K <= 0 || + K > WARP_SIZE) { + + fprintf(stderr, + ":%s:%d: invalid value of one or more input parameters!\n", + __FILE__, __LINE__); + exit(EXIT_FAILURE); + } + + // get stream + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + // replace the K sequential CRSs in "val_dat": + // + // val_dat[ 0: nnz/K) for ker = 0 + // val_dat[nnz/K:2*nnz/K) for ker = 1 + // ... + // val_dat[nnz/K:2*nnz/K) for ker = K-1 + // + // with a packed CSR: + // + // val_dat[nnz/K][K], i.e. with a CSR where elements of the original K CSRs are packed in consecutive elements + assert(0 == (val_idx.size(0) % K)); + + int64_t nrow_csr = row_off.size(0)-1; + assert(0 == (nrow_csr % K)); + + int64_t nrow = nrow_csr / K; + + // sort row indices (ho-s) in descending order + // based on (row_off[ho+1]-row_off[ho]) + at::Tensor row_sort = sortRows(nrow, row_off, stream); + + // move into "disco_cuda_utils.cu" IF val_dat format won't be changed upstream in the call chain + int64_t val_dims[] = {val_dat.size(0)}; + auto options = torch::TensorOptions().device(val_dat.device()).dtype(val_dat.dtype()); + torch::Tensor val_pck = torch::zeros(val_dims, options); + { + dim3 block(WARP_SIZE, THREADS/WARP_SIZE); + dim3 grid(DIV_UP(nlat_in, block.y)); + pack_vals_k<<>>(K, nrow, + row_off.data_ptr(), + val_dat.data_ptr(), + val_pck.data_ptr()); + } + // if K is a multiple of VEC_SIZE it will be read with vector lds + + // smallest power of two "bdimx" (>=4) s.t. bdimx*MAX_LOCAL_ARR_LEN >= nchans*K + int bdimx; + bdimx = DIV_UP(nchans*K, MAX_LOCAL_ARR_LEN); + bdimx = max(bdimx, WARP_SIZE/8); // min 4 threads per group + bdimx = next_pow2(bdimx); + + float *_xp = reinterpret_cast(xP.data_ptr()); + float *_yp = reinterpret_cast(yP.data_ptr()); + + int32_t *_row_sort = reinterpret_cast(row_sort.data_ptr()); + int64_t *_row_off = reinterpret_cast(row_off.data_ptr()); + int64_t *_row_idx = reinterpret_cast(row_idx.data_ptr()); + int64_t *_col_idx = reinterpret_cast(col_idx.data_ptr()); + float *_val_pck = reinterpret_cast(val_pck.data_ptr()); + + constexpr int VEC_SIZE = sizeof(float4) / sizeof(float); + + if (!is_aligned(_yp) || + !is_aligned(_xp) || + !is_aligned(_val_pck) || + (K % VEC_SIZE) != 0) { + + const int nloc = DIV_UP(nchans*K, bdimx); + + // to avoid the compilation of unused template instances; + // we use a block size BDIM_X that is the smallest power of 2 + // such that BDIM_X*MAX_LOCAL_ARR_LEN >= nchans*K, so + // BDIM_X > 32 are used only for: + // + // (BDIM_X-1)*MAX_LOCAL_ARR_LEN < nchan <= BDIM_X*MAX_LOCAL_ARR_LEN + constexpr int MIN_LOCAL_ARR_LEN = MAX_LOCAL_ARR_LEN/2+1; + + // use 2D blocks only if 32 threads are enough + switch(bdimx) { + case 4: launch_spc_disco_bwd< 4, 1, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 8: launch_spc_disco_bwd< 8, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 16: launch_spc_disco_bwd< 16, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 32: launch_spc_disco_bwd< 32, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 64: launch_spc_disco_bwd< 64, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 128: launch_spc_disco_bwd< 128, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 256: launch_spc_disco_bwd< 256, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 512: launch_spc_disco_bwd< 512, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 1024: launch_spc_disco_bwd<1024, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + default: launch_gen_disco_bwd ( batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + } + + } else { + + float4 *_xp4 = reinterpret_cast(_xp); + + float4 *_val_pck4 = reinterpret_cast(_val_pck); + + K /= VEC_SIZE; + const int nloc = DIV_UP(nchans*K, bdimx); + + constexpr int MAX_LOCAL_VEC_LEN = MAX_LOCAL_ARR_LEN / VEC_SIZE; + constexpr int MIN_LOCAL_VEC_LEN = MAX_LOCAL_VEC_LEN/2+1; + + // use 2D blocks only if 32 threads are enough + switch(bdimx) { + case 4: launch_spc_disco_bwd< 4, 1, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp, stream); break; + case 8: launch_spc_disco_bwd< 8, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp, stream); break; + case 16: launch_spc_disco_bwd< 16, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp, stream); break; + case 32: launch_spc_disco_bwd< 32, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp, stream); break; + case 64: launch_spc_disco_bwd< 64, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp, stream); break; + case 128: launch_spc_disco_bwd< 128, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp, stream); break; + case 256: launch_spc_disco_bwd< 256, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp, stream); break; + case 512: launch_spc_disco_bwd< 512, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp, stream); break; + case 1024: launch_spc_disco_bwd<1024, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp, stream); break; + default: launch_gen_disco_bwd ( batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp4, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp, stream); break; + } + } + return; +} + +// END NEW CHANNEL-LAST VERSION + torch::Tensor disco_cuda_bwd(torch::Tensor ograd, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo) { // some sanity checks - CHECK_CUDA_INPUT_TENSOR(inp); + CHECK_CUDA_INPUT_TENSOR(ograd); CHECK_CUDA_INPUT_TENSOR(roff_idx); CHECK_CUDA_INPUT_TENSOR(ker_idx); CHECK_CUDA_INPUT_TENSOR(row_idx); @@ -211,75 +950,58 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t CHECK_CUDA_INPUT_TENSOR(val); // extract some shapes - int64_t B = inp.size(0); - int64_t C = inp.size(1); - int64_t BC = B * C; - int64_t Hi = inp.size(3); - int64_t Wi = inp.size(4); + int64_t batch_size = ograd.size(0); + int64_t nlat_in = ograd.size(1); + int64_t nlon_in = ograd.size(2); + int64_t Co = ograd.size(3); + int64_t Kograd = ograd.size(4); + if (K != Kograd) { + fprintf(stderr, + "%s:%d: error, K (%ld) must match size of dimension 4 of ograd (%ld)!\n", + __func__, __LINE__, K, Kograd); + exit(EXIT_FAILURE); + } + + int64_t nchan = Co * Kograd; int64_t nrows = roff_idx.size(0) - 1; + int64_t nlat_out = Ho; + int64_t nlon_out = Wo; // allocate output - int64_t out_dims[] = {B, C, Ho, Wo}; - auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype()); - torch::Tensor out = torch::zeros(out_dims, options); + int64_t out_dims[] = {batch_size, Ho, Wo, Co}; // get stream auto stream = at::cuda::getCurrentCUDAStream().stream(); - // assert - static_assert(0 == (ELXTH_MAX % 2)); - - if (Wo <= 64 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { - launch_kernel<64, 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); - } else if (Wo <= 128 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { - launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); - } else if (Wo <= 256 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { - launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); - } else if (Wo <= 512 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { - launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); - } else if (Wo <= 1024 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { - launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); - } else { - fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo, - 1024 * ELXTH_MAX); - exit(EXIT_FAILURE); - } - return out; + // extract dtype, convert to fp32 and make contiguous + auto x_type = ograd.dtype(); + torch::Tensor xP = ograd.reshape({batch_size, nlat_in, nlon_in, nchan}).to(torch::kFloat32).contiguous(); + + torch::Tensor igrad = torch::zeros(out_dims, xP.options()); + + // call channel-last kernel implementation + s2_disco_bwd_dispatch(batch_size, + Co, //nchan, + nlat_in, + nlon_in, + nlat_out, + nlon_out, + K, + xP, + roff_idx, + row_idx, + col_idx, + val, + igrad); + + // convert back to original dtype + igrad = igrad.to(x_type); + + return igrad; } TORCH_LIBRARY_IMPL(disco_kernels, CUDA, m) { m.impl("backward", &disco_cuda_bwd); } - -} \ No newline at end of file +} diff --git a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu index 8482f76d..38a5feba 100644 --- a/torch_harmonics/disco/csrc/disco_cuda_fwd.cu +++ b/torch_harmonics/disco/csrc/disco_cuda_fwd.cu @@ -28,11 +28,19 @@ // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +//#include "cudamacro.h" #include "disco.h" #include "disco_cuda.cuh" +#include "csr_cuda.cuh" + +#define THREADS (64) + +#define MAX_LOCAL_ARR_LEN (20) namespace disco_kernels { +using namespace utility_kernels; + template __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int Ho, const int Wo, const int pscale, const int64_t *__restrict__ roff, const int64_t *__restrict__ kers, @@ -51,14 +59,13 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H const int64_t ker = kers[soff]; const int64_t row = rows[soff]; - inp += bidy * Hi * Wi; - out += bidy * K * Ho * Wo + ker * Ho * Wo + row * Wo; + inp += bidy*Hi*Wi; + out += bidy*K*Ho*Wo + ker*Ho*Wo + row*Wo; REAL_T __reg[ELXTH] = {0}; // align to larger supported fp type - extern __shared__ __align__( - sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)] + extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)] REAL_T *__sh = reinterpret_cast(__sh_ptr); int col_prev = cols[soff]; @@ -185,6 +192,689 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t return; } +template +static __global__ void pack_vals_k(const int64_t K, + const int64_t nrows, + const int64_t *__restrict__ row_off, + const VAL_T *__restrict__ val_dat, + VAL_T *__restrict__ val_pck) { + + const int tidx = threadIdx.x; + const int wid = blockIdx.x*blockDim.y + threadIdx.y; + if (wid >= nrows) { + return; + } + + const int64_t rbeg = row_off[wid]; + const int64_t rend = row_off[wid+1]; + + const int rlen = rend-rbeg; + + val_pck += rbeg*K; + + for(int off = tidx; off < rlen; off += blockDim.x) { + for(int ker = 0; ker < K; ker++) { + + val_pck[off*K + ker] = val_dat[ row_off[ker*nrows + wid] + off]; + } + } + + return; +} + + +// BEGIN VERSION WITH CHANNEL-LAST WITH 2D BLOCKS, 2ND DIM IDENTIFYING CHANNLES, NO EINSUM +template +__device__ void processCSR_Kpow2_shm_d(const int wo, + const int rlen, + const int nchan_in, // no. of input floats (not FLOATV_T!) elements along channel dim + const int nlon_in, + const int pscale, + const int K, + const float *__restrict__ x, + const int64_t *__restrict__ cols, + const FLOATV_T *__restrict__ vals, + FLOATV_T *__restrict__ shy) { + const int tidx = threadIdx.x; + + // only used in K_POWER_2==1 branch + const int log2_K = __ffs(K)-1; + + x += tidx >> log2_K; + vals += tidx & (K-1); + + const int BDIM_XdivK = BDIM_X >> log2_K; + + for(int off = 0; off < rlen; off++) { + + const int64_t col = cols[off]; + + const int hi = col / nlon_in; + const int wi = col - (hi*nlon_in); + + //const int wip = (wi + pscale*wo) % nlon_in; + // value of (wi + pscale*wo) < (Wi + (Wi/Wo)*Wo) = 2*Wi + // so we can replace the modulo with: + int wip = wi + pscale*wo; + if (wip >= nlon_in) wip -= nlon_in; + + const float *_x = x + int64_t(hi)*nlon_in*nchan_in + int64_t(wip)*nchan_in; + + // if BDIM_X is a multiple of K then "i*(j*BDIM_X) % K = const", + // so thread "i" only needs to read vals[off*K + (i % K)] to update the + // whole channel array + + const FLOATV_T myval = vals[0]; //vals[off*K + tidxModK]; + + for(int chan = tidx; chan < nchan_in*K; chan += BDIM_X) { // no. of vectors in nchan_in*K dim on intermediate out + + shy[chan] = __vadd(shy[chan], + __vmul(myval, + __vset(_x[0]))); + _x += BDIM_XdivK; + } + + vals += K; + } + return; +} + +template +__device__ void processCSR_Kanyv_shm_d(const int wo, + const int rlen, + const int nchan_in, // no. of input floats (not FLOATV_T!) elements along channel dim + const int nlon_in, + const int pscale, + const int K, + const float *__restrict__ x, + const int64_t *__restrict__ cols, + const FLOATV_T *__restrict__ vals, + FLOATV_T *__restrict__ shy) { + const int tidx = threadIdx.x; + + for(int off = 0; off < rlen; off++) { + + const int64_t col = cols[off]; + + const int hi = col / nlon_in; + const int wi = col - (hi*nlon_in); + + //const int wip = (wi + pscale*wo) % nlon_in; + // value of (wi + pscale*wo) < (Wi + (Wi/Wo)*Wo) = 2*Wi + // so we can replace the modulo with: + int wip = wi + pscale*wo; + if (wip >= nlon_in) wip -= nlon_in; + + const float *_x = x + int64_t(hi)*nlon_in*nchan_in + int64_t(wip)*nchan_in; + + // if BDIM_X is not a multiple of K then "i*(j*BDIM_X) % K = f(i,j)", + // so the mod need to be recomputed at each iteration of update the update loop + for(int chan = tidx; chan < nchan_in*K; chan += BDIM_X) { // no. of vectors in nchan_in*K dim on intermediate out + + const int iDivK = chan / K; + const int iModK = chan - (iDivK*K); + + shy[chan] = __vadd(shy[chan], + __vmul(vals[iModK], + __vset(_x[iDivK]))); + } + + vals += K; + } + return; +} + +template // either float or float4 +__global__ +__launch_bounds__(BDIM) +void s2_disco_fwd_generic_vec_k(int nchan_in, // no. of input float (not FLOATV_T!) elements along channel dim + int nlat_in, + int nlon_in, + int nlat_out, + int nlon_out, + int pscale, + int K, // no. of output FLOATV_T elem along K dim (kernel size) + const float *__restrict__ x, + const int64_t csr_nrow, + const int32_t *__restrict__ row_sort, + const int64_t *__restrict__ row_off, + const int64_t *__restrict__ row_idx, + const int64_t *__restrict__ col_idx, + const FLOATV_T *__restrict__ val_pck, + FLOATV_T *__restrict__ y) { + + const int tidx = threadIdx.x; + + const int batch = blockIdx.y; + const int ctaid = blockIdx.x*blockDim.y + threadIdx.y; + + if (ctaid >= csr_nrow*nlon_out) { + return; + } + + const int h = ctaid / nlon_out; + const int wo = ctaid - (h*nlon_out); + + // set csr_row to "h" to bypass the row sorting + const int csr_row = row_sort[h]; // h + + const int64_t rbeg = row_off[csr_row ]; + const int64_t rend = row_off[csr_row+1]; + + const int ho = row_idx[rbeg]; // reads only the first "nrow" rows of row_idx and only the first element of each row + + const int nchan_out = nchan_in*K; + + extern __shared__ __align__(sizeof(float4)) float shext[]; + FLOATV_T *shy = reinterpret_cast(shext) + threadIdx.y*nchan_out; + + for(int chan = tidx; chan < nchan_out; chan += WARP_SIZE) { + shy[chan] = __vset(0.f); + } + + x += int64_t(batch)*nlat_in*nlon_in*nchan_in; + y += int64_t(batch)*nlat_out*nlon_out*nchan_out + int64_t(ho)*nlon_out*nchan_out + int64_t(wo)*nchan_out; + + col_idx += rbeg; + val_pck += rbeg*K; // val_pck CSR contains K values per element + + const int rlen = rend-rbeg; + + // check if BDIM_X is a multiple of K; since BDIM_X is a power of 2, check if K is also a power of two + if (!(K & K-1)) { processCSR_Kpow2_shm_d(wo, rlen, nchan_in, nlon_in, pscale, K, x, col_idx, val_pck, shy); } + else { processCSR_Kanyv_shm_d(wo, rlen, nchan_in, nlon_in, pscale, K, x, col_idx, val_pck, shy); } + + for(int chan = tidx; chan < nchan_out; chan += WARP_SIZE) { + y[chan] = shy[chan]; + } + + return; +} + +template +__device__ void processCSR_Kpow2_reg_d(const int wo, + const int rlen, + const int nchan_in, // no. of input floats (not FLOATV_T!) elements along channel dim + const int nlon_in, + const int pscale, + const int K, // kernel size + const float *__restrict__ x, + const int64_t *__restrict__ cols, + const FLOATV_T *__restrict__ vals, + const float *(&shXOff)[BDIM_X+SHPAD], + FLOATV_T (&locy)[NLOC]) { + + constexpr int NLOC_M1 = NLOC-1; + + const int tidx = threadIdx.x; + + const int log2_K = __ffs(K)-1; + + const int tidxDivK = tidx >> log2_K; + const int tidxModK = tidx & (K-1); + + cols += tidx; + vals += tidxModK; + + const int BDIM_XdivK = BDIM_X >> log2_K; + + for(int off = 0; off < rlen; off++) { + + if ((off % BDIM_X) == 0) { + __sync(); + + const int64_t col = (off+tidx < rlen) ? cols[0] : 0; + + const int hi = col / nlon_in; + const int wi = col - (hi*nlon_in); + + //const int wip = (wi + pscale*wo) % nlon_in; + // value of (wi + pscale*wo) < (Wi + (Wi/Wo)*Wo) = 2*Wi + // so we can replace the modulo with: + int wip = wi + pscale*wo; + if (wip >= nlon_in) wip -= nlon_in; + + shXOff[tidx] = x + int64_t(hi)*nlon_in*nchan_in + int64_t(wip)*nchan_in; + cols += BDIM_X; + + __sync(); + } + + const float *_x = shXOff[off % BDIM_X] + tidxDivK; + + // if BDIM_X is a multiple of K then "i*(j*BDIM_X) % K = const", + // so thread "i" only needs to read vals[off*K + (i % K)] to update the + // whole channel array + + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + locy[i] = __vadd(locy[i], + __vmul(vals[0], + __vset(_x[i*BDIM_XdivK]))); + } + if (NLOC_M1*BDIM_X+tidx < nchan_in*K) { + locy[NLOC_M1] = __vadd(locy[NLOC_M1], + __vmul(vals[0], + __vset(_x[NLOC_M1*BDIM_XdivK]))); + } + + vals += K; + } + return; +} + +template +__device__ void processCSR_Kanyv_reg_d(const int wo, + const int rlen, + const int nchan_in, // no. of input floats (not FLOATV_T!) elements along channel dim + const int nlon_in, + const int pscale, + const int K, // kernel size + const float *__restrict__ x, + const int64_t *__restrict__ cols, + const FLOATV_T *__restrict__ vals, + const float *(&shXOff)[BDIM_X+SHPAD], + FLOATV_T (&locy)[NLOC]) { + + constexpr int NLOC_M1 = NLOC-1; + + const int tidx = threadIdx.x; + + cols += tidx; + + for(int off = 0; off < rlen; off++) { + + if ((off % BDIM_X) == 0) { + __sync(); + + const int64_t col = (off+tidx < rlen) ? cols[0] : 0; + + const int hi = col / nlon_in; + const int wi = col - (hi*nlon_in); + + //const int wip = (wi + pscale*wo) % nlon_in; + // value of (wi + pscale*wo) < (Wi + (Wi/Wo)*Wo) = 2*Wi + // so we can replace the modulo with: + int wip = wi + pscale*wo; + if (wip >= nlon_in) wip -= nlon_in; + + shXOff[tidx] = x + int64_t(hi)*nlon_in*nchan_in + int64_t(wip)*nchan_in; + cols += BDIM_X; + + __sync(); + } + + const float *_x = shXOff[off % BDIM_X]; + + // if BDIM_X is not a multiple of K then "i*(j*BDIM_X) % K = f(i,j)", + // so the mod need to be recomputed at each iteration of update the update loop + + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + + const int chan = i*BDIM_X+tidx; + const int iDivK = chan / K; + const int iModK = chan - (iDivK*K); + + const FLOATV_T vval = vals[iModK]; //vals[off*K + iModK]; + const FLOATV_T xval = __vset(_x[iDivK]); + + locy[i] = __vadd(locy[i], __vmul(vval, xval)); + } + if (NLOC_M1*BDIM_X+tidx < nchan_in*K) { + + const int chan = NLOC_M1*BDIM_X+tidx; + const int iDivK = chan / K; + const int iModK = chan - (iDivK*K); + + const FLOATV_T vval = vals[iModK]; //vals[off*K + iModK]; + const FLOATV_T xval = __vset(_x[iDivK]); + + locy[NLOC_M1] = __vadd(locy[NLOC_M1], __vmul(vval, xval)); + } + + vals += K; + } + return; +} + +template // either float or float4 +__global__ +__launch_bounds__(BDIM_X*BDIM_Y) +void s2_disco_fwd_special_vec_k(const int nchan_in, // no. of input float (not FLOATV_T!) elements along channel dim + const int nlat_in, + const int nlon_in, + const int nlat_out, + const int nlon_out, + const int pscale, + const int K, // no. of output FLOATV_T elem along K dim (kernel size) + const float *__restrict__ x, + const int64_t csr_nrow, + const int32_t *__restrict__ row_sort, + const int64_t *__restrict__ row_off, + const int64_t *__restrict__ row_idx, + const int64_t *__restrict__ col_idx, + const FLOATV_T *__restrict__ val_pck, + FLOATV_T *__restrict__ y) { + + static_assert(0 == (BDIM_X & (BDIM_X-1))); + static_assert(0 == (BDIM_Y & (BDIM_Y-1))); + static_assert((BDIM_X <= 32 && BDIM_Y > 1) || + (BDIM_X > 32 && BDIM_Y == 1)) ; + + constexpr int NLOC_M1 = NLOC-1; + + + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + + const int batch = blockIdx.y; + const int ctaid = blockIdx.x*blockDim.y + threadIdx.y; + + if (ctaid >= csr_nrow*nlon_out) { + return; + } + + const int h = ctaid / nlon_out; + const int wo = ctaid - (h*nlon_out); + + // set csr_row to "h" to bypass the row sorting + const int csr_row = row_sort[h]; // h + + const int64_t rbeg = row_off[csr_row ]; + const int64_t rend = row_off[csr_row+1]; + + const int ho = row_idx[rbeg]; // reads only the first "nrow" rows of row_idx and only the first element of each row + + const int nchan_out = nchan_in*K; + + FLOATV_T locy[NLOC]; + + x += int64_t(batch)*nlat_in*nlon_in*nchan_in; + y += int64_t(batch)*nlat_out*nlon_out*nchan_out + int64_t(ho)*nlon_out*nchan_out + int64_t(wo)*nchan_out + tidx; + + #pragma unroll + for(int i = 0; i < NLOC; i++) { + locy[i] = __vset(0.f); + } + + col_idx += rbeg; + val_pck += rbeg*K; // val_pck CSR contains K values per element + + const int rlen = rend-rbeg; + + constexpr int PAD = (BDIM_X < WARP_SIZE) ? 1 : 0; + __shared__ const float *shXOffAll[BDIM_Y][BDIM_X+PAD]; + + // check if BDIM_X is a multiple of K; since BDIM_X is a power of 2, check if K is also a power of two + const int isKpow2 = !(K & (K-1)); + if (isKpow2) { processCSR_Kpow2_reg_d(wo, rlen, nchan_in, nlon_in, pscale, K, x, col_idx, val_pck, shXOffAll[tidy], locy); } + else { processCSR_Kanyv_reg_d(wo, rlen, nchan_in, nlon_in, pscale, K, x, col_idx, val_pck, shXOffAll[tidy], locy); } + + + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + y[i*BDIM_X] = locy[i]; + } + if (NLOC_M1*BDIM_X+tidx < nchan_out) { + y[NLOC_M1*BDIM_X] = locy[NLOC_M1]; + } + + return; +} + +template +void launch_gen_disco_fwd(int64_t batch_size, + int64_t nchan_in, + int64_t nlat_in, + int64_t nlon_in, + int64_t nlat_out, + int64_t nlon_out, + int64_t K, + float *__restrict__ _xp, + int64_t nrow, + int32_t *_row_sort, + int64_t *_row_off, + int64_t *_row_idx, + int64_t *_col_idx, + FLOATV_T *_val_pck, + FLOATV_T *__restrict__ _yp, + cudaStream_t stream) { + + dim3 block(WARP_SIZE, THREADS/WARP_SIZE); + dim3 grid(DIV_UP(nrow*nlon_out, block.y), batch_size); + + size_t shsize = sizeof(FLOATV_T)*(nchan_in*K)*block.y; + + const int pscale = nlon_in / nlon_out; +#if 1 + printf("Launching s2_disco_fwd_generic_vec_k<%d, float%s><<<..., ..., %zu, ...>>> with:\n" + "\tnchan_in: %ld\n" + "\tK: %ld\n" + "\tpscale: %d\n\n", + THREADS, sizeof(FLOATV_T)==16?"4":"", shsize, nchan_in, K, pscale); +#endif + // will use only the first 1/K-th of the CSR, i.e. only the first nlat_out rows + s2_disco_fwd_generic_vec_k + <<>>(nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, pscale, K, + _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp); + CHECK_ERROR("s2_disco_fwd_generic_vec_k"); + + return; +} + +template +void launch_spc_disco_fwd(int nloc, // "BDIM_X*nloc" >= nchans + int64_t batch_size, + int64_t nchan_in, + int64_t nlat_in, + int64_t nlon_in, + int64_t nlat_out, + int64_t nlon_out, + int64_t K, + float *__restrict__ _xp, + int64_t nrow, + int32_t *_row_sort, + int64_t *_row_off, + int64_t *_row_idx, + int64_t *_col_idx, + FLOATV_T *_val_pck, + FLOATV_T *__restrict__ _yp, + cudaStream_t stream) { + + if (CUR_LOC_SIZE == nloc) { + + constexpr int BDIM_Y = (BDIM_X <= WARP_SIZE) ? THREADS / BDIM_X : 1; + + dim3 block(BDIM_X, BDIM_Y); + dim3 grid(DIV_UP(nrow*nlon_out, block.y), batch_size); + + size_t shsize = 0; //sizeof(float)*chxgrp_out * block.y; + + const int pscale = nlon_in / nlon_out; +#if 1 + printf("Launching s2_disco_fwd_special_vec_k<%d, %d, %d, float%s><<<(%d, %d), (%d, %d), ..., %zu, ...>>> with:\n" + "\tnchan_in: %ld\n" + "\tK: %ld\n" + "\tpscale: %d\n\n", + BDIM_X, BDIM_Y, CUR_LOC_SIZE, sizeof(FLOATV_T)==16?"4":"", grid.x, grid.y, block.x, block.y, shsize, nchan_in, K, pscale); +#endif + s2_disco_fwd_special_vec_k + <<>>(nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, pscale, K, + _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp); + + CHECK_ERROR("s2_disco_fwd_special_vec_k"); + + return; + } + if constexpr(CUR_LOC_SIZE < MAX_LOC_SIZE) { + launch_spc_disco_fwd(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, + K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); + } + return; +} + +static void s2_disco_fwd_dispatch(int64_t batch_size, + int64_t nchan_in, + int64_t nlat_in, + int64_t nlon_in, + int64_t nlat_out, + int64_t nlon_out, + int64_t K, + at::Tensor xP, + at::Tensor row_off, // CSR non-empty row offsets + at::Tensor row_idx, // CSR non-empty row indices + at::Tensor col_idx, // CSR non-empty col indices + at::Tensor val_dat, // CSR non-empty value data + at::Tensor yP) { + + //static_assert(0 == (MAX_LOCAL_ARR_LEN & (MAX_LOCAL_ARR_LEN-1))); + if (batch_size <= 0 || + nchan_in <= 0 || + nlon_in <= 0 || + nlat_out <= 0 || + nlon_out <= 0 || + K <= 0) { + + fprintf(stderr, + ":%s:%d: invalid value of one or more input parameters!\n", + __FILE__, __LINE__); + exit(EXIT_FAILURE); + } + + // get stream + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + // replace the K sequential CRSs in "val_dat": + // + // val_dat[ 0: nnz/K) for ker = 0 + // val_dat[nnz/K:2*nnz/K) for ker = 1 + // ... + // val_dat[nnz/K:2*nnz/K) for ker = K-1 + // + // with a packed CSR: + // + // val_dat[nnz/K][K], i.e. with a CSR where elements of the original K CSRs are packed in consecutive elements + assert(0 == (val_idx.size(0) % K)); + + int64_t nrow_csr = row_off.size(0)-1; + assert(0 == (nrow_csr % K)); + + int64_t nrow = nrow_csr / K; + + // sort row indices (ho-s) in descending order + // based on (row_off[ho+1]-row_off[ho]) + at::Tensor row_sort = sortRows(nrow, row_off, stream); + + // move into "disco_cuda_utils.cu" IF val_dat format won't be changed upstream in the call chain + int64_t val_dims[] = {val_dat.size(0)}; + auto options = torch::TensorOptions().device(val_dat.device()).dtype(val_dat.dtype()); + torch::Tensor val_pck = torch::zeros(val_dims, options); + { + dim3 block(WARP_SIZE, THREADS/WARP_SIZE); + dim3 grid(DIV_UP(nrow, block.y)); + + pack_vals_k<<>>(K, nrow, + row_off.data_ptr(), + val_dat.data_ptr(), + val_pck.data_ptr()); + } + // if K is a multiple of VEC_SIZE it will be read with vector lds + + // smallest power of two "bdimx" (>=4) s.t. bdimx*MAX_LOCAL_ARR_LEN >= nchan_in*K + int bdimx; + bdimx = DIV_UP(nchan_in*K, MAX_LOCAL_ARR_LEN); + bdimx = max(bdimx, WARP_SIZE/8); // min 4 threads per group + bdimx = next_pow2(bdimx); + + float *_xp = reinterpret_cast(xP.data_ptr()); + float *_yp = reinterpret_cast(yP.data_ptr()); + + int32_t *_row_sort = reinterpret_cast(row_sort.data_ptr()); + int64_t *_row_off = reinterpret_cast(row_off.data_ptr()); + int64_t *_row_idx = reinterpret_cast(row_idx.data_ptr()); + int64_t *_col_idx = reinterpret_cast(col_idx.data_ptr()); + float *_val_pck = reinterpret_cast(val_pck.data_ptr()); + + constexpr int VEC_SIZE = sizeof(float4) / sizeof(float); + + if (!is_aligned(_yp) || + !is_aligned(_val_pck) || + (K % VEC_SIZE) != 0) { + + const int nloc = DIV_UP(nchan_in*K, bdimx); + + // to avoid the compilation of unused template instances; + // we use a block size BDIM_X that is the smallest power of 2 + // such that BDIM_X*MAX_LOCAL_ARR_LEN >= nchan_in*K, so + // BDIM_X > 32 are used only for: + // + // (BDIM_X-1)*MAX_LOCAL_ARR_LEN < nchan <= BDIM_X*MAX_LOCAL_ARR_LEN + constexpr int MIN_LOCAL_ARR_LEN = MAX_LOCAL_ARR_LEN/2+1; + + // use 2D blocks only if 32 threads are enough + switch(bdimx) { + case 4: launch_spc_disco_fwd< 4, 1, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 8: launch_spc_disco_fwd< 8, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 16: launch_spc_disco_fwd< 16, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 32: launch_spc_disco_fwd< 32, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 64: launch_spc_disco_fwd< 64, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 128: launch_spc_disco_fwd< 128, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 256: launch_spc_disco_fwd< 256, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 512: launch_spc_disco_fwd< 512, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + case 1024: launch_spc_disco_fwd<1024, MIN_LOCAL_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + default: launch_gen_disco_fwd ( batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck, _yp, stream); break; + } + + } else { + + //float4 *_xp4 = reinterpret_cast(_xp); + float4 *_yp4 = reinterpret_cast(_yp); + + float4 *_val_pck4 = reinterpret_cast(_val_pck); + + K /= VEC_SIZE; + const int nloc = DIV_UP(nchan_in*K, bdimx); + + constexpr int MAX_LOCAL_VEC_LEN = MAX_LOCAL_ARR_LEN / VEC_SIZE; + constexpr int MIN_LOCAL_VEC_LEN = MAX_LOCAL_VEC_LEN/2+1; + + // use 2D blocks only if 32 threads are enough + switch(bdimx) { + case 4: launch_spc_disco_fwd< 4, 1, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp4, stream); break; + case 8: launch_spc_disco_fwd< 8, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp4, stream); break; + case 16: launch_spc_disco_fwd< 16, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp4, stream); break; + case 32: launch_spc_disco_fwd< 32, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp4, stream); break; + case 64: launch_spc_disco_fwd< 64, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp4, stream); break; + case 128: launch_spc_disco_fwd< 128, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp4, stream); break; + case 256: launch_spc_disco_fwd< 256, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp4, stream); break; + case 512: launch_spc_disco_fwd< 512, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp4, stream); break; + case 1024: launch_spc_disco_fwd<1024, MIN_LOCAL_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp4, stream); break; + default: launch_gen_disco_fwd ( batch_size, nchan_in, nlat_in, nlon_in, nlat_out, nlon_out, K, _xp, nrow, _row_sort, _row_off, _row_idx, _col_idx, _val_pck4, _yp4, stream); break; + } + } + return; +} +// END VERSION WITH CHANNEL-LAST WITH 2D BLOCKS, 2ND DIM IDENTIFYING CHANNLES, NO EINSUM + + torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo) { @@ -197,78 +887,59 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t CHECK_CUDA_INPUT_TENSOR(col_idx); CHECK_CUDA_INPUT_TENSOR(val); - // extract some shapes + // assume input is B, H, W, C int64_t B = inp.size(0); - int64_t C = inp.size(1); - int64_t BC = B * C; - int64_t Hi = inp.size(2); - int64_t Wi = inp.size(3); + int64_t Hi = inp.size(1); + int64_t Wi = inp.size(2); + int64_t C = inp.size(3); + //int64_t BC = B * C; int64_t nrows = roff_idx.size(0) - 1; - // allocate output - int64_t out_dims[] = {B, C, K, Ho, Wo}; - auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype()); - torch::Tensor out = torch::zeros(out_dims, options); + // rename dimensions consistent with attention + int64_t batch_size = B; + int64_t nchan = C; + int64_t nlat_in = Hi; + int64_t nlon_in = Wi; + int64_t nlat_out = Ho; + int64_t nlon_out = Wo; // get stream auto stream = at::cuda::getCurrentCUDAStream().stream(); - // assert - static_assert(0 == (ELXTH_MAX % 2)); - - // pick the correct launch config - if (Wo <= 64 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { - launch_kernel<64, 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); - } else if (Wo <= 128 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { - launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); - } else if (Wo <= 256 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { - launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); - } else if (Wo <= 512 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { - launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); - } else if (Wo <= 1024 * ELXTH_MAX) { - AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { - launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>( - BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr(), - ker_idx.data_ptr(), row_idx.data_ptr(), - col_idx.data_ptr(), val.data_ptr(), - inp.data_ptr(), out.data_ptr(), stream); - })); - } else { - fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo, - 1024 * ELXTH_MAX); - exit(EXIT_FAILURE); - } + // switch to channel-last + // version with fused enisum + + auto x_type = inp.dtype(); + auto xP = inp.to(torch::kFloat32).contiguous(); + + // to test before fusion + int64_t out_dims[] = {batch_size, nlat_out, nlon_out, nchan*K}; + //auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype()); + torch::Tensor yP = torch::zeros(out_dims, xP.options()); + + // call channel-last kernel implementation + s2_disco_fwd_dispatch(batch_size, + nchan, + nlat_in, + nlon_in, + nlat_out, + nlon_out, + K, + xP, + roff_idx, + row_idx, + col_idx, + val, + yP); + + auto out = yP.reshape({batch_size, nlat_out, nlon_out, nchan, K}).to(x_type); return out; } - TORCH_LIBRARY_IMPL(disco_kernels, CUDA, m) - { - m.impl("forward", &disco_cuda_fwd); - } +TORCH_LIBRARY_IMPL(disco_kernels, CUDA, m) +{ + m.impl("forward", &disco_cuda_fwd); +} +} -} \ No newline at end of file diff --git a/torch_harmonics/disco/csrc/disco_helpers.cpp b/torch_harmonics/disco/csrc/disco_helpers.cpp index 1737cec2..67a4a583 100644 --- a/torch_harmonics/disco/csrc/disco_helpers.cpp +++ b/torch_harmonics/disco/csrc/disco_helpers.cpp @@ -28,9 +28,11 @@ // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -#include "disco.h" #include +#include "disco.h" +#include "cppmacro.h" + template void preprocess_psi_kernel(int64_t nnz, int64_t K, int64_t Ho, int64_t *ker_h, int64_t *row_h, int64_t *col_h, int64_t *roff_h, REAL_T *val_h, int64_t &nrows) diff --git a/torch_harmonics/disco/csrc/disco_interface.cpp b/torch_harmonics/disco/csrc/disco_interface.cpp index 779616f4..11bf0249 100644 --- a/torch_harmonics/disco/csrc/disco_interface.cpp +++ b/torch_harmonics/disco/csrc/disco_interface.cpp @@ -54,8 +54,8 @@ namespace disco_kernels { // Declare the operators TORCH_LIBRARY(disco_kernels, m) { - m.def("forward(Tensor inp, Tensor roff_idx, Tensor ker_idx, Tensor row_idx, Tensor col_idx, Tensor vals, int kernel_size, int nlat_out, int nlon_out) -> Tensor", {at::Tag::pt2_compliant_tag}); - m.def("backward(Tensor inp, Tensor roff_idx, Tensor ker_idx, Tensor row_idx, Tensor col_idx, Tensor vals, int kernel_size, int nlat_out, int nlon_out) -> Tensor", {at::Tag::pt2_compliant_tag}); + m.def("forward(Tensor inp, Tensor roff_idx, Tensor ker_idx, Tensor row_idx, Tensor col_idx, Tensor vals, int kernel_size, int nlat_out, int nlon_out) -> Tensor"); //, {at::Tag::pt2_compliant_tag}); + m.def("backward(Tensor inp, Tensor roff_idx, Tensor ker_idx, Tensor row_idx, Tensor col_idx, Tensor vals, int kernel_size, int nlat_out, int nlon_out) -> Tensor"); //, {at::Tag::pt2_compliant_tag}); } } diff --git a/torch_harmonics/distributed/distributed_convolution.py b/torch_harmonics/distributed/distributed_convolution.py index 13d7cbb4..d0b5cd41 100644 --- a/torch_harmonics/distributed/distributed_convolution.py +++ b/torch_harmonics/distributed/distributed_convolution.py @@ -29,19 +29,15 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -from typing import List, Tuple, Union, Optional +from typing import Tuple, Union, Optional from itertools import accumulate import torch -import torch.nn as nn -from functools import partial - -from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes from torch_harmonics.disco._disco_utils import _get_psi, _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch from torch_harmonics.disco._disco_utils import _disco_s2_contraction_optimized, _disco_s2_transpose_contraction_optimized -from torch_harmonics.filter_basis import get_filter_basis -from disco_helpers import optimized_kernels_is_available, preprocess_psi +from torch_harmonics.utils import permute_to_0231, permute_to_0312 +from disco_helpers import preprocess_psi from torch_harmonics.disco.convolution import ( _precompute_convolution_tensor_s2, DiscreteContinuousConv, @@ -49,17 +45,16 @@ # distirbuted stuff from torch_harmonics.distributed import polar_group_size, azimuth_group_size -from torch_harmonics.distributed import distributed_transpose_azimuth, distributed_transpose_polar +from torch_harmonics.distributed import distributed_transpose_azimuth from torch_harmonics.distributed import reduce_from_polar_region, scatter_to_polar_region, gather_from_polar_region, copy_to_polar_region from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank -from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim +from torch_harmonics.distributed import compute_split_shapes def _split_distributed_convolution_tensor_s2( idx: torch.Tensor, vals: torch.Tensor, in_shape: Tuple[int], - out_shape: Tuple[int], ): """ Splits a pre-computed convolution tensor along the latitude dimension for distributed processing. @@ -76,8 +71,6 @@ def _split_distributed_convolution_tensor_s2( Values of the pre-computed convolution tensor in_shape: Tuple[int] Shape of the input tensor (nlat_in, nlon_in) - out_shape: Tuple[int] - Shape of the output tensor (nlat_out, nlon_out) Returns ------- @@ -88,7 +81,6 @@ def _split_distributed_convolution_tensor_s2( """ nlat_in, nlon_in = in_shape - nlat_out, nlon_out = out_shape comm_size_polar = polar_group_size() comm_rank_polar = polar_group_rank() @@ -216,7 +208,7 @@ def __init__( ) # split the convolution tensor along latitude - idx, vals = _split_distributed_convolution_tensor_s2(idx, vals, in_shape, out_shape) + idx, vals = _split_distributed_convolution_tensor_s2(idx, vals, in_shape) # sort the values ker_idx = idx[0, ...].contiguous() @@ -226,7 +218,7 @@ def __init__( if self.optimized_kernel: # preprocessed data-structure for GPU kernel - roff_idx = preprocess_psi(self.kernel_size, self.nlat_out_local, ker_idx, row_idx, col_idx, vals).contiguous() + roff_idx = preprocess_psi(self.kernel_size, self.nlat_out, ker_idx, row_idx, col_idx, vals) self.register_buffer("psi_roff_idx", roff_idx, persistent=False) # save all datastructures @@ -240,7 +232,7 @@ def __init__( self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, self.nlat_in_local, self.nlat_out_local) def extra_repr(self): - return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" + return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize_in * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" @property def psi_idx(self): @@ -256,28 +248,61 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes) if self.optimized_kernel: + # permute input: B, C, Hi, Wi -> B, Hi, Wi, C + xp = permute_to_0231(x) + + # disco contraction: B, Hi, Wi, C -> B, Ho, Wo, C, K x = _disco_s2_contraction_optimized( - x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out_local, self.nlon_out + xp, + self.psi_roff_idx, + self.psi_ker_idx, + self.psi_row_idx, + self.psi_col_idx, + self.psi_vals, + self.kernel_size, + self.nlat_out, + self.nlon_out ) + + # extract shapes + polar_dim = -4 + azimuth_dim = -3 + chan_dim = -2 + else: + # disco contraction: B, C, Hi, Wi -> B, C, K, Ho, Wo x = _disco_s2_contraction_torch(x, self.psi.to(x.device), self.nlon_out) + # extract shapes + polar_dim = -2 + azimuth_dim = -1 + chan_dim = -4 + # perform reduce scatter in polar region x = reduce_from_polar_region(x) - x = scatter_to_polar_region(x, -2) + x = scatter_to_polar_region(x, polar_dim) # now we can transpose back the result, so that lon is split and channels are local if self.comm_size_azimuth > 1: chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth) - x = distributed_transpose_azimuth.apply(x, (-1, 1), chan_shapes) + x = distributed_transpose_azimuth.apply(x, (azimuth_dim, chan_dim), chan_shapes) # extract shape - B, C, K, H, W = x.shape - x = x.reshape(B, self.groups, self.groupsize, K, H, W) - - # do weight multiplication - out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous() - out = out.reshape(out.shape[0], -1, H, W) + if self.optimized_kernel: + # weight multiplication + B, H, W, _, K = x.shape + x = x.reshape(B, H, W, self.groups, self.groupsize_in, K) + outp = torch.einsum("bxygck,gock->bxygo", x, self.weight.reshape(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size)) + outp = outp.reshape(B, H, W, -1).contiguous() + + # permute output + out = permute_to_0312(outp) + else: + # weight multiplication + B, _, K, H, W = x.shape + x = x.reshape(B, self.groups, self.groupsize_in, K, H, W) + out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size)) + out = out.reshape(B, -1, H, W).contiguous() if self.bias is not None: out = out + self.bias.reshape(1, -1, 1, 1) @@ -393,7 +418,7 @@ def __init__( # split the convolution tensor along latitude, again, we need to swap the meaning # of in_shape and out_shape - idx, vals = _split_distributed_convolution_tensor_s2(idx, vals, out_shape, in_shape) + idx, vals = _split_distributed_convolution_tensor_s2(idx, vals, out_shape) # sort the values ker_idx = idx[0, ...].contiguous() @@ -403,7 +428,7 @@ def __init__( if self.optimized_kernel: # preprocessed data-structure for GPU kernel - roff_idx = preprocess_psi(self.kernel_size, self.nlat_in_local, ker_idx, row_idx, col_idx, vals).contiguous() + roff_idx = preprocess_psi(self.kernel_size, self.nlat_in, ker_idx, row_idx, col_idx, vals).contiguous() self.register_buffer("psi_roff_idx", roff_idx, persistent=False) # save all datastructures @@ -417,7 +442,7 @@ def __init__( self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, self.nlat_in_local, self.nlat_out_local, semi_transposed=True) def extra_repr(self): - return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" + return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize_in * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}" @property def psi_idx(self): @@ -426,27 +451,60 @@ def psi_idx(self): def forward(self, x: torch.Tensor) -> torch.Tensor: # extract shape - B, C, H, W = x.shape - x = x.reshape(B, self.groups, self.groupsize, H, W) + B, _, H, W = x.shape - # do weight multiplication - x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous() - x = x.reshape(B, -1, x.shape[-3], H, W) - num_chans = x.shape[1] + if self.optimized_kernel: + # permute input + xp = permute_to_0231(x) + + # weight multiplication + xp = xp.reshape(B, H, W, self.groups, self.groupsize_in) + x = torch.einsum("bxygc,gock->bxygok", xp, self.weight.reshape(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size)) + x = x.reshape(B, H, W, -1, self.kernel_size).contiguous() + # count from front since this does not change + # after disco conv + polar_dim = 1 + azimuth_dim = 2 + chan_dim = 3 + else: + # weight multiplication + x = x.reshape(B, self.groups, self.groupsize_in, H, W) + x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size)) + x = x.reshape(B, -1, self.kernel_size, H, W).contiguous() + # count from back since this changes after disco transpose conv + polar_dim = -2 + azimuth_dim = -1 + chan_dim = 1 + + # store number of channels + num_chans = x.shape[chan_dim] # transpose such that lon is local, channels are split if self.comm_size_azimuth > 1: - x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes) + x = distributed_transpose_azimuth.apply(x, (chan_dim, azimuth_dim), self.lon_in_shapes) # gather input tensor and set up backward reduction hooks - x = gather_from_polar_region(x, -2, self.lat_in_shapes) + x = gather_from_polar_region(x, polar_dim, self.lat_in_shapes) x = copy_to_polar_region(x) if self.optimized_kernel: - out = _disco_s2_transpose_contraction_optimized( - x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out_local, self.nlon_out + # disco contraction + outp = _disco_s2_transpose_contraction_optimized( + x, + self.psi_roff_idx, + self.psi_ker_idx, + self.psi_row_idx, + self.psi_col_idx, + self.psi_vals, + self.kernel_size, + self.nlat_out_local, + self.nlon_out ) + + # permute output + out = permute_to_0312(outp) else: + # disco contraction out = _disco_s2_transpose_contraction_torch(x, self.psi_st.to(x.device), self.nlon_out) # now we can transpose back the result, so that lon is split and channels are local diff --git a/torch_harmonics/filter_basis.py b/torch_harmonics/filter_basis.py index e7163d4e..086518a7 100644 --- a/torch_harmonics/filter_basis.py +++ b/torch_harmonics/filter_basis.py @@ -131,8 +131,11 @@ def _compute_support_vals_isotropic(self, r: torch.Tensor, phi: torch.Tensor, r_ ir = (ikernel + 0.5) * dr # find the indices where the rotated position falls into the support of the kernel - iidx = torch.argwhere(((r - ir).abs() <= dr) & (r <= r_cutoff)) + #iidx = torch.argwhere((r.abs() <= dr) & (r <= r_cutoff)) + #vals = 1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr + iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool, device=r.device)) vals = 1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr + vals = torch.clamp(vals, min=0.0) return iidx, vals @@ -155,11 +158,14 @@ def _compute_support_vals_anisotropic(self, r: torch.Tensor, phi: torch.Tensor, ir = (ikernel // nphi + 0.5) * dr iphi = (ikernel % nphi) * dphi - math.pi + #cond_r = ((r - ir.max()).abs() <= dr) & (r <= r_cutoff) + cond_r = (r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool, device=r.device) + # find the indices where the rotated position falls into the support of the kernel if nr % 2 == 1: # find the support - cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff) - cond_phi = (ikernel == 0) | (_circle_dist(phi, iphi).abs() <= dphi) + cond_phi = (_circle_dist(phi, iphi.max()).abs() <= dphi) + #print(_circle_dist(phi, iphi).abs() <= dphi) # find indices where conditions are met iidx = torch.argwhere(cond_r & cond_phi) # compute the distance to the collocation points @@ -169,27 +175,33 @@ def _compute_support_vals_anisotropic(self, r: torch.Tensor, phi: torch.Tensor, vals = 1 - dist_r / dr vals *= torch.where((iidx[:, 0] > 0), (1 - dist_phi / dphi), 1.0) + # clamp values + vals = torch.clamp(vals, min=0.0) else: # in the even case, the inner basis functions overlap into areas with a negative areas rn = -r phin = torch.where(phi + math.pi >= math.pi, phi - math.pi, phi + math.pi) - # find the support - cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff) - cond_phi = _circle_dist(phi, iphi).abs() <= dphi - cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff) - cond_phin = _circle_dist(phin, iphi) <= dphi + cond_phi = _circle_dist(phi, iphi.max()).abs() <= dphi + #cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff) + cond_rn = (rn.abs() <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool, device=rn.device) + cond_phin = _circle_dist(phin, iphi.max()) <= dphi # find indices where conditions are met iidx = torch.argwhere((cond_r & cond_phi) | (cond_rn & cond_phin)) + #iidx = torch.argwhere(cond_r | cond_rn) # & cond_phin) dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() dist_phi = _circle_dist(phi[iidx[:, 1], iidx[:, 2]], iphi[iidx[:, 0], 0, 0]) dist_rn = (rn[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() dist_phin = _circle_dist(phin[iidx[:, 1], iidx[:, 2]], iphi[iidx[:, 0], 0, 0]) # compute the value of the basis functions - vals = cond_r[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_r / dr) - vals *= cond_phi[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_phi / dphi) - valsn = cond_rn[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_rn / dr) - valsn *= cond_phin[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_phin / dphi) + #vals = cond_r[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_r / dr) + #vals *= cond_phi[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_phi / dphi) + #valsn = cond_rn[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_rn / dr) + #valsn *= cond_phin[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_phin / dphi) + #vals += valsn + + vals = torch.clamp((1 - dist_r / dr) * (1 - dist_phi / dphi), min=0.0) + valsn = torch.clamp((1 - dist_rn / dr) * (1 - dist_phin / dphi), min=0.0) vals += valsn return iidx, vals diff --git a/torch_harmonics/utils/__init__.py b/torch_harmonics/utils/__init__.py new file mode 100644 index 00000000..1acbd722 --- /dev/null +++ b/torch_harmonics/utils/__init__.py @@ -0,0 +1,45 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import warnings +import torch + +# we need those helpers +from utility_helpers import optimized_kernels_is_available + +if optimized_kernels_is_available(): + from . import _C + from torch.ops import utility_kernels +else: + utility_kernels = None + warnings.warn("No optimized utility kernels are available. Please compile the extension first setting BUILD_CPP and BUILD_CUDA to 1.") + +from ._utils import permute_to_0231, permute_to_0312 \ No newline at end of file diff --git a/torch_harmonics/utils/_utils.py b/torch_harmonics/utils/_utils.py new file mode 100644 index 00000000..f79aa013 --- /dev/null +++ b/torch_harmonics/utils/_utils.py @@ -0,0 +1,80 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from typing import Optional + +import torch +from utility_helpers import optimized_kernels_is_available +from . import utility_kernels + +# custom kernels +if optimized_kernels_is_available(): + + # fake permutations + @torch.library.register_fake("utility_kernels::permute_to_0231") + def _(inp: torch.Tensor) -> torch.Tensor: + B, C, H, W = inp.shape + out_shape = (B, H, W, C) + return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) + + @torch.library.register_fake("utility_kernels::permute_to_0312") + def _(inp: torch.Tensor) -> torch.Tensor: + B, H, W, C = inp.shape + out_shape = (B, C, H, W) + return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) + + # autograds: shallow wrappers around the default kernels + def _permute_to_0231_bwd(ctx, grad_output): + return utility_kernels.permute_to_0312.default(grad_output) + + def _permute_to_0312_bwd(ctx, grad_output): + return utility_kernels.permute_to_0231.default(grad_output) + + torch.library.register_autograd( + "utility_kernels::permute_to_0231", _permute_to_0231_bwd) + + torch.library.register_autograd( + "utility_kernels::permute_to_0312", _permute_to_0312_bwd) + +if optimized_kernels_is_available(): + def permute_to_0231(inp: torch.Tensor) -> torch.Tensor: + return utility_kernels.permute_to_0231.default(inp) + + def permute_to_0312(inp: torch.Tensor) -> torch.Tensor: + return utility_kernels.permute_to_0312.default(inp) +else: + def permute_to_0231(inp: torch.Tensor) -> torch.Tensor: + return inp.permute(0, 2, 3, 1).contiguous() + + def permute_to_0312(inp: torch.Tensor) -> torch.Tensor: + return inp.permute(0, 3, 1, 2).contiguous() + + diff --git a/torch_harmonics/utils/csrc/cppmacro.h b/torch_harmonics/utils/csrc/cppmacro.h new file mode 100644 index 00000000..3dccace8 --- /dev/null +++ b/torch_harmonics/utils/csrc/cppmacro.h @@ -0,0 +1,40 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include + +#define CHECK_CPU_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCPU, #x " must be on CPU") +#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_INTERNAL_ASSERT(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT_TENSOR(x) CHECK_CONTIGUOUS_TENSOR(x) +#define CHECK_CPU_INPUT_TENSOR(x) \ + CHECK_CPU_TENSOR(x); \ + CHECK_CONTIGUOUS_TENSOR(x) diff --git a/torch_harmonics/utils/csrc/csr_cuda.cu b/torch_harmonics/utils/csrc/csr_cuda.cu new file mode 100644 index 00000000..ba0edfde --- /dev/null +++ b/torch_harmonics/utils/csrc/csr_cuda.cu @@ -0,0 +1,124 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include + +#include + +#include +#include + +#include "cudamacro.h" +#include "csr_cuda.cuh" + +#define THREADS (64) + +#define TRANSP_WARPS_X_TILE_GENERIC (32) +#define TRANSP_WARPS_X_TILE_SM100 (4) + +namespace utility_kernels { + +// BEGIN - CSR rows sorting kernels and functions +__global__ void set_rlen_rids_k(const int n, + const int64_t *__restrict__ offs, + int *__restrict__ rids, + int *__restrict__ rlen) { + + const int nth = gridDim.x*blockDim.x; + const int tid = blockIdx.x*blockDim.x + threadIdx.x; + + for(int i = tid; i < n; i += nth) { + rids[i] = i; + rlen[i] = offs[i+1]-offs[i]; + } + + return; +} + +torch::Tensor sortRows(int nlat_out, torch::Tensor row_off, cudaStream_t stream) { + + int64_t *_row_off_d = reinterpret_cast(row_off.data_ptr()); + + auto options = torch::TensorOptions().dtype(torch::kInt32).device(row_off.device()); + + torch::Tensor rids_d = torch::empty({nlat_out}, options); + torch::Tensor rlen_d = torch::empty({nlat_out}, options); + + int *_rids_d = reinterpret_cast(rids_d.data_ptr()); + int *_rlen_d = reinterpret_cast(rlen_d.data_ptr()); + + const int grid = DIV_UP(nlat_out, THREADS); + const int block = THREADS; + + set_rlen_rids_k<<>>(nlat_out, + _row_off_d, + _rids_d, + _rlen_d); + + torch::Tensor rids_sort_d = torch::empty({nlat_out}, options); + torch::Tensor rlen_sort_d = torch::empty({nlat_out}, options); + + int *_rids_sort_d = reinterpret_cast(rids_sort_d.data_ptr()); + int *_rlen_sort_d = reinterpret_cast(rlen_sort_d.data_ptr()); + + size_t temp_storage_bytes = 0; + CHECK_CUDA(cub::DeviceRadixSort::SortPairsDescending(NULL, temp_storage_bytes, + _rlen_d, _rlen_sort_d, + _rids_d, _rids_sort_d, + nlat_out, 0, sizeof(*_rlen_d)*8, stream)); + + options = torch::TensorOptions().dtype(torch::kByte).device(row_off.device()); + torch::Tensor temp_storage_d = torch::empty({int64_t(temp_storage_bytes)}, options); + + void *_temp_storage_d = reinterpret_cast(temp_storage_d.data_ptr()); + + CHECK_CUDA(cub::DeviceRadixSort::SortPairsDescending(_temp_storage_d, temp_storage_bytes, + _rlen_d, _rlen_sort_d, + _rids_d, _rids_sort_d, + nlat_out, 0, sizeof(*_rlen_d)*8, stream)); + return rids_sort_d; +} +// END - CSR rows sorting kernels and functions + + +// BEGIN - general host-side functions +unsigned int next_pow2(unsigned int x) { + + x -= 1; + + #pragma unroll + for(int i = 1; i <= sizeof(x)*8 / 2; i *= 2) { + x |= x >> i; + } + return x+1; +} +// END - general host-side functions + +} diff --git a/torch_harmonics/utils/csrc/csr_cuda.cuh b/torch_harmonics/utils/csrc/csr_cuda.cuh new file mode 100644 index 00000000..a9bf6cca --- /dev/null +++ b/torch_harmonics/utils/csrc/csr_cuda.cuh @@ -0,0 +1,223 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include +#include +#include +#include + +#include "cudamacro.h" + +#define WARP_SIZE (32) + + +namespace utility_kernels { + +// CSR rows sorting kernels and functions +torch::Tensor sortRows(int nlat_out, torch::Tensor row_off, cudaStream_t stream); + + +// Host tensor dump and CSR manipulation functions +void dump_tensor(const char *fname, torch::Tensor t); +void dump_csr(const char *fname, torch::Tensor roff, torch::Tensor cols); + +int part_csr_rows(int *row_perm, + const torch::Tensor roff, + const torch::Tensor cols, + int **part_off, + int **part_val); + +int verify_part(const int npart, + const int *part_off, + const int *part_val, + const torch::Tensor roff, + const torch::Tensor cols); + +void verify_part_new(const int nlon_out, + const int nlat_in, + const int nlon_in, + const int npart, // partitioning data + const int *part_off, + const int *part_val, + const torch::Tensor roff, + const torch::Tensor cols); + +unsigned int next_pow2(unsigned int x); + + +// utility host functions and templates + +template +int is_aligned(const void *ptr) { + + static_assert(0 == (ALIGN & (ALIGN-1))); + return (0 == (uintptr_t(ptr) & (ALIGN-1))); +} + + +// utility device functions and templates + +template +__device__ FLOATV_T __vset(float x) { + static_assert(sizeof(FLOATV_T) == 0, "Unsupported type for __vset"); + return FLOATV_T{}; +} + +template<> +__device__ float __forceinline__ __vset(float x) { + return x; +} + +__device__ float __forceinline__ __vmul(float a, float b) { + return a*b; +} + +__device__ float __forceinline__ __vadd(float a, float b) { + return a+b; +} + +__device__ float __forceinline__ __vsub(float a, float b) { + return a-b; +} + +__device__ float __forceinline__ __vred(float a) { + return a; +} + +__device__ float __forceinline__ __vscale(float s, float v) { + return v*s; +} + +__device__ float __forceinline__ __vdiv(float s, float v) { + return v/s; +} + +template<> +__device__ float4 __forceinline__ __vset(float x) { + return make_float4(x, x, x, x); +} + +__device__ float4 __forceinline__ __vmul(float4 a, float4 b) { + return make_float4(a.x*b.x, a.y*b.y, a.z*b.z, a.w*b.w); +} + +__device__ float4 __forceinline__ __vadd(float4 a, float4 b) { + return make_float4(a.x+b.x, a.y+b.y, a.z+b.z, a.w+b.w); +} + +__device__ float4 __forceinline__ __vsub(float4 a, float4 b) { + return make_float4(a.x-b.x, a.y-b.y, a.z-b.z, a.w-b.w); +} + +__device__ float __forceinline__ __vred(float4 a) { + return a.x + a.y + a.z + a.w; +} + +__device__ float4 __forceinline__ __vscale(float s, float4 v) { + return make_float4(s*v.x, s*v.y, s*v.z, s*v.w); +} + +__device__ float4 __forceinline__ __vdiv(float s, float4 v) { + return make_float4(s/v.x, s/v.y, s/v.z, s/v.w);; +} + +template +static __device__ void __sync() { + + static_assert(BDIM_X > 0 && 0 == (BDIM_X & (BDIM_X-1))); + + if constexpr(BDIM_X > WARP_SIZE) { __syncthreads(); } + else if constexpr(BDIM_X == WARP_SIZE) { __syncwarp(); } + else { // BDIM_X < WARP_SIZE + constexpr unsigned int MASK = (1ull << BDIM_X)-1; + unsigned int subwarp_id = threadIdx.y % (WARP_SIZE/BDIM_X); + unsigned int subwarp_mask = MASK << (subwarp_id*BDIM_X); + __syncwarp(subwarp_mask); + } + return; +} + +template +__device__ VAL_T __warp_sum(VAL_T val) { + + #pragma unroll + for(int i = WARP_SIZE/2; i; i /= 2) { + val += __shfl_xor_sync(FULL_MASK, val, i, WARP_SIZE); + } + return val; +} + +template +__device__ VAL_T __block_sum(VAL_T val) { + + const int NWARP = (BDIM_X*BDIM_Y*BDIM_Z) / WARP_SIZE; + + val = __warp_sum(val); + + if constexpr(NWARP > 1) { + + int tid = threadIdx.x; + if constexpr(BDIM_Y > 1) { tid += threadIdx.y*BDIM_X; } + if constexpr(BDIM_Z > 1) { tid += threadIdx.z*BDIM_X*BDIM_Y; } + + const int lid = tid%WARP_SIZE; + const int wid = tid/WARP_SIZE; + + __shared__ VAL_T sh[NWARP]; + + if (lid == 0) { + sh[wid] = val; + } + __syncthreads(); + + if (wid == 0) { + val = (lid < NWARP) ? sh[lid] : 0; + + val = __warp_sum(val); + __syncwarp(); + + if (!lid) { + sh[0] = val; + } + } + __syncthreads(); + + val = sh[0]; + __syncthreads(); + } + return val; +} + +} diff --git a/torch_harmonics/attention/csrc/cudamacro.h b/torch_harmonics/utils/csrc/cudamacro.h similarity index 81% rename from torch_harmonics/attention/csrc/cudamacro.h rename to torch_harmonics/utils/csrc/cudamacro.h index 0edef184..1dcb8227 100644 --- a/torch_harmonics/attention/csrc/cudamacro.h +++ b/torch_harmonics/utils/csrc/cudamacro.h @@ -30,6 +30,11 @@ #pragma once +#include + +#define DIV_UP(a,b) (((a)+((b)-1))/(b)) +#define FULL_MASK (0xFFFFFFFF) + #define CHECK_CUDA(call) { \ cudaError_t err = call; \ if( cudaSuccess != err) { \ @@ -45,3 +50,9 @@ errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) );\ exit(EXIT_FAILURE); \ }} + +#define CHECK_CUDA_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCUDA, #x " must be on GPU") +#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_INTERNAL_ASSERT(x.is_contiguous(), #x " must be contiguous") +#define CHECK_CUDA_INPUT_TENSOR(x) \ + CHECK_CUDA_TENSOR(x); \ + CHECK_CONTIGUOUS_TENSOR(x) diff --git a/torch_harmonics/attention/csrc/attention_interface.cu b/torch_harmonics/utils/csrc/permute_cpu.cpp similarity index 74% rename from torch_harmonics/attention/csrc/attention_interface.cu rename to torch_harmonics/utils/csrc/permute_cpu.cpp index b18c315c..e604b927 100644 --- a/torch_harmonics/attention/csrc/attention_interface.cu +++ b/torch_harmonics/utils/csrc/permute_cpu.cpp @@ -28,11 +28,25 @@ // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -//#include "attention.cuh" -//#include +#include +#include -//PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -//{ -// m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2"); -// m.def("backward_dkvq", &s2_attention_bwd_dkvq_cuda, "(Local) Attention gradient on S2 (gradient for k,v,&q)"); -//} +namespace utility_kernels { + +torch::Tensor permute_4D_to0231_cpu(torch::Tensor src) { + // CPU implementation using standard permute + return src.permute({0, 2, 3, 1}).contiguous(); +} + +torch::Tensor permute_4D_to0312_cpu(torch::Tensor src) { + // CPU implementation using standard permute + return src.permute({0, 3, 1, 2}).contiguous(); +} + +TORCH_LIBRARY_IMPL(utility_kernels, CPU, m) +{ + m.impl("permute_to_0231", &permute_4D_to0231_cpu); + m.impl("permute_to_0312", &permute_4D_to0312_cpu); +} + +} diff --git a/torch_harmonics/utils/csrc/permute_cuda.cu b/torch_harmonics/utils/csrc/permute_cuda.cu new file mode 100644 index 00000000..4cba299c --- /dev/null +++ b/torch_harmonics/utils/csrc/permute_cuda.cu @@ -0,0 +1,116 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include + +#include + +#include +#include + +#include "cudamacro.h" +#include "permute_cuda.cuh" + +// Define the missing macros +#define TRANSP_WARPS_X_TILE_GENERIC (32) +#define TRANSP_WARPS_X_TILE_SM100 (4) + +namespace utility_kernels { + + // BEGIN - 4D tensor permutation kernels and functions +__global__ void empty_k() {} + +static int getPtxver() { + cudaFuncAttributes attrs; + CHECK_CUDA(cudaFuncGetAttributes(&attrs, empty_k)); + return attrs.ptxVersion*10; +} + +torch::Tensor permute_4D_to0231(torch::Tensor src) { + + // make input contiguous: + auto srcc = src.contiguous(); + + auto options = torch::TensorOptions().dtype(srcc.dtype()).device(srcc.device()); + torch::Tensor dst = torch::empty({srcc.size(0), srcc.size(2), srcc.size(3), srcc.size(1)}, options); + + const int ptxv = getPtxver(); + + // to be further specialized for additional archs, if necessary + if (ptxv < 100) { + AT_DISPATCH_FLOATING_TYPES(srcc.scalar_type(), "permute_to0231_k_tile_generic", ([&] { + launch_permute_to0231(srcc, dst); + })); + CHECK_ERROR("permute_to0231_k_tile_generic"); + } else { + AT_DISPATCH_FLOATING_TYPES(srcc.scalar_type(), "permute_to0231_k_tile_sm100", ([&] { + launch_permute_to0231(srcc, dst); + })); + CHECK_ERROR("permute_to0231_k_tile_sm100"); + } + + return dst; +} + +torch::Tensor permute_4D_to0312(torch::Tensor src) { + + // make input contiguous: + auto srcc = src.contiguous(); + + auto options = torch::TensorOptions().dtype(srcc.dtype()).device(srcc.device()); + torch::Tensor dst = torch::empty({srcc.size(0), srcc.size(3), srcc.size(1), srcc.size(2)}, options); + + const int ptxv = getPtxver(); + + // to be further specialized for additional archs, if necessary + if (ptxv < 100) { + AT_DISPATCH_FLOATING_TYPES(srcc.scalar_type(), "permute_to0312_k_tile_generic", ([&] { + launch_permute_to0312(srcc, dst); + })); + CHECK_ERROR("permute_to0312_k_tile_generic"); + } else { + AT_DISPATCH_FLOATING_TYPES(srcc.scalar_type(), "permute_to0312_k_tile_sm100", ([&] { + launch_permute_to0312(srcc, dst); + })); + CHECK_ERROR("permute_to0312_k_tile_sm100"); + } + + return dst; +} + +TORCH_LIBRARY_IMPL(utility_kernels, CUDA, m) +{ + m.impl("permute_to_0231", &permute_4D_to0231); + m.impl("permute_to_0312", &permute_4D_to0312); +} + +// END - tensor permutation kernels and functions + +} \ No newline at end of file diff --git a/torch_harmonics/utils/csrc/permute_cuda.cuh b/torch_harmonics/utils/csrc/permute_cuda.cuh new file mode 100644 index 00000000..5052810b --- /dev/null +++ b/torch_harmonics/utils/csrc/permute_cuda.cuh @@ -0,0 +1,217 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include +#include +#include +#include + +#define WARP_SIZE (32) + + +namespace utility_kernels { + +// transpose utils +template +__global__ +__launch_bounds__(BDIM_X*BDIM_Y) +void permute_to0231_k(const int nchn, + const int nlat, + const int nlon, + const at::PackedTensorAccessor32 src, + at::PackedTensorAccessor32 dst) { + +static_assert(!(BDIM_X & (BDIM_X-1))); +static_assert(!(BDIM_Y & (BDIM_Y-1))); +static_assert(BDIM_X >= BDIM_Y); + +__shared__ VAL_T sh[BDIM_X][BDIM_X+1]; + +const int tidx = threadIdx.x; +const int tidy = threadIdx.y; + +const int coff = blockIdx.x*BDIM_X; // channel offset +const int woff = blockIdx.y*BDIM_X; // width offset +const int batch = blockIdx.z / nlat; // batch (same for all block) +const int h = blockIdx.z - (batch * nlat); // height (same for all block) + +const int nchn_full = (nchn-coff) >= BDIM_X; +const int nlon_full = (nlon-woff) >= BDIM_X; + +if (nchn_full && nlon_full) { +#pragma unroll +for(int j = 0; j < BDIM_X; j += BDIM_Y) { + sh[j+tidy][tidx] = src[batch][coff + j+tidy][h][woff+tidx]; +} +__syncthreads(); + +#pragma unroll +for(int j = 0; j < BDIM_X; j += BDIM_Y) { + dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy]; +} +} else { +if (woff+tidx < nlon) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + sh[j+tidy][tidx] = (coff + j+tidy < nchn) ? src[batch][coff + j+tidy][h][woff+tidx] : VAL_T(0); + } +} +__syncthreads(); + +if (coff+tidx < nchn) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + if (woff + j+tidy < nlon) { + dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy]; + } + } +} +} +return; +} + +template +void launch_permute_to0231(torch::Tensor src, torch::Tensor dst){ +dim3 block; +dim3 grid; + +block.x = WARP_SIZE; +block.y = WARPS_X_TILE; +grid.x = DIV_UP(src.size(1), block.x); +grid.y = DIV_UP(src.size(3), block.x); +grid.z = src.size(2)*src.size(0); + +assert(grid.y < 65536); +assert(grid.z < 65536); + +// get stream +auto stream = at::cuda::getCurrentCUDAStream().stream(); + +permute_to0231_k + <<>>(src.size(1), + src.size(2), + src.size(3), + src.packed_accessor32(), + dst.packed_accessor32()); +} + +template +__global__ +__launch_bounds__(BDIM_X*BDIM_Y) +void permute_to0312_k(const int nchn, + const int nlat, + const int nlon, + const at::PackedTensorAccessor32 src, + at::PackedTensorAccessor32 dst) { + +static_assert(!(BDIM_X & (BDIM_X-1))); +static_assert(!(BDIM_Y & (BDIM_Y-1))); +static_assert(BDIM_X >= BDIM_Y); + +__shared__ VAL_T sh[BDIM_X][BDIM_X+1]; + +const int tidx = threadIdx.x; +const int tidy = threadIdx.y; + +const int woff = blockIdx.x*BDIM_X; // width offset +const int coff = blockIdx.y*BDIM_X; // channel offset +const int batch = blockIdx.z / nlat; // batch (same for all block) +const int h = blockIdx.z - (batch * nlat); // height (same for all block) + +const int nchn_full = (nchn-coff) >= BDIM_X; +const int nlon_full = (nlon-woff) >= BDIM_X; + +if (nchn_full && nlon_full) { +#pragma unroll +for(int j = 0; j < BDIM_X; j += BDIM_Y) { + sh[j+tidy][tidx] = src[batch][h][woff + j+tidy][coff+tidx]; +} +__syncthreads(); + +#pragma unroll +for(int j = 0; j < BDIM_X; j += BDIM_Y) { + dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy]; +} +} else { +if (coff+tidx < nchn) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + sh[j+tidy][tidx] = (woff + j+tidy < nlon) ? src[batch][h][woff + j+tidy][coff+tidx] : VAL_T(0); + } +} +__syncthreads(); + +if (woff+tidx < nlon) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + if (coff + j+tidy < nchn) { + dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];; + } + } +} +} +return; +} + +template +void launch_permute_to0312(torch::Tensor src, torch::Tensor dst){ +dim3 block; +dim3 grid; + +block.x = WARP_SIZE; +block.y = WARPS_X_TILE; +grid.x = DIV_UP(src.size(2), block.x); +grid.y = DIV_UP(src.size(3), block.x); +grid.z = src.size(1)*src.size(0); + +assert(grid.y < 65536); +assert(grid.z < 65536); + +// get stream +auto stream = at::cuda::getCurrentCUDAStream().stream(); + +permute_to0312_k + <<>>(src.size(3), + src.size(1), + src.size(2), + src.packed_accessor32(), + dst.packed_accessor32()); +} + +torch::Tensor permute_4D_to0312(torch::Tensor src); +torch::Tensor permute_4D_to0231(torch::Tensor src); + +} \ No newline at end of file diff --git a/torch_harmonics/utils/csrc/utils_helpers.cpp b/torch_harmonics/utils/csrc/utils_helpers.cpp new file mode 100644 index 00000000..d700ef6a --- /dev/null +++ b/torch_harmonics/utils/csrc/utils_helpers.cpp @@ -0,0 +1,58 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include + +// set default values for BUILD_CPP and BUILD_CUDA +#ifndef BUILD_CPP +#define BUILD_CPP 0 +#endif + +#ifndef BUILD_CUDA +#define BUILD_CUDA 0 +#endif + +bool cpp_kernels_is_available() { + return static_cast(BUILD_CPP); +} + +bool cuda_kernels_is_available() { + return static_cast(BUILD_CUDA); +} + +bool optimized_kernels_is_available() { + return cuda_kernels_is_available() || cpp_kernels_is_available(); +} + +PYBIND11_MODULE(utility_helpers, m) +{ + m.def("optimized_kernels_is_available", &optimized_kernels_is_available, "Check if optimized kernels (CUDA or C++) are available."); +} + diff --git a/torch_harmonics/utils/csrc/utils_interface.cpp b/torch_harmonics/utils/csrc/utils_interface.cpp new file mode 100644 index 00000000..d7d901b2 --- /dev/null +++ b/torch_harmonics/utils/csrc/utils_interface.cpp @@ -0,0 +1,62 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include +#include + +extern "C" { + /* Creates a dummy empty _C module that can be imported from Python. + The import from Python will load the .so consisting of this file + in this extension, so that the TORCH_LIBRARY static initializers + below are run. */ + PyMODINIT_FUNC PyInit__C(void) + { + static struct PyModuleDef module_def = { + PyModuleDef_HEAD_INIT, + "_C", /* name of module */ + NULL, /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + NULL, /* methods */ + }; + return PyModule_Create(&module_def); + } +} + +namespace utility_kernels { + + // Declare the operators + TORCH_LIBRARY(utility_kernels, m) { + m.def("permute_to_0231(Tensor inp) -> Tensor", {at::Tag::pt2_compliant_tag}); + m.def("permute_to_0312(Tensor inp) -> Tensor", {at::Tag::pt2_compliant_tag}); + } + +} +