Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c3fd859
Test for a new disco kernel using vectorized operations on channel-last
mauro-bis Aug 22, 2025
2932853
there is more work to do
azrael417 Aug 26, 2025
97daca4
dbg
azrael417 Aug 28, 2025
6509e68
compiling code
azrael417 Aug 28, 2025
6738964
almost working refactor
azrael417 Aug 28, 2025
7389a95
fixed a lot of issues
azrael417 Aug 29, 2025
da3c206
fixed fwd pass
azrael417 Aug 29, 2025
2967983
re-implementing transpose conv
azrael417 Aug 29, 2025
b0b4d44
working CPU branch
azrael417 Sep 1, 2025
544d0c7
Added preliminary version of channel-last BWD kernel.
mauro-bis Sep 3, 2025
02ba8c1
commenting out tests
azrael417 Sep 3, 2025
8068939
adding reshapes
azrael417 Sep 3, 2025
357e620
adding contiguous call before passing to kernel
azrael417 Sep 3, 2025
f208ccd
dbg
azrael417 Sep 8, 2025
27d5414
fixing some selection criterium
azrael417 Sep 8, 2025
9d3f351
removing some debug prints
azrael417 Sep 8, 2025
f45d3ce
small fix
azrael417 Sep 8, 2025
fb1e589
cleaning up attention
azrael417 Sep 9, 2025
4a1d785
fixing raw meta kernels
azrael417 Sep 9, 2025
1a7f96d
better comments
azrael417 Sep 9, 2025
1b715b3
snapshot
azrael417 Oct 6, 2025
5dcda4d
Fixed absolutie difference errors w.r.t. CPU and torch DISCO
mauro-bis Oct 8, 2025
001c28a
adding permute with contig call
azrael417 Oct 9, 2025
104bf6d
dbg
azrael417 Oct 13, 2025
455ae90
after rebase
azrael417 Oct 20, 2025
8fc8427
small fixes after rebase
azrael417 Oct 20, 2025
d9b0a2a
adding theta cutoff to test
azrael417 Oct 27, 2025
52f5253
cleanups
azrael417 Oct 28, 2025
cfc5338
removing prints
azrael417 Oct 28, 2025
4826971
further cleanup
azrael417 Oct 29, 2025
09e8ebf
Increased max no. of element per thread to 20 to both fwd and bwd and
mauro-bis Oct 30, 2025
c4ed35f
Merge branch 'maurob/disco_bwd_fix' of https://github.com/NVIDIA/torc…
mauro-bis Oct 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,31 +93,77 @@ def get_helpers_compile_args():
def get_ext_modules():
"""Get list of extension modules to compile."""

# define setup dir
setup_dir = os.path.abspath(os.path.dirname(__file__))

ext_modules = []
cmdclass = {}

print(f"Compiling helper routines for torch-harmonics.")

# Utility helpers
ext_modules.append(
CppExtension(
"utility_helpers",
[
"torch_harmonics/utils/csrc/utils_helpers.cpp",
],
extra_compile_args=get_helpers_compile_args(),
)
)

# DISCO helpers
ext_modules.append(
CppExtension(
"disco_helpers",
[
"torch_harmonics/disco/csrc/disco_helpers.cpp",
],
include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")],
extra_compile_args=get_helpers_compile_args(),
)
)

# Attention helpers
ext_modules.append(
CppExtension(
"attention_helpers",
[
"torch_harmonics/attention/csrc/attention_helpers.cpp",
],
include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")],
extra_compile_args=get_helpers_compile_args(),
)
)

if BUILD_CPP:
# HELPERS
utility_sources = [
"torch_harmonics/utils/csrc/utils_interface.cpp",
"torch_harmonics/utils/csrc/permute_cpu.cpp",
]

if BUILD_CUDA:
print(f"Compiling custom CUDA kernels for torch-harmonics.")
utility_sources.extend([
"torch_harmonics/utils/csrc/permute_cuda.cu",
])
ext_modules.append(
CUDAExtension(
"torch_harmonics.utils._C",
utility_sources,
extra_compile_args=get_compile_args("utils")
)
)
else:
ext_modules.append(
CppExtension(
"torch_harmonics.utils._C",
utility_sources,
extra_compile_args=get_compile_args("utils")
)
)

# DISCO
# Create a single extension that includes both CPU and CUDA code
disco_sources = [
Expand All @@ -128,13 +174,15 @@ def get_ext_modules():
if BUILD_CUDA:
print(f"Compiling custom CUDA kernels for torch-harmonics.")
disco_sources.extend([
"torch_harmonics/utils/csrc/csr_cuda.cu",
"torch_harmonics/disco/csrc/disco_cuda_fwd.cu",
"torch_harmonics/disco/csrc/disco_cuda_bwd.cu",
])
ext_modules.append(
CUDAExtension(
"torch_harmonics.disco._C",
disco_sources,
include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")],
extra_compile_args=get_compile_args("disco")
)
)
Expand All @@ -143,10 +191,10 @@ def get_ext_modules():
CppExtension(
"torch_harmonics.disco._C",
disco_sources,
include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")],
extra_compile_args=get_compile_args("disco")
)
)
cmdclass["build_ext"] = BuildExtension

