|
27 | 27 | from lightllm.models.llama.triton_kernel.ppl_quant_copy_kv import destindex_copy_dequantize_kv |
28 | 28 | from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor |
29 | 29 | from lightllm.utils.log_utils import init_logger |
30 | | -from lightllm.utils.envs_utils import get_env_start_args, get_page_size |
| 30 | +from lightllm.utils.envs_utils import get_env_start_args |
31 | 31 | from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops |
32 | 32 | from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant |
33 | 33 | from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops |
@@ -291,9 +291,8 @@ def _paged_context_attention_flashinfer_kernel( |
291 | 291 | self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None |
292 | 292 | ) -> torch.Tensor: |
293 | 293 | o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out |
294 | | - page_size = get_page_size() |
295 | 294 | kv = infer_state.mem_manager.kv_buffer[self.layer_num_].view( |
296 | | - -1, page_size, 2 * self.tp_k_head_num_, self.head_dim_ |
| 295 | + -1, infer_state.page_size, 2 * self.tp_k_head_num_, self.head_dim_ |
297 | 296 | ) |
298 | 297 | infer_state.prefill_wrapper.run( |
299 | 298 | q.view(q.shape[0], -1, self.head_dim_), |
@@ -356,13 +355,12 @@ def _context_attention_kernel_ppl_int8kv( |
356 | 355 | def _paged_context_attention_flashattention( |
357 | 356 | self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None |
358 | 357 | ): |
359 | | - page_size = get_page_size() |
360 | 358 | cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( |
361 | | - -1, page_size, self.tp_k_head_num_, self.head_dim_ |
| 359 | + -1, infer_state.page_size, self.tp_k_head_num_, self.head_dim_ |
362 | 360 | ) |
363 | 361 | cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ |
364 | 362 | :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : |
365 | | - ].reshape(-1, page_size, self.tp_v_head_num_, self.head_dim_) |
| 363 | + ].reshape(-1, infer_state.page_size, self.tp_v_head_num_, self.head_dim_) |
366 | 364 | q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) |
367 | 365 | k_descale, v_descale = None, None # disable quantization |
368 | 366 | Lq = q.shape[-1] |
@@ -622,9 +620,8 @@ def _paged_token_decode_attention_flashinfer( |
622 | 620 | calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) |
623 | 621 |
|
624 | 622 | o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out |
625 | | - page_size = get_page_size() |
626 | 623 | kv = infer_state.mem_manager.kv_buffer[self.layer_num_].view( |
627 | | - -1, page_size, 2 * self.tp_k_head_num_, self.head_dim_ |
| 624 | + -1, infer_state.page_size, 2 * self.tp_k_head_num_, self.head_dim_ |
628 | 625 | ) |
629 | 626 | infer_state.decode_wrapper.run( |
630 | 627 | q.view(calcu_shape1), |
@@ -914,13 +911,12 @@ def _token_decode_attention_gqa_flashdecoding_vsm( |
914 | 911 | def _paged_token_decode_attention_flashattention( |
915 | 912 | self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None |
916 | 913 | ): |
917 | | - page_size = get_page_size() |
918 | 914 | cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( |
919 | | - -1, page_size, self.tp_k_head_num_, self.head_dim_ |
| 915 | + -1, infer_state.page_size, self.tp_k_head_num_, self.head_dim_ |
920 | 916 | ) |
921 | 917 | cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ |
922 | 918 | :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : |
923 | | - ].reshape(-1, page_size, self.tp_v_head_num_, self.head_dim_) |
| 919 | + ].reshape(-1, infer_state.page_size, self.tp_v_head_num_, self.head_dim_) |
924 | 920 | q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) |
925 | 921 | k_descale, v_descale = None, None # disable quantization |
926 | 922 | Lq = q.shape[-1] |
|
0 commit comments