diff --git a/tritonbench/operators/decoding_attention/operator.py b/tritonbench/operators/decoding_attention/operator.py index 53f394d35..5ec11060f 100644 --- a/tritonbench/operators/decoding_attention/operator.py +++ b/tritonbench/operators/decoding_attention/operator.py @@ -77,6 +77,18 @@ except (ImportError, IOError, AttributeError): HAS_AITER = False +# [Optional] cutlass_blackwell_fmha backend +HAS_CUTLASS_BLACKWELL = True +try: + from fbgemm_gpu.experimental.gen_ai.attention.cutlass_blackwell_fmha import ( + cutlass_blackwell_fmha_interface as blackwell, + ) + # Disable FA3 for Blackwell as it doesn't work properly + HAS_FLASH_V3 = False + # Note: We keep FA2 and triton enabled alongside Blackwell for comparison +except (ImportError, IOError, AttributeError): + HAS_CUTLASS_BLACKWELL = False + # [Optional] flash_fwd cute-DSL backend HAS_FLASH_CUTE = True @@ -591,6 +603,54 @@ def flash_cute_dsl( q, k_cache, v_cache, causal=CAUSAL, pack_gqa=(q_heads != kv_heads) ) + @register_benchmark(enabled=HAS_CUTLASS_BLACKWELL) + def cutlass_blackwell_fmha_decode_fp8qkv( + self, + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + cache_seqlens: torch.Tensor, + ) -> Callable: + seq_len_q = q.shape[1] + + # Cutlass Blackwell FMHA currently supports decode case (seq_len_q == 1) + if seq_len_q != 1: + # Skip non-decode cases for now + raise NotImplementedError("Cutlass Blackwell FMHA only supports decode case") + # return lambda: q.new_zeros(q.shape) + + # Convert to fp8 format as required by the decode path + _q = q.to(torch.float8_e4m3fn) + _k_cache = k_cache.to(torch.float8_e4m3fn) + _v_cache = v_cache.to(torch.float8_e4m3fn) + + # Create seqlen_kv tensor for generation phase + seqlen_kv = cache_seqlens.to(dtype=torch.int32, device=q.device) + + return lambda: blackwell.cutlass_blackwell_fmha_func( + _q, _k_cache, _v_cache, causal=CAUSAL, seqlen_kv=seqlen_kv + ) + @register_benchmark(enabled=HAS_CUTLASS_BLACKWELL) + def cutlass_blackwell_fmha_decode( + self, + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + cache_seqlens: torch.Tensor, + ) -> Callable: + seq_len_q = q.shape[1] + + # Cutlass Blackwell FMHA currently supports decode case (seq_len_q == 1) + if seq_len_q != 1: + # Skip non-decode cases for now + raise NotImplementedError("Cutlass Blackwell FMHA only supports decode case") + + # Create seqlen_kv tensor for generation phase + seqlen_kv = cache_seqlens.to(dtype=torch.int32, device=q.device) + + return lambda: blackwell.cutlass_blackwell_fmha_func( + q, k_cache, v_cache, causal=CAUSAL, seqlen_kv=seqlen_kv + ) @register_benchmark(enabled=HAS_AITER) def aiter_paged_fp8kv( self,