Skip to content

Commit 3b979cc

Browse files
author
niushengxiao
committed
fix: fix the page not enough bug
1 parent 49843b2 commit 3b979cc

File tree

11 files changed

+72
-76
lines changed

11 files changed

+72
-76
lines changed

lightllm/common/mem_manager.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
5252
layer_num,
5353
)
5454
self.HOLD_TOKEN_MEMINDEX = self.size
55+
# MemoryManager也需要个引用备份,供内部使用
5556
self.req_to_token_indexs = None
5657

5758
def get_cell_size(self):
@@ -341,8 +342,17 @@ def __init__(self) -> None:
341342
SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}")
342343
for rank_in_node in range(0, self.node_world_size, self.dp_world_size)
343344
]
345+
self.shared_tp_info_pages = [
346+
SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_page_num_{rank_in_node}")
347+
for rank_in_node in range(0, self.node_world_size, self.dp_world_size)
348+
]
344349

345350
def get_unrefed_token_num(self, dp_rank_in_node: int):
346351
if self.is_multinode_tp:
347352
return self.shared_tp_infos[0].get_value()
348353
return self.shared_tp_infos[dp_rank_in_node].get_value()
354+
355+
def get_unrefed_page_num(self, dp_rank_in_node: int):
356+
if self.is_multinode_tp:
357+
return self.shared_tp_info_pages[0].get_value()
358+
return self.shared_tp_info_pages[dp_rank_in_node].get_value()

lightllm/common/page_size_variable_mem_manager.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from .mem_manager import MemoryManager
44
from typing import List, Union
55
from lightllm.utils.log_utils import init_logger
6-
from lightllm.utils.envs_utils import get_page_size
6+
from lightllm.utils.envs_utils import get_unique_server_name, get_page_size
7+
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
8+
from lightllm.utils.dist_utils import get_current_rank_in_node
79

810

911
def cdiv(a, b):
@@ -24,6 +26,12 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
2426
self.mark_page_start = 0
2527
self.can_use_page_size = cdiv(self.size, page_size)
2628

29+
rank_in_node = get_current_rank_in_node()
30+
self.shared_can_use_page_num = SharedInt(
31+
f"{get_unique_server_name()}_mem_manger_can_use_page_num_{rank_in_node}"
32+
)
33+
self.shared_can_use_page_num.set_value(self.can_use_page_size)
34+
2735
def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
2836
self.kv_buffer = torch.empty(
2937
(layer_num, cdiv(size, get_page_size()) * get_page_size(), 2 * head_num, head_dim),
@@ -141,6 +149,7 @@ def alloc(self, need_size, b_req_idx, b_seq_len, b_ready_cache_len=None, is_pref
141149
token_idxs = self.get_paged_token_indexs(b_req_idx, page_size, b_seq_len, b_ready_cache_len, is_prefill)
142150
self.can_use_mem_size -= need_size
143151
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
152+
self.shared_can_use_page_num.set_value(self.can_use_page_size)
144153
return token_idxs
145154

146155
def free(self, free_index: Union[torch.Tensor, List[int]]):
@@ -154,12 +163,13 @@ def free(self, free_index: Union[torch.Tensor, List[int]]):
154163
if len(free_index) == 0:
155164
return
156165

157-
page_indices = free_index // page_size
158-
unique_pages = torch.unique(page_indices)
159-
for page_idx in sorted(unique_pages, reverse=True): # 逆序放回,保持池的相对顺序
166+
base_free_index = free_index[free_index % page_size == 0]
167+
page_indices = base_free_index // page_size
168+
for page_idx in sorted(page_indices, reverse=True): # 逆序放回,保持池的相对顺序
160169
self.mark_page_start -= 1
161170
self.page_idx_pool[self.mark_page_start] = page_idx
162171
self.can_use_page_size += 1
172+
self.shared_can_use_page_num.set_value(self.can_use_page_size)
163173

164174
return
165175

@@ -168,6 +178,7 @@ def free_all(self):
168178
page_size = get_page_size()
169179
self.mark_page_start = 0
170180
self.can_use_page_size = cdiv(self.size, page_size)
181+
self.shared_can_use_page_num.set_value(self.can_use_page_size)
171182
self.page_idx_pool = torch.arange(
172183
0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
173184
)

lightllm/models/deepseek2/flashattention_infer_struct.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class Deepseek2FlashAttentionStateInfo(Deepseek2InferStateInfo):
1616

1717
def __init__(self):
1818
super().__init__()
19+
self.page_size = get_page_size()
1920

2021
@classmethod
2122
def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int):
@@ -43,19 +44,18 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
4344
self.cu_seqlens_q = self.b1_cu_q_seq_len
4445
self.cu_seqlens_k = self.b1_cu_kv_seq_len
4546
max_seq_len_k = self.max_kv_seq_len
46-
page_size = get_page_size()
4747
if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch:
48-
length = cdiv(model.graph_max_len_in_batch, page_size)
48+
length = cdiv(model.graph_max_len_in_batch, self.page_size)
4949
page_buffer = Deepseek2FlashAttentionStateInfo.get_page_table_buffer(model.graph_max_batch_size, length)
5050
self.page_table = page_buffer[self.microbatch_index][: self.batch_size * length].reshape(
5151
self.batch_size, length
5252
)
5353
else:
54-
length = cdiv(self.max_len_in_batch, page_size)
54+
length = cdiv(self.max_len_in_batch, self.page_size)
5555
self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32).to(input_ids.device)
5656

