Skip to content

Commit e3b390c

Browse files
Aya-ZIbrafacebook-github-bot
authored andcommitted
Add trtlllm to triton bench (#379)
Summary: Pull Request resolved: #379 Run C++ FLASHINFER_CUBIN_DIR=/data/users/$USER/fbsource/fbcode/deeplearning/flashinfer/fb/cubins/ buck2 run mode/opt mode/inplace -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=b200a -c fbcode.platform010_cuda_version=12.8 //deeplearning/flashinfer/trtllm_kernel_interfaces:run_example``` ------- Run Triton bench buck2 run mode/opt mode/inplace -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=b200a -c fbcode.platform010_cuda_version=12.8 //pytorch/tritonbench:run -- --op decoding_attention --only trtllm_decode_fmha --seq-len-q 1 --metrics gbps Todo: Support non-paged case Differential Revision: D81021980
1 parent 28d5884 commit e3b390c

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

tritonbench/operators/decoding_attention/operator.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@
5555
torch.ops.load_library(
5656
"//deeplearning/fbgemm/fbgemm_gpu/experimental:gen_ai_attention_ops"
5757
)
58+
torch.ops.load_library(
59+
"//deeplearning/flashinfer/trtllm_kernel_interfaces:trtllm_fmha_pybind"
60+
)
61+
62+
from .trtllm_utils import trtllm_decode_fmha_func
5863

5964
from tritonbench.utils.triton_op import (
6065
BenchmarkOperator,
@@ -720,3 +725,17 @@ def aiter_paged_fp8kv(
720725
k_scale_asm,
721726
v_scale_asm,
722727
)
728+
729+
@register_benchmark()
730+
def trtllm_decode_fmha(
731+
self,
732+
q: torch.Tensor,
733+
k_cache: torch.Tensor,
734+
v_cache: torch.Tensor,
735+
cache_seqlens: torch.Tensor,
736+
) -> Callable:
737+
738+
args = trtllm_decode_fmha_func(q, k_cache, v_cache, cache_seqlens)
739+
return lambda: torch.ops.trtllm_kernel_interfaces.trtllm_decode_fmha(
740+
*args
741+
)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
TRTLLM FMHA utility functions for handling tensor conversion and kernel preparation.
9+
"""
10+
11+
import torch
12+
13+
14+
def trtllm_decode_fmha_func(q, k_cache, v_cache, cache_seqlens):
15+
"""
16+
TRTLLM FMHA decode function that converts standard tensors to paged format
17+
and calls the TRTLLM FMHA kernel via PyBind extension.
18+
"""
19+
20+
device = q.device
21+
# Convert input tensors to paged format for TRTLLM FMHA
22+
batch_size, seq_len_q, num_qo_heads, head_dim = q.shape
23+
_, max_seq_len_kv, num_kv_heads, _ = k_cache.shape
24+
25+
# Use page size of 16 for TRTLLM FMHA
26+
page_size = 16
27+
max_num_blocks_per_seq = (max_seq_len_kv + page_size - 1) // page_size
28+
total_pages = batch_size * max_num_blocks_per_seq
29+
30+
# Reshape k_cache and v_cache to paged format [total_pages, num_kv_heads, page_size, head_dim]
31+
k_cache_paged = k_cache.view(batch_size, max_num_blocks_per_seq, page_size, num_kv_heads, head_dim)
32+
k_cache_paged = k_cache_paged.permute(0, 1, 3, 2, 4).contiguous()
33+
k_cache_paged = k_cache_paged.view(total_pages, num_kv_heads, page_size, head_dim)
34+
35+
v_cache_paged = v_cache.view(batch_size, max_num_blocks_per_seq, page_size, num_kv_heads, head_dim)
36+
v_cache_paged = v_cache_paged.permute(0, 1, 3, 2, 4).contiguous()
37+
v_cache_paged = v_cache_paged.view(total_pages, num_kv_heads, page_size, head_dim)
38+
39+
# Create block tables
40+
block_tables = torch.zeros(
41+
(batch_size, max_num_blocks_per_seq),
42+
dtype=torch.int32,
43+
device=device
44+
)
45+
for i in range(batch_size):
46+
for j in range(max_num_blocks_per_seq):
47+
block_tables[i, j] = i * max_num_blocks_per_seq + j
48+
49+
# Create output tensor
50+
out = torch.zeros_like(q)
51+
52+
# Create workspace buffer
53+
workspace_size = 128 * 1024 * 1024 # 128MB
54+
workspace_buffer = torch.zeros(workspace_size, dtype=torch.uint8, device=device)
55+
56+
# Attention parameters
57+
max_seq_len = cache_seqlens.max().item()
58+
bmm1_scale = 1.0 / (head_dim ** 0.5)
59+
bmm2_scale = 1.0
60+
window_left = -1 # No sliding window
61+
sm_count = torch.cuda.get_device_properties(device).multi_processor_count
62+
63+
args =(
64+
out, q, k_cache_paged, v_cache_paged, workspace_buffer,
65+
block_tables, cache_seqlens, max_seq_len,
66+
bmm1_scale, bmm2_scale, window_left, sm_count
67+
)
68+
return args

0 commit comments

Comments
 (0)