Skip to content

Commit 015ed0c

Browse files
author
niushengxiao
committed
feat: replace page idxs with token idxs in paged_mem_manager
1 parent e70c85a commit 015ed0c

File tree

6 files changed

+56
-89
lines changed

6 files changed

+56
-89
lines changed

lightllm/common/paged_mem_manager.py

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,6 @@ def cdiv(a, b):
1818
class PagedMemoryManager(MemoryManager):
1919
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
2020
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction)
21-
page_size = get_page_size()
22-
self.mem_page_state = torch.arange(
23-
0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
24-
)
25-
self.mark_page_start = 0
26-
self.can_use_page_size = cdiv(self.size, page_size)
2721

2822
def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
2923
self.kv_buffer = torch.empty(
@@ -53,42 +47,23 @@ def check_cache_page_valid(self, values: torch.Tensor):
5347
return True
5448

5549
def alloc(self, need_size) -> torch.Tensor:
56-
if self.can_use_page_size < need_size:
57-
raise RuntimeError(
58-
f"No available pages for alloc. remaining: {self.can_use_page_size}, needed: {need_size}"
59-
)
60-
new_pages = self.mem_page_state[self.mark_page_start : self.mark_page_start + need_size].cuda()
61-
self.mark_page_start += need_size
62-
self.can_use_page_size -= need_size
63-
self.can_use_mem_size -= need_size * get_page_size()
64-
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
65-
return new_pages
50+
assert need_size % get_page_size() == 0, "Need size must be a multiple of page size"
51+
return super().alloc(need_size)
6652

6753
def free(self, free_index: Union[torch.Tensor, List[int]]):
68-
self.can_use_mem_size += len(free_index)
69-
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
70-
7154
page_size = get_page_size()
72-
if isinstance(free_index, list):
73-
free_index = torch.tensor(free_index, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True)
74-
75-
if len(free_index) == 0:
76-
return
55+
if page_size == 1:
56+
return super().free(free_index)
7757

58+
if isinstance(free_index, list):
59+
free_index = torch.tensor(free_index)
7860
base_free_index = free_index[free_index % page_size == 0]
79-
page_indices = base_free_index // page_size
80-
for page_idx in sorted(page_indices, reverse=True): # 逆序放回,保持池的相对顺序
81-
self.mark_page_start -= 1
82-
self.mem_page_state[self.mark_page_start] = page_idx
83-
self.can_use_page_size += 1
84-
61+
if len(base_free_index) == 0:
62+
return
63+
token_idxs = base_free_index[:, None] + torch.arange(page_size, device=free_index.device)
64+
token_idxs = token_idxs.flatten()
65+
super().free(token_idxs)
8566
return
8667

8768
def free_all(self):
8869
super().free_all()
89-
page_size = get_page_size()
90-
self.mark_page_start = 0
91-
self.can_use_page_size = cdiv(self.size, page_size)
92-
self.mem_page_state = torch.arange(
93-
0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
94-
)

lightllm/common/req_manager.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana
7171
self.max_request_num = max_request_num
7272
self.HOLD_REQUEST_ID = max_request_num
7373

74-
def expand_by_page_size(self, b_token_len, page_size):
74+
def _expand_by_page_size(self, b_token_len, page_size):
7575
# 将seq_len按page整数倍展开,例如seq_len = [9,9,9] -> page_len = [4,4,1,4,4,1,4,4,1], page_size = 4
7676
b_page_len = cdiv(b_token_len, page_size)
7777
need_pages_num = b_page_len.sum()
@@ -82,36 +82,28 @@ def expand_by_page_size(self, b_token_len, page_size):
8282
p_token_len[last_page_positions] = remainders
8383
return need_pages_num, b_page_len, p_token_len
8484

85-
def alloc_paged_token_indices(self, b_req_idx, page_size, b_seq_len, b_ready_cache_len):
85+
def _alloc_paged_token_indices(self, b_req_idx, page_size, b_seq_len, b_ready_cache_len):
8686
if b_ready_cache_len is not None:
8787
# prefill
88-
b_req_idx = b_req_idx.cuda()
89-
b_seq_len = b_seq_len.cuda()
90-
b_ready_cache_len = b_ready_cache_len.cuda()
88+
b_seq_len = b_seq_len.cpu()
89+
b_ready_cache_len = b_ready_cache_len.cpu()
9190

9291
b_token_len = b_seq_len - b_ready_cache_len
93-
total_pages_needed, b_page_len, p_token_len = self.expand_by_page_size(b_token_len, page_size)
94-
allocated_pages = self.mem_manager.alloc(total_pages_needed)
95-
96-
def get_offsets_by_length(b_len, max_len):
97-
# 例:b_len = [3,4,5] -> [0,1,2,0,1,2,3,0,1,2,3,4]
98-
offsets = torch.arange(max_len, dtype=b_len.dtype, device=b_len.device)
99-
offset_mask = offsets.unsqueeze(0) < b_len.unsqueeze(1)
100-
return torch.masked_select(offsets, offset_mask)
101-
102-
token_offsets = get_offsets_by_length(p_token_len, page_size)
103-
page_bases = allocated_pages * page_size
104-
return torch.repeat_interleave(page_bases, p_token_len) + token_offsets
92+
total_pages_needed, b_page_len, p_token_len = self._expand_by_page_size(b_token_len, page_size)
93+
paged_token_idxs = self.mem_manager.alloc(total_pages_needed * page_size)
94+
pages = paged_token_idxs.view(-1, page_size)
95+
mask = torch.arange(page_size, device=p_token_len.device) < p_token_len.unsqueeze(1)
96+
return pages[mask]
10597
else:
10698
# decode
10799
b_seq_len = b_seq_len.cuda()
108100
b_req_idx = b_req_idx.cuda()
109101
need_new_page_mask = (b_seq_len - 1) % page_size == 0
110-
new_pages_num = need_new_page_mask.sum()
102+
new_pages_num = need_new_page_mask.sum().cpu()
111103
token_idxs = torch.zeros_like(b_seq_len, device=b_seq_len.device)
112104
if new_pages_num > 0:
113-
new_pages = self.mem_manager.alloc(new_pages_num)
114-
token_idxs[need_new_page_mask] = new_pages * page_size
105+
new_pages_tokens = self.mem_manager.alloc(new_pages_num * page_size).cuda()
106+
token_idxs[need_new_page_mask] = new_pages_tokens[::page_size]
115107

116108
mask = ~need_new_page_mask
117109
if mask.any():
@@ -122,10 +114,28 @@ def get_offsets_by_length(b_len, max_len):
122114
)
123115
return token_idxs
124116