5757
if "page_size_variable" in model.mode:
58-
length = cdiv(max_seq_len_k, page_size)
58+
length = cdiv(max_seq_len_k, self.page_size)
5959
self.page_table[:, :length].copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length])
6060
self.page_table[:, length:].fill_(0)
6161
else:

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from functools import partial
2727
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
2828
from lightllm.distributed.communication_op import all_gather, all_gather_into_tensor, all_reduce, reduce_scatter_tensor
29-
from lightllm.utils.envs_utils import get_env_start_args, get_page_size
29+
from lightllm.utils.envs_utils import get_env_start_args
3030
from lightllm.utils.dist_utils import get_global_world_size
3131
from lightllm.utils.log_utils import init_logger
3232
from lightllm.utils.sgl_utils import flash_attn_varlen_func, flash_attn_with_kvcache, merge_state_v2
@@ -589,12 +589,11 @@ def _token_gqa_decode_attention_flashattention(
589589
def _token_gqa_decode_attention_flashattention_paged(
590590
self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
591591
):
592-
page_size = get_page_size()
593592
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
594593
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
595594
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
596-
k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, page_size, 1, self.qk_rope_head_dim)
597-
kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, page_size, 1, self.kv_lora_rank)
595+
k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, infer_state.page_size, 1, self.qk_rope_head_dim)
596+
kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, infer_state.page_size, 1, self.kv_lora_rank)
598597
k_descale, v_descale = None, None
599598
o_tensor = flash_attn_with_kvcache(
600599
q=q_rope,
@@ -638,7 +637,6 @@ def _token_gqa_decode_attention_flashinfer(
638637
def _token_gqa_decode_attention_flashinfer_paged(
639638
self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
640639
):
641-
page_size = get_page_size()
642640
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
643641
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
644642

@@ -648,8 +646,8 @@ def _token_gqa_decode_attention_flashinfer_paged(
648646
infer_state.decode_wrapper.run(
649647
q_nope,
650648
q_rope,
651-
kv[:, :, : -self.qk_rope_head_dim].reshape(-1, page_size, 1, self.kv_lora_rank),
652-
kv[:, :, -self.qk_rope_head_dim :].reshape(-1, page_size, 1, self.qk_rope_head_dim),
649+
kv[:, :, : -self.qk_rope_head_dim].reshape(-1, infer_state.page_size, 1, self.kv_lora_rank),
650+
kv[:, :, -self.qk_rope_head_dim :].reshape(-1, infer_state.page_size, 1, self.qk_rope_head_dim),
653651
out=o_tensor,
654652
return_lse=False,
655653
)

lightllm/models/llama/flashattention_infer_struct.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class FlashAttentionStateInfo(LlamaInferStateInfo):
1818

1919
def __init__(self):
2020
super().__init__()
21+
self.page_size = get_page_size()
2122

2223
@classmethod
2324
def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int):
@@ -32,7 +33,7 @@ def _init_flash_attention_state(self, model, input_ids: torch.Tensor):
3233
if self.is_prefill:
3334
self.cu_seqlens_q = self.b1_cu_q_seq_len.int()
3435
self.cu_seqlens_k = self.b1_cu_kv_seq_len.int()
35-
length = cdiv(self.max_seq_len, get_page_size())
36+
length = cdiv(self.max_seq_len, self.page_size)
3637
self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32, device=input_ids.device)
3738
if "page_size_variable" in model.mode:
3839
self.page_table.copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length])
@@ -44,17 +45,16 @@ def _init_flash_attention_state(self, model, input_ids: torch.Tensor):
4445
self.cu_seqlens_k = self.b1_cu_kv_seq_len.int()
4546
max_seq_len_k = self.max_kv_seq_len
4647
if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch:
47-
page_size = get_page_size()
48-
length = cdiv(model.graph_max_len_in_batch, page_size)
48+
length = cdiv(model.graph_max_len_in_batch, self.page_size)
4949
page_buffer = FlashAttentionStateInfo.get_page_table_buffer(model.graph_max_batch_size, length)
5050
self.page_table = page_buffer[self.microbatch_index][: self.batch_size * length].reshape(
5151
self.batch_size, length
5252
)
5353
else:
54-
length = cdiv(self.max_len_in_batch, get_page_size())
54+
length = cdiv(self.max_len_in_batch, self.page_size)
5555
self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32, device=input_ids.device)
5656

