Skip to content

Commit f31e2e5

Browse files
author
niushengxiao
committed
feat: add protected code
1 parent bb9d6fc commit f31e2e5

File tree

6 files changed

+31
-28
lines changed

6 files changed

+31
-28
lines changed

lightllm/models/deepseek2/flashattention_infer_struct.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class Deepseek2FlashAttentionStateInfo(Deepseek2InferStateInfo):
1616

1717
def __init__(self):
1818
super().__init__()
19+
self.page_size = get_page_size()
1920

2021
@classmethod
2122
def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int):
@@ -43,19 +44,18 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
4344
self.cu_seqlens_q = self.b1_cu_q_seq_len
4445
self.cu_seqlens_k = self.b1_cu_kv_seq_len
4546
max_seq_len_k = self.max_kv_seq_len
46-
page_size = get_page_size()
4747
if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch:
48-
length = cdiv(model.graph_max_len_in_batch, page_size)
48+
length = cdiv(model.graph_max_len_in_batch, self.page_size)
4949
page_buffer = Deepseek2FlashAttentionStateInfo.get_page_table_buffer(model.graph_max_batch_size, length)
5050
self.page_table = page_buffer[self.microbatch_index][: self.batch_size * length].reshape(
5151
self.batch_size, length
5252
)
5353
else:
54-
length = cdiv(self.max_len_in_batch, page_size)
54+
length = cdiv(self.max_len_in_batch, self.page_size)
5555
self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32).to(input_ids.device)
5656

5757
if "page_size_variable" in model.mode:
58-
length = cdiv(max_seq_len_k, page_size)
58+
length = cdiv(max_seq_len_k, self.page_size)
5959
self.page_table[:, :length].copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length])
6060
self.page_table[:, length:].fill_(0)
6161
else:

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from functools import partial
2727
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
2828
from lightllm.distributed.communication_op import all_gather, all_gather_into_tensor, all_reduce, reduce_scatter_tensor
29-
from lightllm.utils.envs_utils import get_env_start_args, get_page_size
29+
from lightllm.utils.envs_utils import get_env_start_args
3030
from lightllm.utils.dist_utils import get_global_world_size
3131
from lightllm.utils.log_utils import init_logger
3232
from lightllm.utils.sgl_utils import flash_attn_varlen_func, flash_attn_with_kvcache, merge_state_v2
@@ -590,12 +590,11 @@ def _token_gqa_decode_attention_flashattention(
590590
def _token_gqa_decode_attention_flashattention_paged(
591591
self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
592592
):
593-
page_size = get_page_size()
594593
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
595594
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
596595
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
597-
k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, page_size, 1, self.qk_rope_head_dim)
598-
kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, page_size, 1, self.kv_lora_rank)
596+
k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, infer_state.page_size, 1, self.qk_rope_head_dim)
597+
kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, infer_state.page_size, 1, self.kv_lora_rank)
599598
k_descale, v_descale = None, None
600599
o_tensor = flash_attn_with_kvcache(
601600
q=q_rope,
@@ -639,7 +638,6 @@ def _token_gqa_decode_attention_flashinfer(
639638
def _token_gqa_decode_attention_flashinfer_paged(
640639
self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
641640
):
642-
page_size = get_page_size()
643641
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
644642
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
645643

@@ -649,8 +647,8 @@ def _token_gqa_decode_attention_flashinfer_paged(
649647
infer_state.decode_wrapper.run(
650648
q_nope,
651649
q_rope,
652-
kv[:, :, : -self.qk_rope_head_dim].reshape(-1, page_size, 1, self.kv_lora_rank),
653-
kv[:, :, -self.qk_rope_head_dim :].reshape(-1, page_size, 1, self.qk_rope_head_dim),
650+
kv[:, :, : -self.qk_rope_head_dim].reshape(-1, infer_state.page_size, 1, self.kv_lora_rank),
651+
kv[:, :, -self.qk_rope_head_dim :].reshape(-1, infer_state.page_size, 1, self.qk_rope_head_dim),
654652
out=o_tensor,
655653
return_lse=False,
656654
)

lightllm/models/llama/flashattention_infer_struct.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class FlashAttentionStateInfo(LlamaInferStateInfo):
1818

1919
def __init__(self):
2020
super().__init__()
21+
self.page_size = get_page_size()
2122

2223
@classmethod
2324
def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int):
@@ -33,7 +34,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
3334
if self.is_prefill:
3435
self.cu_seqlens_q = self.b1_cu_q_seq_len.int()
3536
self.cu_seqlens_k = self.b1_cu_kv_seq_len.int()
36-
length = cdiv(self.max_seq_len, get_page_size())
37+
length = cdiv(self.max_seq_len, self.page_size)
3738
self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32, device=input_ids.device)
3839
if "page_size_variable" in model.mode:
3940
self.page_table.copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length])
@@ -45,17 +46,16 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
4546
self.cu_seqlens_k = self.b1_cu_kv_seq_len.int()
4647
max_seq_len_k = self.max_kv_seq_len
4748
if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch:
48-
page_size = get_page_size()
49-
length = cdiv(model.graph_max_len_in_batch, page_size)
49+
length = cdiv(model.graph_max_len_in_batch, self.page_size)
5050
page_buffer = FlashAttentionStateInfo.get_page_table_buffer(model.graph_max_batch_size, length)
5151
self.page_table = page_buffer[self.microbatch_index][: self.batch_size * length].reshape(
5252
self.batch_size, length
5353
)
5454
else:
55-
length = cdiv(self.max_len_in_batch, get_page_size())
55+
length = cdiv(self.max_len_in_batch, self.page_size)
5656
self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32, device=input_ids.device)
5757

