Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions tritonbench/operators/decoding_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@
except (ImportError, IOError, AttributeError):
HAS_AITER = False

# [Optional] cutlass_blackwell_fmha backend
HAS_CUTLASS_BLACKWELL = True
try:
from fbgemm_gpu.experimental.gen_ai.attention.cutlass_blackwell_fmha import (
cutlass_blackwell_fmha_interface as blackwell,
)
# Disable FA3 for Blackwell as it doesn't work properly
HAS_FLASH_V3 = False
# Note: We keep FA2 and triton enabled alongside Blackwell for comparison
except (ImportError, IOError, AttributeError):
HAS_CUTLASS_BLACKWELL = False


# [Optional] flash_fwd cute-DSL backend
HAS_FLASH_CUTE = True
Expand Down Expand Up @@ -591,6 +603,54 @@ def flash_cute_dsl(
q, k_cache, v_cache, causal=CAUSAL, pack_gqa=(q_heads != kv_heads)
)

@register_benchmark(enabled=HAS_CUTLASS_BLACKWELL)
def cutlass_blackwell_fmha_decode_fp8qkv(
self,
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
cache_seqlens: torch.Tensor,
) -> Callable:
seq_len_q = q.shape[1]

# Cutlass Blackwell FMHA currently supports decode case (seq_len_q == 1)
if seq_len_q != 1:
# Skip non-decode cases for now
raise NotImplementedError("Cutlass Blackwell FMHA only supports decode case")
# return lambda: q.new_zeros(q.shape)

# Convert to fp8 format as required by the decode path
_q = q.to(torch.float8_e4m3fn)
_k_cache = k_cache.to(torch.float8_e4m3fn)
_v_cache = v_cache.to(torch.float8_e4m3fn)

# Create seqlen_kv tensor for generation phase
seqlen_kv = cache_seqlens.to(dtype=torch.int32, device=q.device)

return lambda: blackwell.cutlass_blackwell_fmha_func(
_q, _k_cache, _v_cache, causal=CAUSAL, seqlen_kv=seqlen_kv
)
@register_benchmark(enabled=HAS_CUTLASS_BLACKWELL)
def cutlass_blackwell_fmha_decode(
self,
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
cache_seqlens: torch.Tensor,
) -> Callable:
seq_len_q = q.shape[1]

# Cutlass Blackwell FMHA currently supports decode case (seq_len_q == 1)
if seq_len_q != 1:
# Skip non-decode cases for now
raise NotImplementedError("Cutlass Blackwell FMHA only supports decode case")

# Create seqlen_kv tensor for generation phase
seqlen_kv = cache_seqlens.to(dtype=torch.int32, device=q.device)

return lambda: blackwell.cutlass_blackwell_fmha_func(
q, k_cache, v_cache, causal=CAUSAL, seqlen_kv=seqlen_kv
)
@register_benchmark(enabled=HAS_AITER)
def aiter_paged_fp8kv(
self,
Expand Down
Loading