Skip to content

Commit 05c6f25

Browse files
author
niushengxiao
committed
feat: support page size variable for flashinfer
1 parent 407ec03 commit 05c6f25

8 files changed

+843
-37
lines changed

lightllm/models/llama/flashinfer_struct.py

Lines changed: 60 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,21 @@
33
import numpy as np
44
import torch.distributed as dist
55
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
6-
from lightllm.utils.envs_utils import get_env_start_args
6+
from lightllm.utils.envs_utils import get_env_start_args, get_page_size
77
from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index
88

99

10+
def cdiv(a, b):
11+
return (a + b - 1) // b
12+
13+
1014
class LlamaFlashInferStateInfo(LlamaInferStateInfo):
1115
def __init__(self):
1216
super().__init__()
1317
self.prefill_wrapper = None
1418
self.decode_wrapper = None
1519
self.flashinfer_extra_state = None
20+
self.page_size = get_page_size()
1621

1722
def init_some_extra_state(self, model, input_ids: torch.Tensor):
1823
super().init_some_extra_state(model, input_ids)
@@ -22,29 +27,41 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
2227

2328
if not self.is_prefill:
2429
if get_env_start_args().enable_flashinfer_decode:
25-
self.kv_last_page_len_buffer = torch.full(
26-
(self.batch_size,), 1, dtype=torch.int32, device=input_ids.device
27-
)
30+
self.kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32, device=input_ids.device)
31+
length = cdiv(self.flashinfer_extra_state.max_seq_length, self.page_size)
2832
if self.batch_size <= model.graph_max_batch_size:
2933
self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][
30-
: self.batch_size * self.flashinfer_extra_state.max_seq_length
34+
: self.batch_size * length
3135
]
3236
else:
3337
self.kv_indices = torch.empty(
34-
self.batch_size * self.flashinfer_extra_state.max_seq_length,
38+
self.batch_size * length,
3539
dtype=torch.int32,
3640
device=input_ids.device,
3741
)
3842