117+
def _get_need_paged_token_num(self, b_seq_len, b_ready_cache_len=None):
118+
page_size = get_page_size()
119+
if page_size == 1:
120+
return 0
121+
122+
need_new_pages = 0
123+
if b_ready_cache_len is not None:
124+
need_tokens_array = b_seq_len - b_ready_cache_len
125+
need_pages_array = (need_tokens_array + page_size - 1) // page_size
126+
need_new_pages = need_pages_array.sum()
127+
else:
128+
mask = (b_seq_len - 1) % page_size == 0
129+
need_new_pages = mask.sum()
130+
return need_new_pages * page_size
131+
132+
def calc_real_need_token_num(self, need_token_num, b_seq_len, b_ready_cache_len=None):
133+
return max(need_token_num, self._get_need_paged_token_num(b_seq_len, b_ready_cache_len))
134+
125135
def alloc_token_indices(self, need_size, b_req_idx, b_seq_len, b_ready_cache_len=None) -> torch.Tensor:
126136
page_size = get_page_size()
127137
if page_size > 1:
128-
return self.alloc_paged_token_indices(b_req_idx, page_size, b_seq_len, b_ready_cache_len)
138+
return self._alloc_paged_token_indices(b_req_idx, page_size, b_seq_len, b_ready_cache_len)
129139
else:
130140
return self.mem_manager.alloc(need_size)
131141

lightllm/server/router/dynamic_prompt/paged_radix_cache.py

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -391,39 +391,17 @@ def _print_helper(self, node: TreeNode, indent):
391391
self._print_helper(child, indent=indent + 2)
392392
return
393393

394-
def free_radix_cache_to_get_enough_token(self, need_token_num=None, b_seq_len=None, b_ready_cache_len=None):
394+
def free_radix_cache_to_get_enough_token(self, need_token_num):
395395
assert self.mem_manager is not None
396-
need_pages = 0
397-
can_use_pages = 0
398-
if hasattr(self.mem_manager, "can_use_page_size") and self.page_size > 1 and b_seq_len is not None:
399-
400-
def get_need_page_size(page_size, b_seq_len, b_ready_cache_len=None):
401-
need_new_pages = 0
402-
if b_ready_cache_len is not None:
403-
need_tokens_array = b_seq_len - b_ready_cache_len
404-
need_pages_array = (need_tokens_array + page_size - 1) // page_size
405-
need_new_pages = need_pages_array.sum()
406-
else:
407-
mask = (b_seq_len - 1) % page_size == 0
408-
need_new_pages = mask.sum()
409-
return need_new_pages
410-
411-
need_pages = get_need_page_size(self.page_size, b_seq_len, b_ready_cache_len)
412-
can_use_pages = self.mem_manager.can_use_page_size
413-
if need_token_num > self.mem_manager.can_use_mem_size or need_pages > can_use_pages:
414-
need_evict_single_token_num = need_token_num - self.mem_manager.can_use_mem_size
415-
need_evict_page_token_num = (need_pages - can_use_pages) * self.page_size
416-
need_evict_token_num = max(need_evict_single_token_num, need_evict_page_token_num)
417-
remaining_tokens = self.get_tree_total_tokens_num() - self.get_refed_tokens_num()
418-
need_evict_token_num = min(need_evict_token_num, remaining_tokens)
396+
if need_token_num > self.mem_manager.can_use_mem_size:
397+
need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size
419398
release_mems = []
420399

421400
def release_mem(mem_index):
422401
release_mems.append(mem_index)
423402
return
424403

425404
self.evict(need_evict_token_num, release_mem)
426-
if release_mems:
427-
mem_index = torch.concat(release_mems)
428-
self.mem_manager.free(mem_index)
405+
mem_index = torch.concat(release_mems)
406+
self.mem_manager.free(mem_index)
429407
return

lightllm/server/router/dynamic_prompt/radix_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def _print_helper(self, node: TreeNode, indent):
333333
self._print_helper(child, indent=indent + 2)
334334
return
335335

336-
def free_radix_cache_to_get_enough_token(self, need_token_num=None, b_seq_len=None, b_ready_cache_len=None):
336+
def free_radix_cache_to_get_enough_token(self, need_token_num):
337337
assert self.mem_manager is not None
338338
if need_token_num > self.mem_manager.can_use_mem_size:
339339
need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size

lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,10 @@ def padded_prepare_prefill_inputs(
7777
# dynamic prompt cache 准备 token
7878
g_infer_state_lock.acquire()
7979
if g_infer_context.radix_cache is not None:
80-
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(
80+
token_num = g_infer_context.req_manager.calc_real_need_token_num(
8181
input_ids.shape[0] - padded_req_num, b_seq_len, b_ready_cache_len
8282
)
83+
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num)
8384
mem_indexes = g_infer_context.req_manager.alloc_token_indices(
8485
input_ids.shape[0] - padded_req_num, b_req_idx, b_seq_len, b_ready_cache_len
8586
)
@@ -167,7 +168,8 @@ def padded_prepare_decode_inputs(
167168
# dynamic prompt cache 准备 token
168169
g_infer_state_lock.acquire()
169170
if g_infer_context.radix_cache is not None:
170-
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0] - padded_req_num, b_seq_len)
171+
token_num = g_infer_context.req_manager.calc_real_need_token_num(b_seq_len.shape[0] - padded_req_num, b_seq_len)
172+
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num)
171173
mem_indexes = g_infer_context.req_manager.alloc_token_indices(
172174
b_seq_len.shape[0] - padded_req_num, b_req_idx, b_seq_len
173175
)

lightllm/server/router/model_infer/mode_backend/generic_pre_process.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ def prepare_prefill_inputs(
5555
# dynamic prompt cache 准备 token
5656
g_infer_state_lock.acquire()
5757
if g_infer_context.radix_cache is not None:
58-
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(
58+
token_num = g_infer_context.req_manager.calc_real_need_token_num(
5959
input_ids.shape[0], b_seq_len, b_ready_cache_len
6060
)
61+
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num)
6162
mem_indexes = g_infer_context.req_manager.alloc_token_indices(
6263
input_ids.shape[0], b_req_idx, b_seq_len, b_ready_cache_len
6364
)
@@ -115,7 +116,8 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
115116
# dynamic prompt cache 准备 token
116117
g_infer_state_lock.acquire()
117118
if g_infer_context.radix_cache is not None:
118-
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0], b_seq_len)
119+
token_num = g_infer_context.req_manager.calc_real_need_token_num(b_seq_len.shape[0], b_seq_len)
120+
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num)
119121
mem_indexes = g_infer_context.req_manager.alloc_token_indices(b_seq_len.shape[0], b_req_idx, b_seq_len)
120122
g_infer_state_lock.release()
121123

0 commit comments

Comments
 (0)