33import numpy as np
44import torch .distributed as dist
55from lightllm .models .llama .infer_struct import LlamaInferStateInfo
6- from lightllm .utils .envs_utils import get_env_start_args
6+ from lightllm .utils .envs_utils import get_env_start_args , get_page_size
77from lightllm .models .deepseek2 .triton_kernel .repack_kv_index import repack_kv_index
88
99
10+ def cdiv (a , b ):
11+ return (a + b - 1 ) // b
12+
13+
1014class LlamaFlashInferStateInfo (LlamaInferStateInfo ):
1115 def __init__ (self ):
1216 super ().__init__ ()
1317 self .prefill_wrapper = None
1418 self .decode_wrapper = None
1519 self .flashinfer_extra_state = None
20+ self .page_size = get_page_size ()
1621
1722 def init_some_extra_state (self , model , input_ids : torch .Tensor ):
1823 super ().init_some_extra_state (model , input_ids )
@@ -22,29 +27,41 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
2227
2328 if not self .is_prefill :
2429 if get_env_start_args ().enable_flashinfer_decode :
25- self .kv_last_page_len_buffer = torch .full (
26- (self .batch_size ,), 1 , dtype = torch .int32 , device = input_ids .device
27- )
30+ self .kv_last_page_len = torch .full ((self .batch_size ,), 1 , dtype = torch .int32 , device = input_ids .device )
31+ length = cdiv (self .flashinfer_extra_state .max_seq_length , self .page_size )
2832 if self .batch_size <= model .graph_max_batch_size :
2933 self .kv_indices = self .flashinfer_extra_state .kv_indices_buffer [self .microbatch_index ][
30- : self .batch_size * self . flashinfer_extra_state . max_seq_length
34+ : self .batch_size * length
3135 ]
3236 else :
3337 self .kv_indices = torch .empty (
34- self .batch_size * self . flashinfer_extra_state . max_seq_length ,
38+ self .batch_size * length ,
3539 dtype = torch .int32 ,
3640 device = input_ids .device ,
3741 )
3842
39- repack_kv_index (
40- self .req_manager .req_to_token_indexs ,
41- self .b_req_idx ,
42- self .b_seq_len ,
43- self .b_start_loc ,
44- self .max_len_in_batch ,
45- self .kv_indices ,
46- )
4743 self .kv_starts = self .b1_cu_kv_seq_len .int ()
44+ if "page_size_variable" in model .mode :
45+ b_page_len = cdiv (self .b_seq_len , self .page_size )
46+ self .kv_starts [1 :] = b_page_len .cumsum (0 )
47+ self .kv_last_page_len = self .b_seq_len - (b_page_len - 1 ) * self .page_size
48+ repack_kv_index (
49+ self .req_manager .req_to_page_indexs ,
50+ self .b_req_idx ,
51+ b_page_len ,
52+ self .kv_starts [:- 1 ],
53+ cdiv (self .max_kv_seq_len , self .page_size ),
54+ self .kv_indices ,
55+ )
56+ else :
57+ repack_kv_index (
58+ self .req_manager .req_to_token_indexs ,
59+ self .b_req_idx ,
60+ self .b_seq_len ,
61+ self .b_start_loc ,
62+ self .max_kv_seq_len ,
63+ self .kv_indices ,
64+ )
4865 if self .decode_wrapper is None :
4966 self .decode_wrapper = flashinfer .decode .BatchDecodeWithPagedKVCacheWrapper (
5067 self .flashinfer_extra_state .workspace_buffer ,
@@ -53,16 +70,16 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
5370 use_tensor_cores = True ,
5471 paged_kv_indptr_buffer = self .kv_starts ,
5572 paged_kv_indices_buffer = self .kv_indices ,
56- paged_kv_last_page_len_buffer = self .kv_last_page_len_buffer ,
73+ paged_kv_last_page_len_buffer = self .kv_last_page_len ,
5774 )
5875 self .decode_wrapper .plan (
5976 self .kv_starts ,
6077 self .kv_indices ,
61- self .kv_last_page_len_buffer ,
78+ self .kv_last_page_len ,
6279 self .flashinfer_extra_state .tp_q_head_num ,
6380 self .flashinfer_extra_state .tp_kv_head_num ,
6481 self .flashinfer_extra_state .head_dim ,
65- 1 ,
82+ self . page_size ,
6683 q_data_type = self .flashinfer_extra_state .q_data_type ,
6784 kv_data_type = self .flashinfer_extra_state .kv_data_type ,
6885 non_blocking = True ,
@@ -72,19 +89,33 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
7289 q_starts = self .b1_cu_q_seq_len .int ()
7390 kv_starts = self .b1_cu_kv_seq_len .int ()
7491 kv_last_page_len = torch .full ((self .batch_size ,), 1 , dtype = torch .int32 , device = input_ids .device )
92+ length = cdiv (self .flashinfer_extra_state .max_seq_length , self .page_size )
7593 kv_indices = torch .empty (
76- self .batch_size * self . flashinfer_extra_state . max_seq_length ,
94+ self .batch_size * length ,
7795 dtype = torch .int32 ,
7896 device = input_ids .device ,
7997 )
80- repack_kv_index (
81- self .req_manager .req_to_token_indexs ,
82- self .b_req_idx ,
83- self .b_seq_len ,
84- kv_starts [:- 1 ],
85- self .max_kv_seq_len ,
86- kv_indices ,
87- )
98+ if "page_size_variable" in model .mode :
99+ b_page_len = cdiv (self .b_seq_len , self .page_size )
100+ kv_starts [1 :] = b_page_len .cumsum (0 )
101+ kv_last_page_len = self .b_seq_len - (b_page_len - 1 ) * self .page_size
102+ repack_kv_index (
103+ self .req_manager .req_to_page_indexs ,
104+ self .b_req_idx ,
105+ b_page_len ,
106+ kv_starts [:- 1 ],
107+ cdiv (self .max_kv_seq_len , self .page_size ),
108+ kv_indices ,
109+ )
110+ else :
111+ repack_kv_index (
112+ self .req_manager .req_to_token_indexs ,
113+ self .b_req_idx ,
114+ self .b_seq_len ,
115+ kv_starts [:- 1 ],
116+ self .max_kv_seq_len ,
117+ kv_indices ,
118+ )
88119 self .prefill_wrapper = flashinfer .prefill .BatchPrefillWithPagedKVCacheWrapper (
89120 self .flashinfer_extra_state .workspace_buffer ,
90121 qo_indptr_buf = q_starts ,
@@ -100,7 +131,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
100131 self .flashinfer_extra_state .tp_q_head_num ,
101132 self .flashinfer_extra_state .tp_kv_head_num ,
102133 self .flashinfer_extra_state .head_dim ,
103- 1 ,
134+ self . page_size ,
104135 causal = True ,
105136 pos_encoding_mode = "NONE" ,
106137 logits_soft_cap = 0.0 ,
@@ -115,11 +146,11 @@ def copy_for_cuda_graph(self, new_infer_state):
115146 self .decode_wrapper .plan (
116147 new_infer_state .kv_starts ,
117148 new_infer_state .kv_indices ,
118- new_infer_state .kv_last_page_len_buffer ,
149+ new_infer_state .kv_last_page_len ,
119150 new_infer_state .flashinfer_extra_state .tp_q_head_num ,
120151 new_infer_state .flashinfer_extra_state .tp_kv_head_num ,
121152 new_infer_state .flashinfer_extra_state .head_dim ,
122- 1 ,
153+ self . page_size ,
123154 q_data_type = new_infer_state .flashinfer_extra_state .q_data_type ,
124155 kv_data_type = new_infer_state .flashinfer_extra_state .kv_data_type ,
125156 non_blocking = True ,
0 commit comments