57-
length = cdiv(max_seq_len_k, get_page_size())
57+
length = cdiv(max_seq_len_k, self.page_size)
5858
if "page_size_variable" in model.mode:
5959
self.page_table[:, :length].copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length])
6060
else:

lightllm/models/llama/layer_infer/transformer_layer_infer.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from lightllm.models.llama.triton_kernel.ppl_quant_copy_kv import destindex_copy_dequantize_kv
2828
from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor
2929
from lightllm.utils.log_utils import init_logger
30-
from lightllm.utils.envs_utils import get_env_start_args, get_page_size
30+
from lightllm.utils.envs_utils import get_env_start_args
3131
from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops
3232
from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant
3333
from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops
@@ -291,9 +291,8 @@ def _paged_context_attention_flashinfer_kernel(
291291
self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None
292292
) -> torch.Tensor:
293293
o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
294-
page_size = get_page_size()
295294
kv = infer_state.mem_manager.kv_buffer[self.layer_num_].view(
296-
-1, page_size, 2 * self.tp_k_head_num_, self.head_dim_
295+
-1, infer_state.page_size, 2 * self.tp_k_head_num_, self.head_dim_
297296
)
298297
infer_state.prefill_wrapper.run(
299298
q.view(q.shape[0], -1, self.head_dim_),
@@ -356,13 +355,12 @@ def _context_attention_kernel_ppl_int8kv(
356355
def _paged_context_attention_flashattention(
357356
self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None
358357
):
359-
page_size = get_page_size()
360358
cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape(
361-
-1, page_size, self.tp_k_head_num_, self.head_dim_
359+
-1, infer_state.page_size, self.tp_k_head_num_, self.head_dim_
362360
)
363361
cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][
364362
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
365-
].reshape(-1, page_size, self.tp_v_head_num_, self.head_dim_)
363+
].reshape(-1, infer_state.page_size, self.tp_v_head_num_, self.head_dim_)
366364
q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_)
367365
k_descale, v_descale = None, None # disable quantization
368366
Lq = q.shape[-1]
@@ -622,9 +620,8 @@ def _paged_token_decode_attention_flashinfer(
622620
calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_)
623621

624622
o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
625-
page_size = get_page_size()
626623
kv = infer_state.mem_manager.kv_buffer[self.layer_num_].view(
627-
-1, page_size, 2 * self.tp_k_head_num_, self.head_dim_
624+
-1, infer_state.page_size, 2 * self.tp_k_head_num_, self.head_dim_
628625
)
629626
infer_state.decode_wrapper.run(
630627
q.view(calcu_shape1),
@@ -914,13 +911,12 @@ def _token_decode_attention_gqa_flashdecoding_vsm(
914911
def _paged_token_decode_attention_flashattention(
915912
self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None
916913
):
917-
page_size = get_page_size()
918914
cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape(
919-
-1, page_size, self.tp_k_head_num_, self.head_dim_
915+
-1, infer_state.page_size, self.tp_k_head_num_, self.head_dim_
920916
)
921917
cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][
922918
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
923-
].reshape(-1, page_size, self.tp_v_head_num_, self.head_dim_)
919+
].reshape(-1, infer_state.page_size, self.tp_v_head_num_, self.head_dim_)
924920
q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_)
925921
k_descale, v_descale = None, None # disable quantization
926922
Lq = q.shape[-1]