58-
length = cdiv(max_seq_len_k, get_page_size())
58+
length = cdiv(max_seq_len_k, self.page_size)
5959
if "page_size_variable" in model.mode:
6060
self.page_table[:, :length].copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length])
6161
else:

lightllm/models/llama/layer_infer/transformer_layer_infer.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from lightllm.models.llama.triton_kernel.ppl_quant_copy_kv import destindex_copy_dequantize_kv
2828
from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor
2929
from lightllm.utils.log_utils import init_logger
30-
from lightllm.utils.envs_utils import get_env_start_args, get_page_size
30+
from lightllm.utils.envs_utils import get_env_start_args
3131
from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops
3232
from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant
3333
from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops
@@ -291,9 +291,8 @@ def _paged_context_attention_flashinfer_kernel(
291291
self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None
292292
) -> torch.Tensor:
293293
o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
294-
page_size = get_page_size()
295294
kv = infer_state.mem_manager.kv_buffer[self.layer_num_].view(
296-
-1, page_size, 2 * self.tp_k_head_num_, self.head_dim_
295+
-1, infer_state.page_size, 2 * self.tp_k_head_num_, self.head_dim_
297296
)
298297
infer_state.prefill_wrapper.run(
299298
q.view(q.shape[0], -1, self.head_dim_),
@@ -356,13 +355,12 @@ def _context_attention_kernel_ppl_int8kv(
356355
def _paged_context_attention_flashattention(
357356
self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None
358357
):
359-
page_size = get_page_size()
360358
cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape(
361-
-1, page_size, self.tp_k_head_num_, self.head_dim_
359+
-1, infer_state.page_size, self.tp_k_head_num_, self.head_dim_
362360
)
363361
cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][
364362
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
365-
].reshape(-1, page_size, self.tp_v_head_num_, self.head_dim_)
363+
].reshape(-1, infer_state.page_size, self.tp_v_head_num_, self.head_dim_)
366364
q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_)
367365
k_descale, v_descale = None, None # disable quantization
368366
Lq = q.shape[-1]
@@ -622,9 +620,8 @@ def _paged_token_decode_attention_flashinfer(
622620
calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_)
623621

624622
o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
625-
page_size = get_page_size()
626623
kv = infer_state.mem_manager.kv_buffer[self.layer_num_].view(
627-
-1, page_size, 2 * self.tp_k_head_num_, self.head_dim_
624+
-1, infer_state.page_size, 2 * self.tp_k_head_num_, self.head_dim_
628625
)
629626
infer_state.decode_wrapper.run(
630627
q.view(calcu_shape1),
@@ -914,13 +911,12 @@ def _token_decode_attention_gqa_flashdecoding_vsm(
914911
def _paged_token_decode_attention_flashattention(
915912
self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None
916913
):
917-
page_size = get_page_size()
918914
cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape(
919-
-1, page_size, self.tp_k_head_num_, self.head_dim_
915+
-1, infer_state.page_size, self.tp_k_head_num_, self.head_dim_
920916
)
921917
cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][
922918
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
923-
].reshape(-1, page_size, self.tp_v_head_num_, self.head_dim_)
919+
].reshape(-1, infer_state.page_size, self.tp_v_head_num_, self.head_dim_)
924920
q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_)
925921
k_descale, v_descale = None, None # disable quantization
926922
Lq = q.shape[-1]

lightllm/server/api_cli.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
165165
nargs="+",
166166
help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding
167167
| triton_gqa_attention | triton_gqa_flashdecoding | triton_fp8kv | offline_calibration_fp8kv
168-
| export_fp8kv_calibration
168+
| export_fp8kv_calibration | page_size_variable
169169
triton_flashdecoding mode is for long context, current support llama llama2 qwen;
170170
triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA;
171171
triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel;
@@ -177,6 +177,8 @@ def make_argument_parser() -> argparse.ArgumentParser:
177177
Calibration need to disable cudagraph and use fa3 or flashinfer backend.
178178
ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel;
179179
ppl_fp16 mode use ppl fast fp16 decode attention kernel;
180+
page_size_variable allow to use page size > 1, use PAGE_SIZE env to set page size,
181+
page_size_variable only support fa3 and flashinfer backend for now
180182
you need to read source code to make sure the supported detail mode for all models""",
181183
)
182184
parser.add_argument(

lightllm/server/api_start.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,13 @@ def normal_or_p_d_start(args):
125125
"--enable_flashinfer_prefill and --enable_flashinfer_decode"
126126
)
127127
assert args.disable_cudagraph is True, "export_fp8kv_calibration mode need disable cudagraph"
128+
if "page_size_variable" in args.mode:
129+
assert args.enable_fa3 is True or (
130+
args.enable_flashinfer_prefill is True and args.enable_flashinfer_decode is True
131+
), (
132+
"page_size_variable mode need enable fa3 or flashinfer, add --enable_fa3 or "
133+
"--enable_flashinfer_prefill and --enable_flashinfer_decode"
134+
)
128135

129136
# 部分模式还不能支持与高级动态调度算法协同,to do.
130137
if args.diverse_mode:

0 commit comments

Comments
 (0)