Skip to content

Commit 6f42d17

Browse files
author
niushengxiao
committed
feat: add b_last_mem_indx in the InferReq
1 parent 64f649f commit 6f42d17

File tree

8 files changed

+66
-30
lines changed

8 files changed

+66
-30
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -687,9 +687,7 @@ def _check_max_len_infer(self):
687687
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
688688
b_seq_len[:] = self.batch_max_tokens
689689
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
690-
mem_indexes = self.req_manager.alloc_mem_indices(
691-
len(dummy_input_ids), b_req_idx, b_seq_len, b_ready_cache_len
692-
).cuda()
690+
mem_indexes = self.req_manager.alloc_mem_indices(len(dummy_input_ids), b_seq_len, b_ready_cache_len).cuda()
693691
total_token_num = self.batch_max_tokens
694692
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
695693
model_input = ModelInput(
@@ -765,7 +763,7 @@ def _autotune_warmup(self):
765763
total_token_num = input_len
766764
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
767765
mem_indexes = self.req_manager.alloc_mem_indices(
768-
len(dummy_input_ids), b_req_idx, b_seq_len, b_ready_cache_len
766+
len(dummy_input_ids), b_seq_len, b_ready_cache_len
769767
).cuda()
770768
model_input = ModelInput(
771769
batch_size=1,

lightllm/common/basemodel/cuda_graph.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,11 @@ def warmup(self, model):
201201
)
202202
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
203203
b_seq_len.fill_(seq_len)
204+
b_last_mem_index = torch.zeros_like(b_seq_len)
204205
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
205-
mem_indexes = model.req_manager.alloc_mem_indices(len(input_ids), b_req_idx, b_seq_len).cuda()
206+
mem_indexes = model.req_manager.alloc_mem_indices(
207+
len(input_ids), b_seq_len, b_last_mem_index=b_last_mem_index
208+
).cuda()
206209

207210
model_input = ModelInput(
208211
batch_size=batch_size,
@@ -257,8 +260,11 @@ def warmup_overlap(self, model):
257260
)
258261
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
259262
b_seq_len.fill_(seq_len)
263+
b_last_mem_index = torch.zeros_like(b_seq_len)
260264
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
261-
mem_indexes = model.req_manager.alloc_mem_indices(len(input_ids), b_req_idx, b_seq_len).cuda()
265+
mem_indexes = model.req_manager.alloc_mem_indices(
266+
len(input_ids), b_seq_len, b_last_mem_index=b_last_mem_index
267+
).cuda()
262268

263269
micro_batch = ModelInput(
264270
is_prefill=False,

lightllm/common/req_manager.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,21 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana
7171
def calc_real_need_token_num(self, need_token_num, b_seq_len, b_ready_cache_len=None):
7272
return max(need_token_num, self._get_need_paged_token_num(b_seq_len, b_ready_cache_len))
7373

74-
def alloc_mem_indices(self, need_size, b_req_idx=None, b_seq_len=None, b_ready_cache_len=None) -> torch.Tensor:
74+
def calc_last_mem_index_in_prefill(self, mem_indices, b_seq_len, b_ready_cache_len=None):
75+
b_token_len = b_seq_len
76+
if b_ready_cache_len is not None:
77+
b_token_len = b_seq_len - b_ready_cache_len
78+
b_token_len_cumsum = torch.cumsum(b_token_len, dim=0)
79+
b_last_mem_index = mem_indices[b_token_len_cumsum - 1]
80+
return b_last_mem_index
81+
82+
# b_ready_cache_len为None时才需要b_last_mem_index
83+
def alloc_mem_indices(
84+
self, need_size, b_seq_len=None, b_ready_cache_len=None, b_last_mem_index=None
85+
) -> torch.Tensor:
7586
page_size = get_page_size()
76-
if page_size > 1 and b_req_idx is not None and b_seq_len is not None:
77-
return self._alloc_paged_mem_indices(b_req_idx, page_size, b_seq_len, b_ready_cache_len)
87+
if page_size > 1 and b_seq_len is not None:
88+
return self._alloc_paged_mem_indices(page_size, b_seq_len, b_ready_cache_len, b_last_mem_index)
7889
else:
7990
return self.mem_manager.alloc(need_size)
8091

@@ -114,12 +125,11 @@ def _expand_by_page_size(self, b_token_len, page_size):
114125
p_token_len[last_page_positions] = remainders
115126
return need_pages_num, p_token_len
116127

117-
def _alloc_paged_mem_indices(self, b_req_idx, page_size, b_seq_len, b_ready_cache_len):
128+
def _alloc_paged_mem_indices(self, page_size, b_seq_len, b_ready_cache_len, b_last_mem_index):
129+
b_seq_len = b_seq_len.cpu()
118130
if b_ready_cache_len is not None:
119131
# prefill
120-
b_seq_len = b_seq_len.cpu()
121132
b_ready_cache_len = b_ready_cache_len.cpu()
122-
123133
b_token_len = b_seq_len - b_ready_cache_len
124134
total_pages_needed, p_token_len = self._expand_by_page_size(b_token_len, page_size)
125135
paged_token_idxs = self.mem_manager.alloc(total_pages_needed * page_size)
@@ -128,19 +138,17 @@ def _alloc_paged_mem_indices(self, b_req_idx, page_size, b_seq_len, b_ready_cach
128138
return pages[mask]
129139
else:
130140
# decode
131-
b_seq_len = b_seq_len.cuda()
132-
b_req_idx = b_req_idx.cuda()
141+
assert b_last_mem_index is not None
142+
b_last_mem_index = b_last_mem_index.cpu()
133143
need_new_page_mask = (b_seq_len - 1) % page_size == 0
134-
new_pages_num = need_new_page_mask.sum().cpu()
144+
new_pages_num = need_new_page_mask.sum()
135145
token_idxs = torch.zeros_like(b_seq_len, device=b_seq_len.device)
136146
if new_pages_num > 0:
137-
new_pages_tokens = self.mem_manager.alloc(new_pages_num * page_size).cuda()
147+
new_pages_tokens = self.mem_manager.alloc(new_pages_num * page_size)
138148
token_idxs[need_new_page_mask] = new_pages_tokens[::page_size]
139-
140149
mask = ~need_new_page_mask
141150
if mask.any():
142-
seq_lens = b_seq_len[mask]
143-
token_idxs[mask] = self.req_to_token_indexs[b_req_idx[mask], seq_lens - 2] + 1
151+
token_idxs[mask] = b_last_mem_index[mask] + 1
144152
return token_idxs
145153

146154
def _get_need_paged_token_num(self, b_seq_len, b_ready_cache_len=None):

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def __init__(
288288
self.shm_index = shm_index
289289
self.multimodal_params = multimodal_params
290290
self.vocab_size = vocab_size
291+
self.last_kv_mem_index = -1
291292

292293
# 请求需要被暂停
293294
self.wait_pause = False

lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,13 @@ def padded_prepare_prefill_inputs(
8282
)
8383
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num)
8484
mem_indexes = g_infer_context.req_manager.alloc_mem_indices(
85-
input_ids.shape[0] - padded_req_num, b_req_idx, b_seq_len, b_ready_cache_len
85+
input_ids.shape[0] - padded_req_num, b_seq_len, b_ready_cache_len
8686
)
87+
b_last_mem_index = g_infer_context.req_manager.calc_last_mem_index_in_prefill(
88+
mem_indexes, b_seq_len, b_ready_cache_len
89+
)
90+
for i, req in enumerate(req_objs):
91+
req.last_kv_mem_index = b_last_mem_index[i].item()
8792

8893
g_infer_state_lock.release()
8994

@@ -123,6 +128,7 @@ def padded_prepare_decode_inputs(
123128
b_req_idx = []
124129
b_mtp_index = []
125130
b_seq_len = []
131+
b_last_mem_index = []
126132
for req in req_objs:
127133
run_reqs.append(req)
128134
b_req_idx.append(req.req_idx)
@@ -132,6 +138,7 @@ def padded_prepare_decode_inputs(
132138
total_token_num += seq_len
133139
max_len_in_batch = max(max_len_in_batch, seq_len)
134140
b_mtp_index.append(0)
141+
b_last_mem_index.append(req.last_kv_mem_index)
135142
# process the draft tokens.
136143
for step in range(req.mtp_step):
137144
run_reqs.append(req)
@@ -164,15 +171,18 @@ def padded_prepare_decode_inputs(
164171
b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu")
165172
b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu")
166173
b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu")
174+
b_last_mem_index = torch.tensor(b_last_mem_index, dtype=torch.int32, device="cpu")
167175

168176
# dynamic prompt cache 准备 token
169177
g_infer_state_lock.acquire()
170178
if g_infer_context.radix_cache is not None:
171179
token_num = g_infer_context.req_manager.calc_real_need_token_num(b_seq_len.shape[0] - padded_req_num, b_seq_len)
172180
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num)
173181
mem_indexes = g_infer_context.req_manager.alloc_mem_indices(
174-
b_seq_len.shape[0] - padded_req_num, b_req_idx, b_seq_len
182+
b_seq_len.shape[0] - padded_req_num, b_seq_len, b_last_mem_index=b_last_mem_index
175183
)
184+
for i, req in enumerate(req_objs):
185+
req.last_kv_mem_index = mem_indexes[i]
176186
g_infer_state_lock.release()
177187

178188
if padded_req_num > 0:

lightllm/server/router/model_infer/mode_backend/generic_pre_process.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,12 @@ def prepare_prefill_inputs(
5959
input_ids.shape[0], b_seq_len, b_ready_cache_len
6060
)
6161
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num)
62-
mem_indexes = g_infer_context.req_manager.alloc_mem_indices(
63-
input_ids.shape[0], b_req_idx, b_seq_len, b_ready_cache_len
62+
mem_indexes = g_infer_context.req_manager.alloc_mem_indices(input_ids.shape[0], b_seq_len, b_ready_cache_len)
63+
b_last_mem_index = g_infer_context.req_manager.calc_last_mem_index_in_prefill(
64+
mem_indexes, b_seq_len, b_ready_cache_len
6465
)
66+
for i, req in enumerate(req_objs):
67+
req.last_kv_mem_index = b_last_mem_index[i].item()
6568
g_infer_state_lock.release()
6669

6770
model_input = ModelInput(
@@ -90,6 +93,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
9093
b_req_idx = []
9194
b_mtp_index = []
9295
b_seq_len = []
96+
b_last_mem_index = []
9397
for req in req_objs:
9498
run_reqs.append(req)
9599
b_req_idx.append(req.req_idx)
@@ -99,6 +103,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
99103
total_token_num += seq_len
100104
max_len_in_batch = max(max_len_in_batch, seq_len)
101105
b_mtp_index.append(0)
106+
b_last_mem_index.append(req.last_kv_mem_index)
102107
# process the draft tokens.
103108
for step in range(req.mtp_step):
104109
run_reqs.append(req)
@@ -112,13 +117,18 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
112117
b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu")
113118
b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu")
114119
b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu")
120+
b_last_mem_index = torch.tensor(b_last_mem_index, dtype=torch.int32, device="cpu")
115121

116122
# dynamic prompt cache 准备 token
117123
g_infer_state_lock.acquire()
118124
if g_infer_context.radix_cache is not None:
119125
token_num = g_infer_context.req_manager.calc_real_need_token_num(b_seq_len.shape[0], b_seq_len)
120126
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num)
121-
mem_indexes = g_infer_context.req_manager.alloc_mem_indices(b_seq_len.shape[0], b_req_idx, b_seq_len)
127+
mem_indexes = g_infer_context.req_manager.alloc_mem_indices(
128+
b_seq_len.shape[0], b_seq_len, b_last_mem_index=b_last_mem_index
129+
)
130+
for i, req in enumerate(req_objs):
131+
req.last_kv_mem_index = mem_indexes[i]
122132
g_infer_state_lock.release()
123133

124134
model_input = ModelInput(

test/benchmark/static_inference/model_infer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,8 @@ def run_forward_once(
258258
b_seq_len[i] = input_len
259259

260260
total_token_num = batch_size * input_len
261-
mem_indexes = model_part.req_manager.alloc_mem_indices(test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len)
261+
mem_indexes = model_part.req_manager.alloc_mem_indices(test_data.shape[0], b_seq_len, b_ready_cache_len)
262+
b_last_mem_index = model_part.req_manager.calc_last_mem_index_in_prefill(mem_indexes, b_seq_len, b_ready_cache_len)
262263
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu")
263264
rank_id = model_kvargs["rank_id"]
264265

@@ -321,7 +322,10 @@ def run_forward_once(
321322
step_start = time.time()
322323
total_token_num += batch_size
323324
b_seq_len += 1
324-
mem_indexes = model_part.req_manager.alloc_mem_indices(predict_ids.shape[0], b_req_idx, b_seq_len)
325+
mem_indexes = model_part.req_manager.alloc_mem_indices(
326+
predict_ids.shape[0], b_seq_len, b_last_mem_index=b_last_mem_index
327+
)
328+
b_last_mem_index = mem_indexes
325329
max_len_in_batch = input_len + i + 1
326330
logits = decode_fn(
327331
model_part,

test/benchmark/static_inference/model_infer_mtp.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,8 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_
124124
b_seq_len[i] = input_len
125125

126126
total_token_num = input_len * batch_size
127-
mem_indexes = main_model.req_manager.alloc_mem_indices(
128-
test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len
129-
).cuda()
127+
mem_indexes = main_model.req_manager.alloc_mem_indices(test_data.shape[0], b_seq_len, b_ready_cache_len).cuda()
128+
b_last_mem_index = main_model.req_manager.calc_last_mem_index_in_prefill(mem_indexes, b_seq_len, b_ready_cache_len)
130129
# Main model Prefill
131130
model_input = ModelInput(
132131
batch_size=batch_size,
@@ -194,7 +193,7 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_
194193
nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cuda")
195194
nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda")
196195
mem_indexes = main_model.req_manager.alloc_mem_indices(
197-
batch_size * (len(draft_models) + 1), nopad_b_seq_idx, nopad_b_seq_len
196+
batch_size * (len(draft_models) + 1), nopad_b_seq_len, b_last_mem_index=b_last_mem_index
198197
).cuda()
199198

200199
model_input = ModelInput(

0 commit comments

Comments
 (0)