lightllm/server/api_cli.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
179179
nargs="+",
180180
help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding
181181
| triton_gqa_attention | triton_gqa_flashdecoding | triton_fp8kv | offline_calibration_fp8kv
182-
| export_fp8kv_calibration
182+
| export_fp8kv_calibration | page_size_variable
183183
triton_flashdecoding mode is for long context, current support llama llama2 qwen;
184184
triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA;
185185
triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel;
@@ -191,6 +191,8 @@ def make_argument_parser() -> argparse.ArgumentParser:
191191
Calibration need to disable cudagraph and use fa3 or flashinfer backend.
192192
ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel;
193193
ppl_fp16 mode use ppl fast fp16 decode attention kernel;
194+
page_size_variable allow to use page size > 1, use PAGE_SIZE env to set page size,
195+
page_size_variable only support fa3 and flashinfer backend for now
194196
you need to read source code to make sure the supported detail mode for all models""",
195197
)
196198
parser.add_argument(

lightllm/server/api_start.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def normal_or_p_d_start(args):
9494

9595
if args.graph_max_len_in_batch == 0:
9696
args.graph_max_len_in_batch = args.max_req_total_len
97-
97+
9898
# mode setting check.
9999
if args.output_constraint_mode != "none":
100100
assert args.disable_dynamic_prompt_cache is False
@@ -126,6 +126,13 @@ def normal_or_p_d_start(args):
126126
"--enable_flashinfer_prefill and --enable_flashinfer_decode"
127127
)
128128
assert args.disable_cudagraph is True, "export_fp8kv_calibration mode need disable cudagraph"
129+
if "page_size_variable" in args.mode:
130+
assert args.enable_fa3 is True or (
131+
args.enable_flashinfer_prefill is True and args.enable_flashinfer_decode is True
132+
), (
133+
"page_size_variable mode need enable fa3 or flashinfer, add --enable_fa3 or "
134+
"--enable_flashinfer_prefill and --enable_flashinfer_decode"
135+
)
129136

130137
# 部分模式还不能支持与高级动态调度算法协同,to do.
131138
if args.diverse_mode:

lightllm/server/router/dynamic_prompt/paged_radix_cache.py

Lines changed: 9 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None)
159159
)
160160
self.tree_total_tokens_num.arr[0] = 0
161161

162-
def _get_page_aligned_key(self, key, value=None):
162+
def _get_page_aligned_key(self, key, value=None, free_truncated=False):
163163
aligned_len = len(key)
164164
if aligned_len == 0:
165165
return None, None
@@ -171,6 +171,13 @@ def _get_page_aligned_key(self, key, value=None):
171171
aligned_len = aligned_len & ~self._page_size_mask
172172
else:
173173
aligned_len = (aligned_len // self.page_size) * self.page_size
174+
175+
# 释放被截断的部分
176+
if free_truncated and aligned_len < len(key) and self.mem_manager is not None:
177+
truncated_value = value[aligned_len:] if value is not None else key[aligned_len:]
178+
if len(truncated_value) > 0:
179+
self.mem_manager.free(truncated_value)
180+
174181
return (
175182
key[:aligned_len] if aligned_len > 0 else None,
176183
value[:aligned_len] if value is not None and aligned_len > 0 else None,
@@ -182,7 +189,7 @@ def insert(self, key, value=None):
182189
value = key
183190

184191
assert len(key) == len(value) # and len(key) >= 1
185-
key, value = self._get_page_aligned_key(key, value)
192+
key, value = self._get_page_aligned_key(key, value, free_truncated=True)
186193
if key is None:
187194
return 0
188195
return self._insert_helper(self.root_node, key, value)
@@ -422,41 +429,3 @@ def release_mem(mem_index):
422429
mem_index = torch.concat(release_mems)
423430
self.mem_manager.free(mem_index)
424431
return
425-
426-
427-
class _RadixCacheReadOnlyClient:
428-
"""
429-
router 端只读用的客户端,用于从共享内存中读取树结构中的信息,用于进行prompt cache 的调度估计。
430-
"""
431-
432-
def __init__(self, unique_name, total_token_num, rank_in_node):
433-
self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{rank_in_node}", (1,), dtype=np.int64)
434-
self.tree_total_tokens_num = SharedArray(
435-
f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64
436-
)
437-
438-
def get_refed_tokens_num(self):
439-
return self.refed_tokens_num.arr[0]
440-
441-
def get_tree_total_tokens_num(self):
442-
return self.tree_total_tokens_num.arr[0]
443-
444-
def get_unrefed_tokens_num(self):
445-
return self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0]
446-
447-
448-
class RadixCacheReadOnlyClient:
449-
def __init__(self, unique_name, total_token_num, node_world_size, dp_world_size):
450-
self.dp_rank_clients: List[_RadixCacheReadOnlyClient] = [
451-
_RadixCacheReadOnlyClient(unique_name, total_token_num, rank_in_node)
452-
for rank_in_node in range(0, node_world_size, dp_world_size)
453-
]
454-
455-
def get_refed_tokens_num(self, dp_rank_in_node):
456-
return self.dp_rank_clients[dp_rank_in_node].get_refed_tokens_num()
457-
458-
def get_tree_total_tokens_num(self, dp_rank_in_node):
459-
return self.dp_rank_clients[dp_rank_in_node].get_tree_total_tokens_num()
460-
461-
def get_unrefed_tokens_num(self, dp_rank_in_node):
462-
return self.dp_rank_clients[dp_rank_in_node].get_unrefed_tokens_num()

0 commit comments

Comments
 (0)