|
71 | 71 | except (ImportError, IOError, AttributeError):
|
72 | 72 | HAS_AITER = False
|
73 | 73 |
|
| 74 | +# [Optional] cutlass_blackwell_fmha backend |
| 75 | +HAS_CUTLASS_BLACKWELL = True |
| 76 | +try: |
| 77 | + from ai_acceleration.kernels.attentions.cutlass_blackwell_fmha.cutlass_blackwell_fmha_interface import ( |
| 78 | + cutlass_blackwell_fmha_func, |
| 79 | + ) |
| 80 | + # Disable FA3 for Blackwell as it doesn't work properly |
| 81 | + HAS_FLASH_V3 = False |
| 82 | + # Note: We keep FA2 and triton enabled alongside Blackwell for comparison |
| 83 | +except (ImportError, IOError, AttributeError): |
| 84 | + HAS_CUTLASS_BLACKWELL = False |
| 85 | + |
74 | 86 |
|
75 | 87 | def parse_op_args(args: List[str]):
|
76 | 88 | parser = argparse.ArgumentParser()
|
@@ -559,6 +571,34 @@ def fbgemm_gqa_fp8kv(
|
559 | 571 | cache_logical_dtype_int=1, # FP8 = 1
|
560 | 572 | )
|
561 | 573 |
|
| 574 | + @register_benchmark(enabled=HAS_CUTLASS_BLACKWELL) |
| 575 | + def cutlass_blackwell_fmha_decode( |
| 576 | + self, |
| 577 | + q: torch.Tensor, |
| 578 | + k_cache: torch.Tensor, |
| 579 | + v_cache: torch.Tensor, |
| 580 | + cache_seqlens: torch.Tensor, |
| 581 | + ) -> Callable: |
| 582 | + seq_len_q = q.shape[1] |
| 583 | + |
| 584 | + # Cutlass Blackwell FMHA currently supports decode case (seq_len_q == 1) |
| 585 | + if seq_len_q != 1: |
| 586 | + # Skip non-decode cases for now |
| 587 | + raise NotImplementedError("Cutlass Blackwell FMHA only supports decode case") |
| 588 | + # return lambda: q.new_zeros(q.shape) |
| 589 | + |
| 590 | + # Convert to fp8 format as required by the decode path |
| 591 | + _q = q.to(torch.float8_e4m3fn) |
| 592 | + _k_cache = k_cache.to(torch.float8_e4m3fn) |
| 593 | + _v_cache = v_cache.to(torch.float8_e4m3fn) |
| 594 | + |
| 595 | + # Create seqlen_kv tensor for generation phase |
| 596 | + seqlen_kv = cache_seqlens.to(dtype=torch.int32, device=q.device) |
| 597 | + |
| 598 | + return lambda: cutlass_blackwell_fmha_func( |
| 599 | + _q, _k_cache, _v_cache, causal=CAUSAL, seqlen_kv=seqlen_kv |
| 600 | + ) |
| 601 | + |
562 | 602 | @register_benchmark(enabled=HAS_AITER)
|
563 | 603 | def aiter_paged_fp8kv(
|
564 | 604 | self,
|
|
0 commit comments