Skip to content

Commit ae2323f

Browse files
Aya-ZIbrafacebook-github-bot
authored andcommitted
Add cutlass decode kernel to TritonBench
Summary: as title Differential Revision: D80041532
1 parent 94afdc9 commit ae2323f

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

tritonbench/operators/decoding_attention/operator.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,18 @@
7171
except (ImportError, IOError, AttributeError):
7272
HAS_AITER = False
7373

74+
# [Optional] cutlass_blackwell_fmha backend
75+
HAS_CUTLASS_BLACKWELL = True
76+
try:
77+
from ai_acceleration.kernels.attentions.cutlass_blackwell_fmha.cutlass_blackwell_fmha_interface import (
78+
cutlass_blackwell_fmha_func,
79+
)
80+
# Disable FA3 for Blackwell as it doesn't work properly
81+
HAS_FLASH_V3 = False
82+
# Note: We keep FA2 and triton enabled alongside Blackwell for comparison
83+
except (ImportError, IOError, AttributeError):
84+
HAS_CUTLASS_BLACKWELL = False
85+
7486

7587
def parse_op_args(args: List[str]):
7688
parser = argparse.ArgumentParser()
@@ -559,6 +571,34 @@ def fbgemm_gqa_fp8kv(
559571
cache_logical_dtype_int=1, # FP8 = 1
560572
)
561573

574+
@register_benchmark(enabled=HAS_CUTLASS_BLACKWELL)
575+
def cutlass_blackwell_fmha_decode(
576+
self,
577+
q: torch.Tensor,
578+
k_cache: torch.Tensor,
579+
v_cache: torch.Tensor,
580+
cache_seqlens: torch.Tensor,
581+
) -> Callable:
582+
seq_len_q = q.shape[1]
583+
584+
# Cutlass Blackwell FMHA currently supports decode case (seq_len_q == 1)
585+
if seq_len_q != 1:
586+
# Skip non-decode cases for now
587+
raise NotImplementedError("Cutlass Blackwell FMHA only supports decode case")
588+
# return lambda: q.new_zeros(q.shape)
589+
590+
# Convert to fp8 format as required by the decode path
591+
_q = q.to(torch.float8_e4m3fn)
592+
_k_cache = k_cache.to(torch.float8_e4m3fn)
593+
_v_cache = v_cache.to(torch.float8_e4m3fn)
594+
595+
# Create seqlen_kv tensor for generation phase
596+
seqlen_kv = cache_seqlens.to(dtype=torch.int32, device=q.device)
597+
598+
return lambda: cutlass_blackwell_fmha_func(
599+
_q, _k_cache, _v_cache, causal=CAUSAL, seqlen_kv=seqlen_kv
600+
)
601+
562602
@register_benchmark(enabled=HAS_AITER)
563603
def aiter_paged_fp8kv(
564604
self,

0 commit comments

Comments
 (0)