39-
repack_kv_index(
40-
self.req_manager.req_to_token_indexs,
41-
self.b_req_idx,
42-
self.b_seq_len,
43-
self.b_start_loc,
44-
self.max_len_in_batch,
45-
self.kv_indices,
46-
)
4743
self.kv_starts = self.b1_cu_kv_seq_len.int()
44+
if "page_size_variable" in model.mode:
45+
b_page_len = cdiv(self.b_seq_len, self.page_size)
46+
self.kv_starts[1:] = b_page_len.cumsum(0)
47+
self.kv_last_page_len = self.b_seq_len - (b_page_len - 1) * self.page_size
48+
repack_kv_index(
49+
self.req_manager.req_to_page_indexs,
50+
self.b_req_idx,
51+
b_page_len,
52+
self.kv_starts[:-1],
53+
cdiv(self.max_kv_seq_len, self.page_size),
54+
self.kv_indices,
55+
)
56+
else:
57+
repack_kv_index(
58+
self.req_manager.req_to_token_indexs,
59+
self.b_req_idx,
60+
self.b_seq_len,
61+
self.b_start_loc,
62+
self.max_kv_seq_len,
63+
self.kv_indices,
64+
)
4865
if self.decode_wrapper is None:
4966
self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(
5067
self.flashinfer_extra_state.workspace_buffer,
@@ -53,16 +70,16 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
5370
use_tensor_cores=True,
5471
paged_kv_indptr_buffer=self.kv_starts,
5572
paged_kv_indices_buffer=self.kv_indices,
56-
paged_kv_last_page_len_buffer=self.kv_last_page_len_buffer,
73+
paged_kv_last_page_len_buffer=self.kv_last_page_len,
5774
)
5875
self.decode_wrapper.plan(
5976
self.kv_starts,
6077
self.kv_indices,
61-
self.kv_last_page_len_buffer,
78+
self.kv_last_page_len,
6279
self.flashinfer_extra_state.tp_q_head_num,
6380
self.flashinfer_extra_state.tp_kv_head_num,
6481
self.flashinfer_extra_state.head_dim,
65-
1,
82+
self.page_size,
6683
q_data_type=self.flashinfer_extra_state.q_data_type,
6784
kv_data_type=self.flashinfer_extra_state.kv_data_type,
6885
non_blocking=True,
@@ -72,19 +89,33 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
7289
q_starts = self.b1_cu_q_seq_len.int()
7390
kv_starts = self.b1_cu_kv_seq_len.int()
7491
kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32, device=input_ids.device)
92+
length = cdiv(self.flashinfer_extra_state.max_seq_length, self.page_size)
7593
kv_indices = torch.empty(
76-
self.batch_size * self.flashinfer_extra_state.max_seq_length,
94+
self.batch_size * length,
7795
dtype=torch.int32,
7896
device=input_ids.device,
7997
)
80-
repack_kv_index(
81-
self.req_manager.req_to_token_indexs,
82-
self.b_req_idx,
83-
self.b_seq_len,
84-
kv_starts[:-1],
85-
self.max_kv_seq_len,
86-
kv_indices,
87-
)
98+
if "page_size_variable" in model.mode:
99+
b_page_len = cdiv(self.b_seq_len, self.page_size)
100+
kv_starts[1:] = b_page_len.cumsum(0)
101+
kv_last_page_len = self.b_seq_len - (b_page_len - 1) * self.page_size
102+
repack_kv_index(
103+
self.req_manager.req_to_page_indexs,
104+
self.b_req_idx,
105+
b_page_len,
106+
kv_starts[:-1],
107+
cdiv(self.max_kv_seq_len, self.page_size),
108+
kv_indices,
109+
)
110+
else:
111+
repack_kv_index(
112+
self.req_manager.req_to_token_indexs,
113+
self.b_req_idx,
114+
self.b_seq_len,
115+
kv_starts[:-1],
116+
self.max_kv_seq_len,
117+
kv_indices,
118+
)
88119
self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
89120
self.flashinfer_extra_state.workspace_buffer,
90121
qo_indptr_buf=q_starts,
@@ -100,7 +131,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
100131
self.flashinfer_extra_state.tp_q_head_num,
101132
self.flashinfer_extra_state.tp_kv_head_num,
102133
self.flashinfer_extra_state.head_dim,
103-
1,
134+
self.page_size,
104135
causal=True,
105136
pos_encoding_mode="NONE",
106137
logits_soft_cap=0.0,
@@ -115,11 +146,11 @@ def copy_for_cuda_graph(self, new_infer_state):
115146
self.decode_wrapper.plan(
116147
new_infer_state.kv_starts,
117148
new_infer_state.kv_indices,
118-
new_infer_state.kv_last_page_len_buffer,
149+
new_infer_state.kv_last_page_len,
119150
new_infer_state.flashinfer_extra_state.tp_q_head_num,
120151
new_infer_state.flashinfer_extra_state.tp_kv_head_num,
121152
new_infer_state.flashinfer_extra_state.head_dim,
122-
1,
153+
self.page_size,
123154
q_data_type=new_infer_state.flashinfer_extra_state.q_data_type,
124155
kv_data_type=new_infer_state.flashinfer_extra_state.kv_data_type,
125156
non_blocking=True,

lightllm/models/llama/layer_infer/transformer_layer_infer.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,16 @@ def _bind_attention(self):
107107
raise Exception(f"Unsupported mode for fa3 backend: {self.mode}")
108108
return
109109
elif get_env_start_args().enable_flashinfer_prefill:
110-
self._context_attention_kernel = partial(
111-
LlamaTransformerLayerInfer._context_attention_flashinfer_kernel, self
112-
)
110+
if "page_size_variable" in self.mode:
111+
self._context_attention_kernel = partial(
112+
LlamaTransformerLayerInfer._paged_context_attention_flashinfer_kernel, self
113+
)
114+
elif not self.mode:
115+
self._context_attention_kernel = partial(
116+
LlamaTransformerLayerInfer._context_attention_flashinfer_kernel, self
117+
)
118+
else:
119+
raise Exception(f"Unsupported mode for flashinfer backend: {self.mode}")
113120
else:
114121
self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_kernel, self)
115122
if "ppl_int8kv" in self.mode:
@@ -174,6 +181,12 @@ def _bind_attention(self):
174181
self._copy_kv_to_mem_cache = partial(
175182
LlamaTransformerLayerInfer._copy_kv_to_mem_cache_with_calibration, self
176183
)
184+
elif "page_size_variable" in self.mode:
185+
assert get_env_start_args().enable_flashinfer_prefill and get_env_start_args().enable_flashinfer_decode
186+
self._token_attention_kernel = partial(
187+
LlamaTransformerLayerInfer._paged_token_decode_attention_flashinfer, self
188+
)
189+
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
177190
elif not self.mode:
178191
if get_env_start_args().enable_flashinfer_decode:
179192
self._token_attention_kernel = partial(
@@ -274,6 +287,21 @@ def _context_attention_flashinfer_kernel(
274287
)
275288
return o_tensor
276289

290+
def _paged_context_attention_flashinfer_kernel(
291+
self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None
292+
) -> torch.Tensor:
293+
o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
294+
page_size = get_page_size()
295+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_].view(
296+
-1, page_size, 2 * self.tp_k_head_num_, self.head_dim_
297+
)
298+
infer_state.prefill_wrapper.run(
299+
q.view(q.shape[0], -1, self.head_dim_),
300+
(kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]),
301+
out=o_tensor.view(q.shape[0], -1, self.head_dim_),
302+
)
303+
return o_tensor
304+
277305
def _context_attention_kernel(
278306
self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None
279307
) -> torch.Tensor:
@@ -587,6 +615,24 @@ def _token_decode_attention_flashinfer(self, q, infer_state: LlamaFlashInferStat
587615
)
588616
return o_tensor
589617

