Skip to content

Commit 64f649f

Browse files
author
niushengxiao
committed
feat: remove page_size_variable mode
1 parent 0202681 commit 64f649f

24 files changed

+223
-571
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ def _check_max_len_infer(self):
687687
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
688688
b_seq_len[:] = self.batch_max_tokens
689689
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
690-
mem_indexes = self.req_manager.alloc_token_indices(
690+
mem_indexes = self.req_manager.alloc_mem_indices(
691691
len(dummy_input_ids), b_req_idx, b_seq_len, b_ready_cache_len
692692
).cuda()
693693
total_token_num = self.batch_max_tokens
@@ -764,7 +764,7 @@ def _autotune_warmup(self):
764764
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
765765
total_token_num = input_len
766766
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
767-
mem_indexes = self.req_manager.alloc_token_indices(
767+
mem_indexes = self.req_manager.alloc_mem_indices(
768768
len(dummy_input_ids), b_req_idx, b_seq_len, b_ready_cache_len
769769
).cuda()
770770
model_input = ModelInput(

lightllm/common/basemodel/cuda_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def warmup(self, model):
202202
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
203203
b_seq_len.fill_(seq_len)
204204
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
205-
mem_indexes = model.req_manager.alloc_token_indices(len(input_ids), b_req_idx, b_seq_len).cuda()
205+
mem_indexes = model.req_manager.alloc_mem_indices(len(input_ids), b_req_idx, b_seq_len).cuda()
206206

207207
model_input = ModelInput(
208208
batch_size=batch_size,
@@ -258,7 +258,7 @@ def warmup_overlap(self, model):
258258
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
259259
b_seq_len.fill_(seq_len)
260260
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
261-
mem_indexes = model.req_manager.alloc_token_indices(len(input_ids), b_req_idx, b_seq_len).cuda()
261+
mem_indexes = model.req_manager.alloc_mem_indices(len(input_ids), b_req_idx, b_seq_len).cuda()
262262

263263
micro_batch = ModelInput(
264264
is_prefill=False,

lightllm/common/deepseek2_mem_manager.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
99
from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node
1010
from lightllm.distributed.pynccl import PyNcclCommunicator
11+
from lightllm.utils.envs_utils import get_page_size
1112

1213
logger = init_logger(__name__)
1314

@@ -20,7 +21,12 @@ def get_cell_size(self):
2021
return self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype)
2122

2223
def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
23-
self.kv_buffer = torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device="cuda")
24+
page_size = get_page_size()
25+
self.kv_buffer = torch.empty(
26+
(layer_num, (size // page_size + 1) * page_size, head_num, head_dim),
27+
dtype=dtype,
28+
device="cuda",
29+
)
2430

2531
# todo, etp or edp use the same work buffer here
2632
# also it can be used for any kernels for work buffer witout save info only

lightllm/common/deepseek2_paged_mem_manager.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

lightllm/common/mem_manager.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
import os
33
import torch
44
import torch.distributed as dist
5+
import triton
56
from typing import List, Union
67
from lightllm.server.pd_io_struct import KVMoveTask
78
from lightllm.utils.log_utils import init_logger
89
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
910
from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory
1011
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
1112
from lightllm.utils.dist_utils import get_current_rank_in_node
12-
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
13+
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args, get_page_size
1314
from lightllm.distributed.pynccl import PyNcclCommunicator
1415
from lightllm.utils.dist_utils import get_current_device_id
1516

@@ -81,7 +82,12 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
8182
# 分配,内部实际也没有管理,这个token是预留来对一些特殊的运行模式,如多dp下,overlap microbatch
8283
# 等模式下 padding 一些请求,使推理过程可以正常运行采用的,其索引值为size,存储在HOLD_TOKEN_MEMINDEX
8384
# 成员变量中,其与 req_manager 中的HOLD_REQUEST_ID具有类似的作用和意义。
84-
self.kv_buffer = torch.empty((layer_num, size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda")
85+
page_size = get_page_size()
86+
self.kv_buffer = torch.empty(
87+
(layer_num, (size // page_size + 1) * page_size, 2 * head_num, head_dim),
88+
dtype=dtype,
89+
device="cuda",
90+
)
8591

8692
def alloc_kv_move_buffer(self, max_req_total_len):
8793
"""
@@ -244,6 +250,7 @@ def _free_buffers(self):
244250
self.kv_buffer = None
245251

246252
def alloc(self, need_size) -> torch.Tensor:
253+
assert need_size % get_page_size() == 0, "Need size must be a multiple of page size"
247254
if need_size > self.mark_end - self.mark_start:
248255
logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}")
249256
assert False, "error alloc state"
@@ -265,18 +272,25 @@ def free(self, free_index: Union[torch.Tensor, List[int]]):
265272
"""
266273

267274
end = self.mark_start
268-
start = self.mark_start - len(free_index)
269-
assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}"
275+
page_size = get_page_size()
276+
free_len = page_size * triton.cdiv(len(free_index), page_size)
277+
start = self.mark_start - free_len
278+
assert start >= 0, f"error free state start: {self.mark_start} free len {free_len}"
270279

271280
if isinstance(free_index, list):
272-
self.mem_state.numpy()[start:end] = free_index
281+
free_index = torch.tensor(free_index)
282+
283+
# 从 gpu 到 cpu 的拷贝操作是流内阻塞操作
284+
if page_size > 1:
285+
base_free_index = free_index[free_index % page_size == 0]
286+
token_idxs = base_free_index[:, None] + torch.arange(page_size)
287+
self.mem_state[start:end] = token_idxs.flatten()
273288
else:
274-
# 从 gpu 到 cpu 的拷贝操作是流内阻塞操作
275289
self.mem_state[start:end] = free_index
276290

277-
self.mark_start -= len(free_index)
291+
self.mark_start -= free_len
278292

279-
self.can_use_mem_size += len(free_index)
293+
self.can_use_mem_size += free_len
280294
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
281295

282296
if self.can_use_mem_size == len(self.mem_state):

lightllm/common/mem_utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from lightllm.common.export_calibration_mem_manager import ExportCalibrationMemoryManager
55
from lightllm.common.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager
66
from lightllm.common.ppl_int4kv_mem_manager import PPLINT4KVMemoryManager
7-
from lightllm.common.paged_mem_manager import PagedMemoryManager
87
from lightllm.utils.log_utils import init_logger
98

109
logger = init_logger(__name__)
@@ -29,9 +28,6 @@ def select_mem_manager_class(mode):
2928
elif "export_fp8kv_calibration" in mode:
3029
memory_manager_class = ExportCalibrationMemoryManager
3130
logger.info("Using mode export fp8kv calibration")
32-
elif "page_size_variable" in mode:
33-
memory_manager_class = PagedMemoryManager
34-
logger.info("Page size will be variable")
3531
else:
3632
memory_manager_class = MemoryManager
3733
logger.info("Model kv cache using mode normal")

lightllm/common/paged_mem_manager.py

Lines changed: 0 additions & 69 deletions
This file was deleted.

lightllm/common/req_manager.py

Lines changed: 41 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import collections
3+
import triton
34
from lightllm.utils.log_utils import init_logger
45
from .mem_manager import MemoryManager
56
from typing import List, Optional
@@ -11,10 +12,6 @@
1112
logger = init_logger(__name__)
1213

1314

14-
def cdiv(a, b):
15-
return (a + b - 1) // b
16-
17-
1815
class _ReqNode:
1916
def __init__(self, index):
2017
self.index = index
@@ -71,25 +68,60 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana
7168
self.max_request_num = max_request_num
7269
self.HOLD_REQUEST_ID = max_request_num
7370

71+
def calc_real_need_token_num(self, need_token_num, b_seq_len, b_ready_cache_len=None):
72+
return max(need_token_num, self._get_need_paged_token_num(b_seq_len, b_ready_cache_len))
73+
74+
def alloc_mem_indices(self, need_size, b_req_idx=None, b_seq_len=None, b_ready_cache_len=None) -> torch.Tensor:
75+
page_size = get_page_size()
76+
if page_size > 1 and b_req_idx is not None and b_seq_len is not None:
77+
return self._alloc_paged_mem_indices(b_req_idx, page_size, b_seq_len, b_ready_cache_len)
78+
else:
79+
return self.mem_manager.alloc(need_size)
80+
81+
def alloc(self):
82+
return self.req_list.alloc()
83+
84+
def free(self, free_req_indexes: List[int], free_token_index):
85+
for req_index in free_req_indexes:
86+
self.req_list.free(req_index)
87+
88+
if self.req_list.is_all_free():
89+
logger.debug(f"freed all request size {self.req_list.can_alloc_size}")
90+
self.mem_manager.free(free_token_index)
91+
92+
def free_req(self, free_req_index: int):
93+
self.req_list.free(free_req_index)
94+
if self.req_list.is_all_free():
95+
logger.debug(f"freed all request size {self.req_list.can_alloc_size}")
96+
return
97+
98+
def free_token(self, free_token_index):
99+
self.mem_manager.free(free_token_index)
100+
return
101+
102+
def free_all(self):
103+
self.req_list = _ReqLinkedList(self.max_request_num)
104+
return
105+
74106
def _expand_by_page_size(self, b_token_len, page_size):
75-
# 将seq_len按page整数倍展开,例如seq_len = [9,9,9] -> page_len = [4,4,1,4,4,1,4,4,1], page_size = 4
76-
b_page_len = cdiv(b_token_len, page_size)
107+
# 将seq_len按page整数倍展开,例如seq_len = [9,9,9] -> p_token_len = [4,4,1,4,4,1,4,4,1], page_size = 4
108+
b_page_len = triton.cdiv(b_token_len, page_size)
77109
need_pages_num = b_page_len.sum()
78110
p_token_len = torch.full((need_pages_num,), page_size, dtype=b_token_len.dtype, device=b_token_len.device)
79111
cumsum_pages = torch.cumsum(b_page_len, dim=0)
80112
last_page_positions = cumsum_pages - 1
81113
remainders = b_token_len - (b_page_len - 1) * page_size
82114
p_token_len[last_page_positions] = remainders
83-
return need_pages_num, b_page_len, p_token_len
115+
return need_pages_num, p_token_len
84116

85-
def _alloc_paged_token_indices(self, b_req_idx, page_size, b_seq_len, b_ready_cache_len):
117+
def _alloc_paged_mem_indices(self, b_req_idx, page_size, b_seq_len, b_ready_cache_len):
86118
if b_ready_cache_len is not None:
87119
# prefill
88120
b_seq_len = b_seq_len.cpu()
89121
b_ready_cache_len = b_ready_cache_len.cpu()
90122

91123
b_token_len = b_seq_len - b_ready_cache_len
92-
total_pages_needed, b_page_len, p_token_len = self._expand_by_page_size(b_token_len, page_size)
124+
total_pages_needed, p_token_len = self._expand_by_page_size(b_token_len, page_size)
93125
paged_token_idxs = self.mem_manager.alloc(total_pages_needed * page_size)
94126
pages = paged_token_idxs.view(-1, page_size)
95127
mask = torch.arange(page_size, device=p_token_len.device) < p_token_len.unsqueeze(1)
@@ -126,41 +158,6 @@ def _get_need_paged_token_num(self, b_seq_len, b_ready_cache_len=None):
126158
need_new_pages = mask.sum()
127159
return need_new_pages * page_size
128160

129-
def calc_real_need_token_num(self, need_token_num, b_seq_len, b_ready_cache_len=None):
130-
return max(need_token_num, self._get_need_paged_token_num(b_seq_len, b_ready_cache_len))
131-
132-
def alloc_token_indices(self, need_size, b_req_idx, b_seq_len, b_ready_cache_len=None) -> torch.Tensor:
133-
page_size = get_page_size()
134-
if page_size > 1:
135-
return self._alloc_paged_token_indices(b_req_idx, page_size, b_seq_len, b_ready_cache_len)
136-
else:
137-
return self.mem_manager.alloc(need_size)
138-
139-
def alloc(self):
140-
return self.req_list.alloc()
141-
142-
def free(self, free_req_indexes: List[int], free_token_index):
143-
for req_index in free_req_indexes:
144-
self.req_list.free(req_index)
145-
146-
if self.req_list.is_all_free():
147-
logger.debug(f"freed all request size {self.req_list.can_alloc_size}")
148-
self.mem_manager.free(free_token_index)
149-
150-
def free_req(self, free_req_index: int):
151-
self.req_list.free(free_req_index)
152-
if self.req_list.is_all_free():
153-
logger.debug(f"freed all request size {self.req_list.can_alloc_size}")
154-
return
155-
156-
def free_token(self, free_token_index):
157-
self.mem_manager.free(free_token_index)
158-
return
159-
160-
def free_all(self):
161-
self.req_list = _ReqLinkedList(self.max_request_num)
162-
return
163-
164161

165162
class ReqSamplingParamsManager:
166163
"""

0 commit comments

Comments
 (0)