Skip to content

WIP: Add fused Butina clustering with Triton similarity kernels#125

Open
moradza wants to merge 8 commits intoNVIDIA-Digital-Bio:mainfrom
moradza:amoradzadeh/butina_clustering
Open

WIP: Add fused Butina clustering with Triton similarity kernels#125
moradza wants to merge 8 commits intoNVIDIA-Digital-Bio:mainfrom
moradza:amoradzadeh/butina_clustering

Conversation

@moradza
Copy link
Copy Markdown
Contributor

@moradza moradza commented Mar 31, 2026

Add fused_butina() that computes similarities on-the-fly using custom
Triton kernels, avoiding full distance matrix storage for large datasets.
Includes Tanimoto similarity neighbor computation, cluster subtraction,
and largest-cluster removal operations.

moradza and others added 2 commits March 31, 2026 13:33
Add fused_butina() that computes similarities on-the-fly using custom
Triton kernels, avoiding full distance matrix storage for large datasets.
Includes Tanimoto similarity neighbor computation, cluster subtraction,
and largest-cluster removal operations.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Remove hardcoded BIT_COUNTS/TILE_K constants and make the K dimension
(number of int32 words per fingerprint) a runtime parameter. All three
kernels (_similarity_neighbor, _subtract_similarity_neighbor,
_remove_largest_cluster) now iterate over K in blocked tiles, enabling
fingerprints of arbitrary width instead of being fixed at 1024 bits.
Update validation to check x/y shape agreement instead of enforcing a
fixed column count, and extend the test harness with variable num_words
configs.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 31, 2026

Greptile Summary

This WIP PR introduces fused_butina(), a GPU-native Butina clustering implementation that avoids materialising the full O(N²) distance matrix by computing Tanimoto/cosine similarities on-the-fly via custom Triton kernels (_fused_Butina.py). The main loop maintains front/back pointers into a shared cluster_indices buffer, incrementally subtracting processed-molecule contributions from the running neighbour-count vector.

Key changes:

  • New nvmolkit/_fused_Butina.py with Triton kernels and Python wrappers update_neighbor_counts / extract_cluster_and_singletons.
  • fused_butina() added to nvmolkit/clustering.py; existing butina() and its imports are intact.
  • A comprehensive new test suite in test_clustering.py and a benchmark in benchmarks/.
  • triton added as a core dependency.

Issues not yet addressed from prior review rounds: the off-by-one loop-exit bug where the meeting-point slot in cluster_indices is never written and one molecule is silently dropped / replaced by molecule 0; and the stream parameter threading through Triton kernel launches.

New findings in this revision:

  • cutoff is never validated to be in [0, 1]; an out-of-range value silently produces wrong clustering output.
  • x is not validated at the fused_butina entry point, so dtype/device errors surface with internal parameter names.
  • The Triton kernel grid uses TILE_X=32 on axis-0 and TILE_Y=64 on axis-1, which is suboptimal for L2 cache reuse.
  • Test helpers and benchmark data-generation functions hardcode device=\"cuda\" (GPU 0).

Confidence Score: 4/5

Not ready to merge as-is — the off-by-one loop-exit bug (prior thread) can silently drop a molecule and emit a spurious molecule-0 entry; the new missing cutoff validation also produces silently wrong output for out-of-range inputs.

Score of 4 reflects that the PR is a WIP with at least two P1 issues remaining: the previously-flagged off-by-one cluster_count convergence bug and the newly-found missing cutoff range guard. The architecture is sound, most earlier critical issues (broken butina() imports, hardcoded GPU 0 tensors, missing SPDX header) have been resolved in this revision, and the test suite provides good coverage. The remaining P1 issues are relatively small in scope to fix.

nvmolkit/clustering.py (cutoff validation, off-by-one loop exit) and nvmolkit/_fused_Butina.py (grid axis ordering) need the most attention before merge.

Important Files Changed

