Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lightllm/common/basemodel/batch_objs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class ModelInput:
# prefill 阶段使用的参数,但是不是推理过程使用的参数,是推理外部进行资源管理
# 的一些变量
b_prefill_has_output_cpu: List[bool] = None # 标记进行prefill的请求是否具有输出
b_next_chunck_first_token_ids_cpu: List[int] = None # for chuncked prefill mtp

# 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊
# 的输入变量。只在特殊的模型模式下才会具体使用和生效。
Expand Down
11 changes: 9 additions & 2 deletions lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,14 @@ def get_input_token_ids(self):
def get_chuncked_input_token_ids(self):
chunked_start = self.cur_kv_len
chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size)
return self.shm_req.shm_prompt_ids.arr[0:chunked_end]

if chunked_end < self.get_cur_total_len():
next_token_id = self.shm_req.shm_prompt_ids.arr[chunked_end]
else:
# padding id for last chunck, will be discarded.
next_token_id = self.shm_req.shm_prompt_ids.arr[0]

return self.shm_req.shm_prompt_ids.arr[0:chunked_end], next_token_id

def get_chuncked_input_token_len(self):
chunked_start = self.cur_kv_len
Expand Down Expand Up @@ -438,7 +445,7 @@ def _stop_sequences_matched(self, output_len: int):

def prefill_need_token_num(self, is_chuncked_prefill: bool):
if is_chuncked_prefill:
input_token_ids = self.get_chuncked_input_token_ids()
input_token_ids, _ = self.get_chuncked_input_token_ids()
else:
input_token_ids = self.get_input_token_ids()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,7 @@ def _sample_and_scatter_token(
is_prefill: bool,
b_prefill_has_output_cpu: torch.Tensor = None,
mask_func: Optional[Callable] = None,
b_next_chunck_first_token_ids_cpu: torch.Tensor = None,
):

if mask_func is not None:
Expand All @@ -670,6 +671,11 @@ def _sample_and_scatter_token(
b_has_out = g_pin_mem_manager.gen_from_list(
key="b_has_out", data=b_prefill_has_output_cpu, dtype=torch.bool
).cuda(non_blocking=True)
if b_next_chunck_first_token_ids_cpu is not None:
b_next_chunck_first_token_ids = g_pin_mem_manager.gen_from_list(
key="b_next_chunck_first_token_ids", data=b_next_chunck_first_token_ids_cpu, dtype=torch.int64
).cuda(non_blocking=True)
next_token_ids = torch.where(b_has_out, next_token_ids, b_next_chunck_first_token_ids)

scatter_token(
next_token_ids=next_token_ids,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ def prefill_mtp(
is_prefill=True,
b_prefill_has_output_cpu=model_input.b_prefill_has_output_cpu,
mask_func=self.prefill_mask_func,
b_next_chunck_first_token_ids_cpu=model_input.b_next_chunck_first_token_ids_cpu,
)
# mtp kv fill
self._draft_prefill_forward(
model_input=model_input, model_output=model_output, next_token_ids=next_token_ids
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,10 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I
b_has_out_cpu = (
model_input0.b_prefill_has_output_cpu[0:req_num0] + model_input1.b_prefill_has_output_cpu[0:req_num1]
)
b_next_chunck_first_token_ids_cpu = (
model_input0.b_next_chunck_first_token_ids_cpu[0:req_num0]
+ model_input1.b_next_chunck_first_token_ids_cpu[0:req_num1]
)
b_mtp_index = torch.cat((model_input0.b_mtp_index[0:req_num0], model_input1.b_mtp_index[0:req_num1]), dim=0)
b_req_idx = torch.cat((model_input0.b_req_idx[0:req_num0], model_input1.b_req_idx[0:req_num1]), dim=0)

Expand All @@ -627,6 +631,7 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I
b_mtp_index=b_mtp_index,
is_prefill=True,
b_prefill_has_output_cpu=b_has_out_cpu,
b_next_chunck_first_token_ids_cpu=b_next_chunck_first_token_ids_cpu,
)

# spec prefill: MTP
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,16 @@ def padded_prepare_prefill_inputs(
b_ready_cache_len = []
b_mtp_index = []
b_prefill_has_output = []
b_next_chunck_first_token_ids = []

for req in req_objs:

run_reqs.append(req)
batch_multimodal_params.append(req.multimodal_params)
b_req_idx.append(req.req_idx)

input_token_ids = req.get_chuncked_input_token_ids()
input_token_ids, next_token_id = req.get_chuncked_input_token_ids()
b_next_chunck_first_token_ids.append(next_token_id)
b_prefill_has_output.append(False if len(input_token_ids) < req.get_cur_total_len() else True)
seq_len = len(input_token_ids)
input_token_len = seq_len - req.cur_kv_len
Expand All @@ -65,6 +67,7 @@ def padded_prepare_prefill_inputs(
b_q_seq_len.append(1)
b_mtp_index.append(0)
b_prefill_has_output.append(False)
b_next_chunck_first_token_ids.append(0)
b_ready_cache_len.append(0)
total_token_num += 1
prefix_total_token_num += 0
Expand Down Expand Up @@ -112,6 +115,7 @@ def padded_prepare_prefill_inputs(
b_ready_cache_len=b_ready_cache_len,
is_prefill=True,
b_prefill_has_output_cpu=b_prefill_has_output,
b_next_chunck_first_token_ids_cpu=b_next_chunck_first_token_ids,
)
if is_multimodal:
model_input.multimodal_params = batch_multimodal_params
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ def prepare_prefill_inputs(
b_ready_cache_len = []
b_mtp_index = []
b_prefill_has_output = []
b_next_chunck_first_token_ids = []

for req in req_objs:
run_reqs.append(req)
batch_multimodal_params.append(req.multimodal_params)
b_req_idx.append(req.req_idx)

if is_chuncked_mode:
input_token_ids = req.get_chuncked_input_token_ids()
input_token_ids, next_token_id = req.get_chuncked_input_token_ids()
b_next_chunck_first_token_ids.append(next_token_id)
else:
input_token_ids = req.get_input_token_ids()

Expand Down Expand Up @@ -57,6 +59,7 @@ def prepare_prefill_inputs(
b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu")
b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu")
b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cpu")
b_next_chunck_first_token_ids = torch.tensor(b_next_chunck_first_token_ids, dtype=torch.int64, device="cpu")

# dynamic prompt cache 准备 token
g_infer_state_lock.acquire()
Expand All @@ -80,6 +83,7 @@ def prepare_prefill_inputs(
b_ready_cache_len=b_ready_cache_len,
is_prefill=True,
b_prefill_has_output_cpu=b_prefill_has_output,
b_next_chunck_first_token_ids_cpu=b_next_chunck_first_token_ids,
prefix_total_token_num=prefix_total_token_num,
)
if is_multimodal:
Expand Down