2626from functools import partial
2727from lightllm .models .llama .yarn_rotary_utils import get_deepseek_mscale
2828from lightllm .distributed .communication_op import all_gather , all_gather_into_tensor , all_reduce , reduce_scatter_tensor
29- from lightllm .utils .envs_utils import get_env_start_args
29+ from lightllm .utils .envs_utils import get_env_start_args , get_page_size
3030from lightllm .utils .dist_utils import get_global_world_size
3131from lightllm .utils .log_utils import init_logger
3232from lightllm .utils .sgl_utils import flash_attn_varlen_func , flash_attn_with_kvcache , merge_state_v2
@@ -93,6 +93,18 @@ def _bind_attention(self):
9393 self ._token_attention_kernel = partial (
9494 Deepseek2TransformerLayerInfer ._token_gqa_decode_attention_flashdecoding_fp8 , self
9595 )
96+ elif "page_size_variable" in self .mode :
97+ self ._copy_kv_to_mem_cache = partial (Deepseek2TransformerLayerInfer ._copy_kv_to_mem_cache_normal , self )
98+ if get_env_start_args ().enable_fa3 :
99+ self ._token_attention_kernel = partial (
100+ Deepseek2TransformerLayerInfer ._token_gqa_decode_attention_flashattention_paged , self
101+ )
102+ elif get_env_start_args ().enable_flashinfer_decode :
103+ self ._token_attention_kernel = partial (
104+ Deepseek2TransformerLayerInfer ._token_gqa_decode_attention_flashinfer_paged , self
105+ )
106+ else :
107+ raise Exception ("Page size variable mode is not supported in other backends." )
96108 else :
97109 self ._copy_kv_to_mem_cache = partial (Deepseek2TransformerLayerInfer ._copy_kv_to_mem_cache_normal , self )
98110 if get_env_start_args ().enable_fa3 :
@@ -574,6 +586,36 @@ def _token_gqa_decode_attention_flashattention(
574586 )
575587 return o_tensor
576588
589+ def _token_gqa_decode_attention_flashattention_paged (
590+ self , q , infer_state : Deepseek2FlashInferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
591+ ):
592+ page_size = get_page_size ()
593+ q_nope , q_rope = q [:, :, : - self .qk_rope_head_dim ], q [:, :, - self .qk_rope_head_dim :]
594+ q_nope = layer_weight .k_b_proj_ .bmm (q_nope .transpose (0 , 1 )).transpose (0 , 1 )
595+ kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ]
596+ k_rope = kv [:, :, - self .qk_rope_head_dim :].reshape (- 1 , page_size , 1 , self .qk_rope_head_dim )
597+ kv_nope = kv [:, :, : - self .qk_rope_head_dim ].reshape (- 1 , page_size , 1 , self .kv_lora_rank )
598+ k_descale , v_descale = None , None
599+ o_tensor = flash_attn_with_kvcache (
600+ q = q_rope ,
601+ k_cache = k_rope ,
602+ v_cache = kv_nope ,
603+ qv = q_nope ,
604+ page_table = infer_state .page_table ,
605+ cache_seqlens = infer_state .b_seq_len ,
606+ cu_seqlens_q = infer_state .cu_seqlens_q ,
607+ cu_seqlens_k_new = infer_state .cu_seqlens_k ,
608+ max_seqlen_q = 1 ,
609+ softmax_scale = self .softmax_scale ,
610+ causal = True ,
611+ window_size = (- 1 , - 1 ),
612+ softcap = 0.0 ,
613+ k_descale = k_descale ,
614+ v_descale = v_descale ,
615+ return_softmax_lse = False ,
616+ )
617+ return o_tensor
618+
577619 def _token_gqa_decode_attention_flashinfer (
578620 self , q , infer_state : Deepseek2FlashInferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
579621 ):
@@ -593,6 +635,26 @@ def _token_gqa_decode_attention_flashinfer(
593635 )
594636 return o_tensor
595637
638+ def _token_gqa_decode_attention_flashinfer_paged (
639+ self , q , infer_state : Deepseek2FlashInferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
640+ ):
641+ page_size = get_page_size ()
642+ q_nope , q_rope = q [:, :, : - self .qk_rope_head_dim ], q [:, :, - self .qk_rope_head_dim :]
643+ q_nope = layer_weight .k_b_proj_ .bmm (q_nope .transpose (0 , 1 )).transpose (0 , 1 )
644+
645+ kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ]
646+ o_tensor = self .alloc_tensor (q_nope .shape , dtype = q_nope .dtype )
647+
648+ infer_state .decode_wrapper .run (
649+ q_nope ,
650+ q_rope ,
651+ kv [:, :, : - self .qk_rope_head_dim ].reshape (- 1 , page_size , 1 , self .kv_lora_rank ),
652+ kv [:, :, - self .qk_rope_head_dim :].reshape (- 1 , page_size , 1 , self .qk_rope_head_dim ),
653+ out = o_tensor ,
654+ return_lse = False ,
655+ )
656+ return o_tensor
657+
596658 def _token_gqa_decode_attention_flashdecoding (
597659 self , q , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
598660 ):
0 commit comments