Filename Overview
nvmolkit/clustering.py Adds fused_butina() using Triton kernels; missing cutoff range validation and x tensor validation at entry point; off-by-one loop-exit bug (prior thread) still present; original imports retained correctly.
nvmolkit/_fused_Butina.py New Triton kernel file implementing pairwise similarity counting, cluster extraction, and singleton collection; correctness concerns around off-by-one pointer convergence (noted in prior threads) and minor grid-axis ordering inefficiency flagged here.
nvmolkit/tests/test_clustering.py Good coverage of edge cases for fused_butina (single item, all-identical, all-singletons, stream, centroids); helper functions hardcode device="cuda" (GPU 0 assumption).
benchmarks/fused_butina_clustering_bench.py New benchmark comparing Triton vs RDKit Butina across various sizes and thresholds; depends on undocumented benchmark_timing internal package and hardcodes device="cuda".
pyproject.toml Adds triton as a core dependency; straightforward change.
.gitignore Adds .cursor to ignored paths; trivial change.

Reviews (5): Last reviewed commit: "Move fused_butina benchmark from cluster..." | Re-trigger Greptile

Comment on lines +20 to +25
# from nvmolkit import _clustering
# from nvmolkit._arrayHelpers import * # noqa: F403
# from nvmolkit.types import AsyncGpuResult
AsyncGpuResult = None
from nvmolkit._similarity_neighbor import similarity_neighbor, subtract_similarity_neighbor, remove_largest_cluster
# TODO: rename and add cosine similarity
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Existing butina() is completely broken by this PR

AsyncGpuResult is replaced with None and the _clustering C-extension import is commented out. Any call to the existing butina() function will crash:

  • _clustering.butina(...) on line 76 raises NameError: name '_clustering' is not defined
  • return AsyncGpuResult(result) on lines 85–86 raises TypeError: 'NoneType' object is not callable

All six tests in test_clustering.py that exercise butina() will fail immediately. The fix is to restore the original imports alongside the new ones:

from nvmolkit import _clustering
from nvmolkit._arrayHelpers import *  # noqa: F403
from nvmolkit.types import AsyncGpuResult
from nvmolkit._similarity_neighbor import (
    similarity_neighbor,
    subtract_similarity_neighbor,
    remove_largest_cluster,
)