# ATTENTION
# Create a single extension that includes both CPU and CUDA code
Expand All @@ -167,6 +215,7 @@ def get_ext_modules():
CUDAExtension(
"torch_harmonics.attention._C",
attention_sources,
include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")],
extra_compile_args=get_compile_args("attention")
)
)
Expand All @@ -175,9 +224,12 @@ def get_ext_modules():
CppExtension(
"torch_harmonics.attention._C",
attention_sources,
include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")],
extra_compile_args=get_compile_args("attention")
)
)

# set cmdclass
cmdclass["build_ext"] = BuildExtension

return ext_modules, cmdclass
Expand Down
11 changes: 6 additions & 5 deletions tests/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@
# CPU results normalized to 16 OpenMP threads,
# GPU results normalized to V100 16 GB GPU
# this is just to detect performance regressions, not for absolute performance
_perf_test_thresholds = {"cpu": {"fwd_ms": 1000, "bwd_ms": 8000},
"cuda": {"fwd_ms": 50, "bwd_ms": 150}}
_perf_test_thresholds = {"cpu": {"fwd_ms": 800, "bwd_ms": 6000},
"cuda": {"fwd_ms": 10, "bwd_ms": 30}}
_run_perf_tests = (os.getenv("TORCH_HARMONICS_RUN_PERF_TESTS", "0") == "1")


Expand Down Expand Up @@ -326,12 +326,12 @@ def test_optimized_pt2_compatibility(self, batch_size, channels, heads, in_shape
@parameterized.expand(
[
# self attention
[1, 256, 1, (91, 180), (91, 180), "equiangular", "equiangular", 1e-5, 1e-5],
[1, 256, 1, (91, 180), (91, 180), "equiangular", "equiangular", None],
],
skip_on_empty=True,
)
@unittest.skipUnless(optimized_kernels_is_available() and _run_perf_tests, "skipping performance test because optimized kernels are not available or perf tests are disabled")
def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=False):
def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, theta_cutoff, verbose=False):

if (self.device.type == "cuda") and (not cuda_kernels_is_available()):
raise unittest.SkipTest("skipping test because CUDA kernels are not available")
Expand All @@ -350,7 +350,8 @@ def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, g

att_optimized = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=True,
grid_in=grid_in, grid_out=grid_out,
theta_cutoff=theta_cutoff, bias=True,
optimized_kernel=True).to(self.device)

# random weights
Expand Down
41 changes: 40 additions & 1 deletion tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@
import math
import torch

from torch_harmonics.cache import lru_cache
from testutils import compare_tensors

class TestCacheConsistency(unittest.TestCase):

def test_consistency(self, verbose=False):
if verbose:
print("Testing that cache values does not get modified externally")
Expand All @@ -47,7 +50,43 @@ def test_consistency(self, verbose=False):
# perform in-place modification of leg1
leg1 *= -1.0
leg2 = _precompute_legpoly(10, 10, cost)
self.assertFalse(torch.allclose(leg1, leg2))
self.assertFalse(compare_tensors("legendre", leg2, leg1, verbose=verbose))


def test_pytorch_tensors(self, verbose=False):
if verbose:
print("Testing that PyTorch tensors are cached")

@lru_cache(typed=True, copy=True)
def test_func(tens1, tens2):
return tens1, tens2

# initial tensors
tens1 = torch.randn(4, 4, dtype=torch.float32)
tens2 = torch.randn(4, 4, dtype=torch.float32)

# retrieve from cache
tens1c, tens2c = test_func(tens1, tens2)

# modify copies
tens1c *= -1.0
tens2c *= -1.0

# retrieve from cache again
tens1cc, tens2cc = test_func(tens1, tens2)

if verbose:
print("first tensor", tens1)
print("first tensor after modification", tens1c)
print("first tensor cached", tens1cc)
print("second tensor", tens2)
print("second tensor after modification", tens2c)
print("second tensor cached", tens2cc)

self.assertFalse(compare_tensors("first cached", tens1cc, tens1c, verbose=verbose))
self.assertFalse(compare_tensors("second cached", tens2cc, tens2c, verbose=verbose))
self.assertTrue(compare_tensors("first raw", tens1, tens1cc, verbose=verbose))
self.assertTrue(compare_tensors("second raw", tens2, tens2cc, verbose=verbose))


if __name__ == "__main__":
Expand Down
Loading