618+
def _paged_token_decode_attention_flashinfer(
619+
self, q, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None
620+
):
621+
batch_size = infer_state.batch_size
622+
calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_)
623+
624+
o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
625+
page_size = get_page_size()
626+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_].view(
627+
-1, page_size, 2 * self.tp_k_head_num_, self.head_dim_
628+
)
629+
infer_state.decode_wrapper.run(
630+
q.view(calcu_shape1),
631+
(kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]),
632+
out=o_tensor.view(calcu_shape1),
633+
)
634+
return o_tensor
635+
590636
def _token_decode_attention_normal(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None):
591637
total_token_num = infer_state.total_token_num
592638
batch_size = infer_state.batch_size

lightllm/utils/envs_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def get_kv_quant_calibration_inference_count():
153153
def get_page_size():
154154
try:
155155
args = get_env_start_args()
156-
return int(os.getenv("PAGE_SIZE", 4)) if "page_size_variable" in args.mode else 1
156+
return int(os.getenv("PAGE_SIZE", 64)) if "page_size_variable" in args.mode else 1
157157
except:
158158
return 1
159159

unit_tests/models/llama/test_context_flashattention_nopad.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
context_attention_fwd_no_prompt_cache,
1111
)
1212
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
13-
from lightllm.common.req_manager import ReqManager
1413

1514
logger = init_logger(__name__)
1615

@@ -56,8 +55,6 @@ def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim):
5655
infer_state.batch_size = Z
5756
infer_state.max_len_in_batch = N_CTX
5857
infer_state.total_token_num = Z * N_CTX
59-
infer_state.req_manager = ReqManager(Z, N_CTX, None)
60-
infer_state.req_manager.req_to_token_indexs = req_to_token_indexs
6158
infer_state.b_req_idx = b_req_idx
6259
infer_state.b_seq_len = b_seq_len
6360
infer_state.b_ready_cache_len = b_ready_cache_len
@@ -73,7 +70,7 @@ def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim):
7370
infer_state.b_seq_len,
7471
infer_state.b_ready_cache_len,
7572
infer_state.max_len_in_batch,
76-
infer_state.req_manager.req_to_token_indexs,
73+
req_to_token_indexs,
7774
)
7875

7976
batch_size = Z

0 commit comments

Comments
 (0)