Comment on lines +121 to +122
if stream is not None and not isinstance(stream, torch.cuda.Stream):
raise TypeError(f"stream must be a torch.cuda.Stream or None, got {type(stream).__name__}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 stream parameter is accepted but silently ignored

fused_butina validates the stream argument but never passes it to any Triton kernel invocation. All three kernel calls launch on the default CUDA stream regardless of what is passed. Callers relying on stream ordering will get silent incorrect behavior. The parameter should either be wired through via torch.cuda.stream(stream) context, or removed until stream support is implemented.

Comment on lines +127 to +138
cluster_count = torch.zeros(2).int().cuda()
cluster_count[1] = n_start - 1
cluster_indices = torch.zeros(n_start, dtype=torch.int32).cuda()
cluster_indices.fill_(-1)
cluster_sizes = [0]
centroids = []
is_free = torch.ones(n_start, dtype=torch.int32).cuda()
neigh = torch.zeros(n_start).int().cuda()
threshold = float(1 -cutoff)

first_run = True
while cluster_count[0].item() < cluster_count[1].item():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 n_start == 1 edge case produces a cluster with centroid/member -1

When n_start == 1, cluster_count[1] initialises to n_start - 1 = 0. The while-loop condition 0 < 0 is False, so the loop is never entered and the single molecule's index is never stored in cluster_indices. The post-loop reads cluster_indices[0] == -1 (the fill value) as the centroid, returning [(-1,)] instead of [(0,)].

Comment on lines +19 to +27
def _check_fp_tensor(name: str, x: torch.Tensor) -> None:
if not isinstance(x, torch.Tensor):
raise TypeError(f"{name} must be a torch.Tensor")
if not x.is_cuda:
raise ValueError(f"{name} must be a CUDA tensor")
if x.dtype != torch.int32:
raise ValueError(f"{name} must have dtype int32")
if x.ndim != 2:
raise ValueError(f"{name} must be 2D, got shape={tuple(x.shape)}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 _check_fp_tensor is misnamed — it validates int32, not floating-point tensors

The name _check_fp_tensor ("fp" = floating-point) is misleading; it enforces dtype == torch.int32. The check is correct but the name creates confusion for future maintainers. Consider renaming to _check_int32_tensor and updating all three call sites.

Comment on lines +1 to +4
import torch
import triton
import triton.language as tl

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing SPDX license header

Every other Python source file in the project starts with the standard NVIDIA Apache-2.0 SPDX block. This new file omits it entirely.

Comment on lines +190 to +207
pa = tl.zeros((), dtype=tl.int32)
pb = tl.zeros((), dtype=tl.int32)
dot = tl.zeros((), dtype=tl.int32)

for k_block in range(0, tl.cdiv(K, BLOCK_K)):
k_offset = k_block * BLOCK_K
for kk in tl.static_range(0, BLOCK_K):
k_idx = k_offset + kk
k_mask = k_idx < K
center_k = tl.load(x_ptr + center_id * x_stride_n + k_idx * x_stride_k, mask=k_mask, other=0)
row_k = tl.load(
x_ptr + row * x_stride_n + k_idx * x_stride_k,
mask=row_mask & k_mask,
other=0,
)
pa += _popcount32(center_k)
pb += _popcount32(row_k)
dot += _popcount32(row_k & center_k)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Center molecule's popcount (pa) recomputed redundantly by every thread

In _remove_largest_cluster_kernel each of the n thread-blocks independently reloads and computes pa (the popcount of center_id's fingerprint) via a full K-dimensional scan. Since every thread reads the same center_id row and produces the same scalar, this is O(n × K/32) redundant work. Pre-computing pa on the host and passing it as a scalar parameter would eliminate this waste.

Comment on lines +175 to +322
if __name__ == "__main__":
import time
try:
from rdkit import DataStructs
from rdkit.DataStructs import ExplicitBitVect
from rdkit.ML.Cluster import Butina
HAS_RDKIT = True
except ImportError:
HAS_RDKIT = False
print("RDKit not found. RDKit comparison tests will be skipped.")

def get_rdkit_clusters(bit_tensor, threshold=0.5):
"""Convert int32 tensor to RDKit ExplicitBitVects and run Butina"""
n = bit_tensor.shape[0]
num_words = bit_tensor.shape[1]
fps = []
for i in range(n):
bv = ExplicitBitVect(num_words * 32)
bits = bit_tensor[i].cpu().numpy()
for word_idx in range(num_words):
word = int(bits[word_idx])
for bit_idx in range(32):
if (word >> bit_idx) & 1:
bv.SetBit(word_idx * 32 + bit_idx)
fps.append(bv)

# Calculate pairwise distances (1 - Tanimoto similarity)
dists = []
for i in range(n):
dists.extend(DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i], returnDistance=True))
# Run Butina clustering (cutoff is maximum distance, so 1.0 - threshold)
cutoff = 1.0 - threshold
clusters = Butina.ClusterData(dists, n, cutoff, isDistData=True, reordering=True)
return clusters

def generate_data(n, num_clusters, noise_range=2, seed=42, num_words=64):
"""Generate random bit vectors with underlying cluster structure."""
torch.manual_seed(seed)
base_vectors = torch.randint(-(2**31 - 1), 2**31 - 1, size=(num_clusters, num_words), dtype=torch.int32).cuda()
x_tr = torch.zeros((n, num_words), dtype=torch.int32).cuda()
for i in range(n):
base_idx = i % num_clusters
x_tr[i] = base_vectors[base_idx]
noise = torch.randint(0, noise_range, size=(num_words,), dtype=torch.int32).cuda()
x_tr[i] = x_tr[i] ^ noise
return x_tr

def run_test(n, threshold, num_clusters, noise_range=2, seed=42, num_words=64):
"""Run a single comparison test between Triton and RDKit Butina clustering."""
print(f"\n{'='*60}")
print(f"Test: n={n}, threshold={threshold}, num_clusters={num_clusters}, noise_range={noise_range}, num_words={num_words}")
print(f"{'='*60}")

x_tr = generate_data(n, num_clusters, noise_range=noise_range, seed=seed, num_words=num_words)

print("Running Triton clustering...")
# fused_butina expects a distance cutoff, so we pass 1.0 - threshold
fused_butina(x_tr, cutoff=1.0 - threshold)
torch.cuda.synchronize()
print("Done Triton clustering, starting second run...")
start = time.time()
warp_clusters, _ = fused_butina(x_tr, cutoff=1.0 - threshold)
torch.cuda.synchronize()
warp_time = time.time() - start
print(f"Triton took {warp_time:.4f}s, found {len(warp_clusters)} clusters")

if not HAS_RDKIT:
return True

print("Running RDKit Butina...")
start = time.time()
rdkit_failed = False
try:
rdkit_clusters = get_rdkit_clusters(x_tr, threshold=threshold)
except Exception as e:
print(f"Error running RDKit: {e}")
rdkit_failed = True
rdkit_time = time.time() - start
print(f"RDKit took {rdkit_time:.4f}s, found {len(rdkit_clusters)} clusters")

if rdkit_failed:
return True
rdkit_set = set(tuple(sorted(c)) for c in rdkit_clusters)
warp_set = set(tuple(sorted(c)) for c in warp_clusters)

passed = rdkit_set == warp_set
if passed:
print("SUCCESS: Clusters match exactly!")
else:
print("DIFFERENCE DETECTED!")
print(f" Clusters only in RDKit: {len(rdkit_set - warp_set)}")
print(f" Clusters only in Warp: {len(warp_set - rdkit_set)}")
print(f" RDKit diff: {rdkit_set - warp_set}")
print(f" Warp diff: {warp_set - rdkit_set}")

return passed

def main():
test_configs = [
# (n, threshold, num_clusters, noise_range, num_words)
(100, 0.3, 20, 2, 32),
(100, 0.5, 20, 2, 32),
(100, 0.7, 20, 2, 64),
(100, 0.9, 20, 2, 64),
(500, 0.4, 50, 2, 32),
(500, 0.6, 50, 2, 32),
(500, 0.8, 50, 2, 64),
(1000, 0.3, 100, 2, 32),
(1000, 0.5, 100, 2, 64),
(1000, 0.7, 100, 2, 64),
(5000, 0.5, 200, 2, 32),
(5000, 0.7, 200, 2, 64),
(10000, 0.5, 500, 2, 32),
(10000, 0.5, 2000, 2, 64),

# Denser clusters (lower noise) with tight threshold
(1000, 0.9, 100, 1, 32),
# Sparser clusters (higher noise) with loose threshold
(1000, 0.3, 100, 4, 64),
# Many small clusters
(2000, 0.5, 1000, 2, 32),
# Few large clusters
(2000, 0.5, 10, 2, 64),
(100000, 0.7, 100, 128, 32),
]

results = []
for n, threshold, num_clusters, noise_range, num_words in test_configs:
passed = run_test(n, threshold, num_clusters, noise_range=noise_range, num_words=num_words)
results.append((n, threshold, num_clusters, noise_range, num_words, passed))

print(f"\n{'='*60}")
print("SUMMARY")
print(f"{'='*60}")
all_passed = True
for n, threshold, num_clusters, noise_range, num_words, passed in results:
status = "PASS" if passed else "FAIL"
if not passed:
all_passed = False
print(f" [{status}] n={n:>5}, threshold={threshold}, clusters={num_clusters:>4}, noise={noise_range}, words={num_words}")

total = len(results)
n_passed = sum(1 for *_, p in results if p)
print(f"\n{n_passed}/{total} tests passed.")
if not all_passed:
exit(1)

main() No newline at end of file
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Test/benchmark harness embedded directly in the library module

The if __name__ == "__main__": block is 148 lines of RDKit comparison tests and benchmarks inside clustering.py. The project already has a dedicated nvmolkit/tests/ directory. This should be moved to a proper test file so CI can pick it up via pytest.

moradza and others added 2 commits April 1, 2026 06:51
…tric

Consolidate the duplicated add/subtract neighbor-count kernels into a
single `_update_neighbor_count_kernel` parameterised by ADD_MODE, and
unify the public API into `update_neighbor_counts` and
`extract_cluster_and_singletons`.  Replace the SWAR popcount fallback
with PTX inline assembly (`popc.b32`) and add cosine similarity as a
second supported metric alongside Tanimoto.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Comment on lines +121 to +133
n_start = x.shape[0]
indices = torch.arange(n_start, dtype=torch.int32).cuda()
cluster_count = torch.zeros(2).int().cuda()
cluster_count[1] = n_start - 1
cluster_indices = torch.zeros(n_start, dtype=torch.int32).cuda()
cluster_sizes = [0]
centroids = []
is_free = torch.ones(n_start, dtype=torch.int32).cuda()
neigh = torch.zeros(n_start).int().cuda()
threshold = float(1 - cutoff)
y = x
first_run = True
while cluster_count[0].item() < cluster_count[1].item():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Off-by-one in loop exit: last free molecule is silently dropped

cluster_count[1] is initialised to n_start - 1 and the loop runs while cluster_count[0] < cluster_count[1]. When exactly n - 1 molecules (where n is the current free count) are assigned in a single extract_cluster_and_singletons launch (c cluster members + s singletons, with c + s = n - 1), both counters converge to the same value k (cluster_count[0] = c, cluster_count[1] = n_start - 1 - s = c). The strict < check then evaluates to False and the loop exits, leaving the one remaining free molecule unprocessed.

That molecule is never written into cluster_indices. The post-loop for i in range(n_start - cluster_sizes[-1]) then iterates s + 1 times, but cluster_indices[k] is zero-initialised (never written), so molecule 0 is spuriously emitted as a singleton centroid while the actual free molecule is silently dropped from the output entirely.

Concrete trace with n_start = 5, one iteration producing 3 cluster members + 1 singleton:

  • cluster_count = [3, 3] → loop exits
  • cluster_indices = [A, B, C, 0(uninit), D]
  • Post-loop reads indices 3 and 4 → emits 0 (wrong) and D (correct); molecule E is lost

Fix: change the stopping condition to <= and guard against the empty-x case:

# current (buggy)
while cluster_count[0].item() < cluster_count[1].item():

# candidate fix
while cluster_count[0].item() <= cluster_count[1].item() and x.shape[0] > 0:

Comment on lines +121 to +129
n_start = x.shape[0]
indices = torch.arange(n_start, dtype=torch.int32).cuda()
cluster_count = torch.zeros(2).int().cuda()
cluster_count[1] = n_start - 1
cluster_indices = torch.zeros(n_start, dtype=torch.int32).cuda()
cluster_sizes = [0]
centroids = []
is_free = torch.ones(n_start, dtype=torch.int32).cuda()
neigh = torch.zeros(n_start).int().cuda()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Auxiliary tensors hardcoded to GPU 0 — breaks multi-GPU setups

Every helper tensor in fused_butina uses the bare .cuda() call, which always places the tensor on the default CUDA device (GPU 0), regardless of which device x lives on. If x is on any other GPU, update_neighbor_counts raises a device-mismatch ValueError immediately.

# all of these silently target GPU 0
indices = torch.arange(n_start, dtype=torch.int32).cuda()
cluster_count = torch.zeros(2).int().cuda()
cluster_indices = torch.zeros(n_start, dtype=torch.int32).cuda()
is_free = torch.ones(n_start, dtype=torch.int32).cuda()
neigh = torch.zeros(n_start).int().cuda()

Replace every .cuda() with device=x.device:

indices         = torch.arange(n_start, dtype=torch.int32, device=x.device)
cluster_count   = torch.zeros(2, dtype=torch.int32, device=x.device)
cluster_count[1] = n_start - 1
cluster_indices = torch.zeros(n_start, dtype=torch.int32, device=x.device)
is_free         = torch.ones(n_start, dtype=torch.int32, device=x.device)
neigh           = torch.zeros(n_start, dtype=torch.int32, device=x.device)

moradza and others added 2 commits April 1, 2026 07:29
Add pytest tests for fused_butina covering basic correctness with
greedy-optimality verification, edge cases (single item, all-identical,
all-singletons), return_centroids validation, CUDA stream usage, and
input validation for both tanimoto and cosine metrics.

Also restore real imports in clustering.py and wrap fused_butina in a
torch.cuda.stream context so the stream parameter is actually honored.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace bare .cuda() calls with device=x.device so helper tensors are
allocated on the same GPU as the input fingerprints.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

from nvmolkit import _clustering
from nvmolkit._arrayHelpers import * # noqa: F403
from nvmolkit._arrayHelpers import *
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 # noqa: F403 suppression removed, breaking ruff CI

The original line had # noqa: F403 to suppress the wildcard import lint warning. This PR removes it:

- from nvmolkit._arrayHelpers import *  # noqa: F403
+ from nvmolkit._arrayHelpers import *

The project's pyproject.toml selects the full "F" rule set (pyflakes) in ruff without ignoring F403. clustering.py has no per-file-ignores entry, so ruff will now flag this line and fail CI.

Suggested change
from nvmolkit._arrayHelpers import *
from nvmolkit._arrayHelpers import * # noqa: F403

return AsyncGpuResult(clusters), AsyncGpuResult(centroids)
return AsyncGpuResult(result)

#TODO: add cosine similarity support for longer fingerprint sizes
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 #TODO: missing space violates E265, breaking ruff CI

The E rule set is selected in pyproject.toml and E265 ("block comment should start with # ") is not in the ignore list. #TODO: without a space after # will fail ruff.

Suggested change
#TODO: add cosine similarity support for longer fingerprint sizes
# TODO: add cosine similarity support for longer fingerprint sizes

moradza and others added 2 commits April 1, 2026 08:17
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Extract the inline __main__ block into a standalone benchmark script
using the shared time_it utility, pandas output, and CLI metric
selection.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Comment on lines +115 to +118
if metric not in ["tanimoto", "cosine"]:
raise ValueError(f"metric must be one of ['tanimoto', 'cosine'], got {metric}")
if stream is not None and not isinstance(stream, torch.cuda.Stream):
raise TypeError(f"stream must be a torch.cuda.Stream or None, got {type(stream).__name__}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 cutoff not validated — silently incorrect results outside [0, 1]

threshold = float(1 - cutoff) is computed but the function never validates that cutoff is in [0.0, 1.0]. A caller passing a negative cutoff or a value above 1 produces threshold > 1.0 or threshold < 0.0. Since all fingerprint similarities are in [0, 1], a threshold above 1 means no neighbors are ever found (all molecules become singletons silently), and a negative threshold means everything is a neighbour. This differs from the existing butina() convention where cutoff is a distance, but the real failure mode is that no exception is raised and the output looks plausible.

Suggested change
if metric not in ["tanimoto", "cosine"]:
raise ValueError(f"metric must be one of ['tanimoto', 'cosine'], got {metric}")
if stream is not None and not isinstance(stream, torch.cuda.Stream):
raise TypeError(f"stream must be a torch.cuda.Stream or None, got {type(stream).__name__}")
if metric not in ["tanimoto", "cosine"]:
raise ValueError(f"metric must be one of ['tanimoto', 'cosine'], got {metric}")
if not (0.0 <= cutoff <= 1.0):
raise ValueError(f"cutoff must be in [0, 1], got {cutoff}")
if stream is not None and not isinstance(stream, torch.cuda.Stream):

Comment on lines +119 to +120
with torch.cuda.stream(stream):
n_start = x.shape[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 x tensor not validated at entry point — unhelpful errors surface later

fused_butina performs no validation on x itself before diving into the kernel pipeline. A CPU tensor, an int64 tensor, or a 3-D array will only raise an error deep inside update_neighbor_counts or extract_cluster_and_singletons, with an error message that names the internal parameter ("x must be a CUDA tensor") rather than pointing the caller to their top-level API call. A direct call to _check_fingerprint_matrix at the entry point gives a consistent, early, actionable error:

with torch.cuda.stream(stream):
    _check_fingerprint_matrix("x", x)   # <-- add this line
    n_start = x.shape[0]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant