Skip to content

Commit 61ef622

Browse files
author
niushengxiao
committed
fix: fix the page not enough bug
1 parent 18be818 commit 61ef622

File tree

4 files changed

+60
-51
lines changed

4 files changed

+60
-51
lines changed

lightllm/common/mem_manager.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,17 @@ def __init__(self) -> None:
341341
SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}")
342342
for rank_in_node in range(0, self.node_world_size, self.dp_world_size)
343343
]
344+
self.shared_tp_info_pages = [
345+
SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_page_num_{rank_in_node}")
346+
for rank_in_node in range(0, self.node_world_size, self.dp_world_size)
347+
]
344348

345349
def get_unrefed_token_num(self, dp_rank_in_node: int):
346350
if self.is_multinode_tp:
347351
return self.shared_tp_infos[0].get_value()
348352
return self.shared_tp_infos[dp_rank_in_node].get_value()
353+
354+
def get_unrefed_page_num(self, dp_rank_in_node: int):
355+
if self.is_multinode_tp:
356+
return self.shared_tp_info_pages[0].get_value()
357+
return self.shared_tp_info_pages[dp_rank_in_node].get_value()

lightllm/common/page_size_variable_mem_manager.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from .mem_manager import MemoryManager
44
from typing import List, Union
55
from lightllm.utils.log_utils import init_logger
6-
from lightllm.utils.envs_utils import get_page_size
6+
from lightllm.utils.envs_utils import get_unique_server_name, get_page_size
7+
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
8+
from lightllm.utils.dist_utils import get_current_rank_in_node
79

810

911
def cdiv(a, b):
@@ -24,6 +26,12 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
2426
self.mark_page_start = 0
2527
self.can_use_page_size = cdiv(self.size, page_size)
2628

29+
rank_in_node = get_current_rank_in_node()
30+
self.shared_can_use_page_num = SharedInt(
31+
f"{get_unique_server_name()}_mem_manger_can_use_page_num_{rank_in_node}"
32+
)
33+
self.shared_can_use_page_num.set_value(self.can_use_page_size)
34+
2735
def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
2836
self.kv_buffer = torch.empty(
2937
(layer_num, cdiv(size, get_page_size()) * get_page_size(), 2 * head_num, head_dim),
@@ -141,6 +149,7 @@ def alloc(self, need_size, b_req_idx, b_seq_len, b_ready_cache_len=None, is_pref
141149
token_idxs = self.get_paged_token_indexs(b_req_idx, page_size, b_seq_len, b_ready_cache_len, is_prefill)
142150
self.can_use_mem_size -= need_size
143151
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
152+
self.shared_can_use_page_num.set_value(self.can_use_page_size)
144153
return token_idxs
145154

146155
def free(self, free_index: Union[torch.Tensor, List[int]]):
@@ -154,12 +163,13 @@ def free(self, free_index: Union[torch.Tensor, List[int]]):
154163
if len(free_index) == 0:
155164
return
156165

157-
page_indices = free_index // page_size
158-
unique_pages = torch.unique(page_indices)
159-
for page_idx in sorted(unique_pages, reverse=True): # 逆序放回,保持池的相对顺序
166+
base_free_index = free_index[free_index % page_size == 0]
167+
page_indices = base_free_index // page_size
168+
for page_idx in sorted(page_indices, reverse=True): # 逆序放回,保持池的相对顺序
160169
self.mark_page_start -= 1
161170
self.page_idx_pool[self.mark_page_start] = page_idx
162171
self.can_use_page_size += 1
172+
self.shared_can_use_page_num.set_value(self.can_use_page_size)
163173

164174
return
165175

@@ -168,6 +178,7 @@ def free_all(self):
168178
page_size = get_page_size()
169179
self.mark_page_start = 0
170180
self.can_use_page_size = cdiv(self.size, page_size)
181+
self.shared_can_use_page_num.set_value(self.can_use_page_size)
171182
self.page_idx_pool = torch.arange(
172183
0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
173184
)

lightllm/server/router/dynamic_prompt/paged_radix_cache.py

Lines changed: 9 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None)
159159
)
160160
self.tree_total_tokens_num.arr[0] = 0
161161

162-
def _get_page_aligned_key(self, key, value=None):
162+
def _get_page_aligned_key(self, key, value=None, free_truncated=False):
163163
aligned_len = len(key)
164164
if aligned_len == 0:
165165
return None, None
@@ -171,6 +171,13 @@ def _get_page_aligned_key(self, key, value=None):
171171
aligned_len = aligned_len & ~self._page_size_mask
172172
else:
173173
aligned_len = (aligned_len // self.page_size) * self.page_size
174+
175+
# 释放被截断的部分
176+
if free_truncated and aligned_len < len(key) and self.mem_manager is not None:
177+
truncated_value = value[aligned_len:] if value is not None else key[aligned_len:]
178+
if len(truncated_value) > 0:
179+
self.mem_manager.free(truncated_value)
180+
174181
return (
175182
key[:aligned_len] if aligned_len > 0 else None,
176183
value[:aligned_len] if value is not None and aligned_len > 0 else None,
@@ -182,7 +189,7 @@ def insert(self, key, value=None):
182189
value = key
183190

184191
assert len(key) == len(value) # and len(key) >= 1
185-
key, value = self._get_page_aligned_key(key, value)
192+
key, value = self._get_page_aligned_key(key, value, free_truncated=True)
186193
if key is None:
187194
return 0
188195
return self._insert_helper(self.root_node, key, value)
@@ -422,41 +429,3 @@ def release_mem(mem_index):
422429
mem_index = torch.concat(release_mems)
423430
self.mem_manager.free(mem_index)
424431
return
425-
426-
427-
class _RadixCacheReadOnlyClient:
428-
"""
429-
router 端只读用的客户端,用于从共享内存中读取树结构中的信息,用于进行prompt cache 的调度估计。
430-
"""
431-
432-
def __init__(self, unique_name, total_token_num, rank_in_node):
433-
self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{rank_in_node}", (1,), dtype=np.int64)
434-
self.tree_total_tokens_num = SharedArray(
435-
f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64
436-
)
437-
438-
def get_refed_tokens_num(self):
439-
return self.refed_tokens_num.arr[0]
440-
441-
def get_tree_total_tokens_num(self):
442-
return self.tree_total_tokens_num.arr[0]
443-
444-
def get_unrefed_tokens_num(self):
445-
return self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0]
446-
447-
448-
class RadixCacheReadOnlyClient:
449-
def __init__(self, unique_name, total_token_num, node_world_size, dp_world_size):
450-
self.dp_rank_clients: List[_RadixCacheReadOnlyClient] = [
451-
_RadixCacheReadOnlyClient(unique_name, total_token_num, rank_in_node)
452-
for rank_in_node in range(0, node_world_size, dp_world_size)
453-
]
454-
455-
def get_refed_tokens_num(self, dp_rank_in_node):
456-
return self.dp_rank_clients[dp_rank_in_node].get_refed_tokens_num()
457-
458-
def get_tree_total_tokens_num(self, dp_rank_in_node):
459-
return self.dp_rank_clients[dp_rank_in_node].get_tree_total_tokens_num()
460-
461-
def get_unrefed_tokens_num(self, dp_rank_in_node):
462-
return self.dp_rank_clients[dp_rank_in_node].get_unrefed_tokens_num()

lightllm/server/router/req_queue/chunked_prefill/impl.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
from ...batch import Batch, Req
44
from lightllm.server.router.req_queue.base_queue import BaseQueue
55
from lightllm.common.basemodel.infer_lock import g_router_lock
6+
from lightllm.utils.envs_utils import get_page_size
7+
8+
9+
def cdiv(a, b):
10+
return (a + b - 1) // b
611

712

813
class ChunkedPrefillQueue(BaseQueue):
@@ -21,8 +26,9 @@ def _init_cache_list(self, current_batch: Batch, is_busy):
2126
return
2227

2328
# @calculate_time(show=True, min_cost_ms=0.1)
24-
def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens):
25-
self.cache_len_list.append(req.get_tuple_tokens(is_busy, self.router_max_new_token_len)) # hard to analysis
29+
def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens, new_batch_prefill_need_pages):
30+
token_infos = req.get_tuple_tokens(is_busy, self.router_max_new_token_len)
31+
self.cache_len_list.append(token_infos) # hard to analysis
2632
self.cache_len_list.sort(key=lambda x: -x[1])
2733

2834
left_out_len_array = np.array([e[1] for e in self.cache_len_list])
@@ -42,16 +48,29 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens
4248
new_batch_first_router_need_tokens += req.get_first_router_need_tokens()
4349
ok_prefill = new_batch_first_router_need_tokens <= self.batch_max_tokens
4450

45-
if ok_token_num and ok_req_num and ok_prefill:
51+
# 检查page
52+
ok_page_num = True
53+
if "page_size_variable" in self.router.mode:
54+
available_pages = self.router.read_only_statics_mem_manager.get_unrefed_page_num(self.dp_index)
55+
page_size = get_page_size()
56+
if self.router.radix_cache_client is not None:
57+
radix_cache = self.router.radix_cache_client
58+
available_pages += radix_cache.get_unrefed_tokens_num(self.dp_index) // page_size
59+
60+
new_batch_prefill_need_pages += cdiv(req.input_len + req.shm_cur_output_len, page_size)
61+
decode_need_pages = cdiv((left_out_len_array * size_array).max(), page_size)
62+
ok_page_num = new_batch_prefill_need_pages + decode_need_pages < available_pages
63+
64+
if ok_token_num and ok_req_num and ok_prefill and ok_page_num:
4665
self.router.shared_token_load.set_estimated_peak_token_count(need_max_token_num, self.dp_index)
4766
self.router.shared_token_load.set_dynamic_max_load(
4867
(need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index))
4968
/ self.max_total_tokens,
5069
self.dp_index,
5170
)
52-
return True, new_batch_first_router_need_tokens
71+
return True, new_batch_first_router_need_tokens, new_batch_prefill_need_pages
5372
else:
54-
return False, new_batch_first_router_need_tokens
73+
return False, new_batch_first_router_need_tokens, new_batch_prefill_need_pages
5574

5675
# @calculate_time(show=True, min_cost_ms=10)
5776
def generate_new_batch(self, current_batch: Batch):
@@ -77,15 +96,16 @@ def generate_new_batch(self, current_batch: Batch):
7796

7897
waiting_queue = self.waiting_req_list
7998

99+
new_batch_prefill_need_pages = cdiv(new_batch_first_router_need_tokens, get_page_size())
80100
for req in waiting_queue:
81101
if req.is_aborted:
82102
# 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉.
83103
# 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏
84104
aborted_count += 1
85105
abort_req_list.append(req)
86106
continue
87-
ok_insert, new_batch_first_router_need_tokens = self._can_add_new_req(
88-
req, is_busy, new_batch_first_router_need_tokens
107+
ok_insert, new_batch_first_router_need_tokens, new_batch_prefill_need_pages = self._can_add_new_req(
108+
req, is_busy, new_batch_first_router_need_tokens, new_batch_prefill_need_pages
89109
)
90110
if ok_insert:
91111
can_run_list.append(req)

0 commit comments

Comments
 (0)