diff --git a/tritonbench/operators/blackwell_attentions/operator.py b/tritonbench/operators/blackwell_attentions/operator.py index b3c43cee..d528c68c 100644 --- a/tritonbench/operators/blackwell_attentions/operator.py +++ b/tritonbench/operators/blackwell_attentions/operator.py @@ -44,6 +44,13 @@ except (ImportError, IOError, AttributeError): HAS_FLASH_V2 = False +# [Optional] CuTe +try: + import flash_attn.cute.interface as facute + HAS_FLASH_CUTE = True +except (ImportError, IOError, AttributeError): + HAS_FLASH_CUTE = False + # [Optional] xformers backend try: import xformers # @manual=//fair/xformers:xformers @@ -266,6 +273,19 @@ def sdpa_flash_attention(q, k, v): v, ) + @register_benchmark(enabled=(IS_B200 and HAS_FLASH_CUTE), label=f"cutedsl-blackwell", fwd_only=True) + def cutedsl_blackwell( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor) -> Callable: + + # [B, H, S, D] -> [B, S, H, D] + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + v = v.transpose(1, 2).contiguous() + return lambda: facute.flash_attn_func(q, k, v, self.sm_scale, self.causal) + @register_benchmark() def flex_attention(self, q, k, v): from torch.nn.attention.flex_attention import create_block_mask, flex_attention