Skip to content

Commit 4a8bd69

Browse files
committed
fix: MTP in chunked prefill mode
1 parent db1b64c commit 4a8bd69

File tree

6 files changed

+49
-9
lines changed

6 files changed

+49
-9
lines changed

lightllm/common/basemodel/batch_objs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class ModelInput:
3131
# prefill 阶段使用的参数,但是不是推理过程使用的参数,是推理外部进行资源管理
3232
# 的一些变量
3333
b_prefill_has_output_cpu: List[bool] = None # 标记进行prefill的请求是否具有输出
34+
b_chunked_prefill_next_token_ids_cpu: List[int] = None # for chunked prefill mtp
3435

3536
# 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊
3637
# 的输入变量。只在特殊的模型模式下才会具体使用和生效。

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,13 @@ def get_input_token_ids(self):
393393
def get_chuncked_input_token_ids(self):
394394
chunked_start = self.cur_kv_len
395395
chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size)
396-
return self.shm_req.shm_prompt_ids.arr[0:chunked_end]
396+
397+
if chunked_end < self.get_cur_total_len():
398+
next_token_id = self.shm_req.shm_prompt_ids.arr[chunked_end]
399+
else:
400+
next_token_id = -1 # last chunk
401+
402+
return self.shm_req.shm_prompt_ids.arr[0:chunked_end], next_token_id
397403

398404
def get_chuncked_input_token_len(self):
399405
chunked_start = self.cur_kv_len
@@ -438,7 +444,7 @@ def _stop_sequences_matched(self, output_len: int):
438444

439445
def prefill_need_token_num(self, is_chuncked_prefill: bool):
440446
if is_chuncked_prefill:
441-
input_token_ids = self.get_chuncked_input_token_ids()
447+
input_token_ids, _ = self.get_chuncked_input_token_ids()
442448
else:
443449
input_token_ids = self.get_input_token_ids()
444450

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,14 @@ def prefill_mtp(
192192
mask_func=self.prefill_mask_func,
193193
)
194194
# mtp kv fill
195+
b_has_out = torch.tensor(model_input.b_prefill_has_output_cpu, dtype=torch.bool, device="cuda")
196+
b_chunked_next_token_ids = torch.tensor(
197+
model_input.b_chunked_prefill_next_token_ids_cpu, dtype=torch.int64, device="cuda"
198+
)
199+
mtp_next_token_ids = torch.where(b_has_out, next_token_ids, b_chunked_next_token_ids)
200+
195201
self._draft_prefill_forward(
196-
model_input=model_input, model_output=model_output, next_token_ids=next_token_ids
202+
model_input=model_input, model_output=model_output, next_token_ids=mtp_next_token_ids
197203
)
198204
sync_event = torch.cuda.Event()
199205
sync_event.record()

lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,13 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq]
354354
# mtp kv fill
355355
draft_next_token_ids_gpu = torch.zeros((model_input.batch_size), dtype=torch.int64, device="cuda")
356356
if req_num > 0:
357-
draft_next_token_ids_gpu[0:req_num].copy_(next_token_ids)
357+
b_has_out = torch.tensor(b_has_out_cpu, dtype=torch.bool, device="cuda")
358+
b_chunked_next_token_ids = torch.tensor(
359+
model_input.b_chunked_prefill_next_token_ids_cpu[0:req_num], dtype=torch.int64, device="cuda"
360+
)
361+
mtp_next_token_ids = torch.where(b_has_out, next_token_ids, b_chunked_next_token_ids)
362+
draft_next_token_ids_gpu[0:req_num].copy_(mtp_next_token_ids)
363+
358364
self._draft_prefill_forward(
359365
model_input=model_input,
360366
model_output=model_output,
@@ -633,13 +639,27 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I
633639
draft_model_input0, draft_model_input1 = model_input0, model_input1
634640
draft_next_token_ids_gpu0 = torch.zeros((model_input0.batch_size), dtype=torch.int64, device="cuda")
635641
if req_num0 > 0:
636-
draft_next_token_ids_gpu0[0:req_num0].copy_(next_token_ids[0:req_num0], non_blocking=True)
642+
b_has_out0 = torch.tensor(
643+
model_input0.b_prefill_has_output_cpu[0:req_num0], dtype=torch.bool, device="cuda"
644+
)
645+
b_chunked_next_token_ids0 = torch.tensor(
646+
model_input0.b_chunked_prefill_next_token_ids_cpu[0:req_num0], dtype=torch.int64, device="cuda"
647+
)
648+
mtp_next_token_ids0 = torch.where(b_has_out0, next_token_ids[0:req_num0], b_chunked_next_token_ids0)
649+
draft_next_token_ids_gpu0[0:req_num0].copy_(mtp_next_token_ids0, non_blocking=True)
637650

638651
draft_next_token_ids_gpu1 = torch.zeros((model_input1.batch_size), dtype=torch.int64, device="cuda")
639652
if req_num1 > 0:
640-
draft_next_token_ids_gpu1[0:req_num1].copy_(
641-
next_token_ids[req_num0 : (req_num0 + req_num1)], non_blocking=True
653+
b_has_out1 = torch.tensor(
654+
model_input1.b_prefill_has_output_cpu[0:req_num1], dtype=torch.bool, device="cuda"
655+
)
656+
b_chunked_next_token_ids1 = torch.tensor(
657+
model_input1.b_chunked_prefill_next_token_ids_cpu[0:req_num1], dtype=torch.int64, device="cuda"
658+
)
659+
mtp_next_token_ids1 = torch.where(
660+
b_has_out1, next_token_ids[req_num0 : (req_num0 + req_num1)], b_chunked_next_token_ids1
642661
)
662+
draft_next_token_ids_gpu1[0:req_num1].copy_(mtp_next_token_ids1, non_blocking=True)
643663

644664
draft_model_output0, draft_model_output1 = model_output0, model_output1
645665

lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,16 @@ def padded_prepare_prefill_inputs(
3636
b_ready_cache_len = []
3737
b_mtp_index = []
3838
b_prefill_has_output = []
39+
b_chunked_prefill_next_token_ids = []
3940

4041
for req in req_objs:
4142

4243
run_reqs.append(req)
4344
batch_multimodal_params.append(req.multimodal_params)
4445
b_req_idx.append(req.req_idx)
4546

46-
input_token_ids = req.get_chuncked_input_token_ids()
47+
input_token_ids, next_token_id = req.get_chuncked_input_token_ids()
48+
b_chunked_prefill_next_token_ids.append(next_token_id)
4749
b_prefill_has_output.append(False if len(input_token_ids) < req.get_cur_total_len() else True)
4850
seq_len = len(input_token_ids)
4951
input_token_len = seq_len - req.cur_kv_len
@@ -65,6 +67,7 @@ def padded_prepare_prefill_inputs(
6567
b_q_seq_len.append(1)
6668
b_mtp_index.append(0)
6769
b_prefill_has_output.append(False)
70+
b_chunked_prefill_next_token_ids.append(-1)
6871
b_ready_cache_len.append(0)
6972
total_token_num += 1
7073
prefix_total_token_num += 0
@@ -112,6 +115,7 @@ def padded_prepare_prefill_inputs(
112115
b_ready_cache_len=b_ready_cache_len,
113116
is_prefill=True,
114117
b_prefill_has_output_cpu=b_prefill_has_output,
118+
b_chunked_prefill_next_token_ids_cpu=b_chunked_prefill_next_token_ids,
115119
)
116120
if is_multimodal:
117121
model_input.multimodal_params = batch_multimodal_params

lightllm/server/router/model_infer/mode_backend/generic_pre_process.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,16 @@ def prepare_prefill_inputs(
2020
b_ready_cache_len = []
2121
b_mtp_index = []
2222
b_prefill_has_output = []
23+
b_chunked_prefill_next_token_ids = []
2324

2425
for req in req_objs:
2526
run_reqs.append(req)
2627
batch_multimodal_params.append(req.multimodal_params)
2728
b_req_idx.append(req.req_idx)
2829

2930
if is_chuncked_mode:
30-
input_token_ids = req.get_chuncked_input_token_ids()
31+
input_token_ids, next_token_id = req.get_chuncked_input_token_ids()
32+
b_chunked_prefill_next_token_ids.append(next_token_id)
3133
else:
3234
input_token_ids = req.get_input_token_ids()
3335

@@ -80,6 +82,7 @@ def prepare_prefill_inputs(
8082
b_ready_cache_len=b_ready_cache_len,
8183
is_prefill=True,
8284
b_prefill_has_output_cpu=b_prefill_has_output,
85+
b_chunked_prefill_next_token_ids_cpu=b_chunked_prefill_next_token_ids,
8386
prefix_total_token_num=prefix_total_token_num,
8487
)
8588
if is_multimodal:

0 commit comments

Comments
 (0)