Skip to content

Commit 03d5bb7

Browse files
committed
Fix the AR drop in non-cdl
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent 440c99d commit 03d5bb7

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1476,7 +1476,15 @@ def _prepare_tp_inputs(
14761476
# skip adding input_ids of CUDA graph dummy requests so that new_tokens_device
14771477
# can be aligned to the correct positions.
14781478
if not request.is_cuda_graph_dummy:
1479-
input_ids.append(request.get_last_tokens(beam))
1479+
# Track position for GPU update (draft model only)
1480+
if self.is_draft_model and num_accepted_tokens_device is not None:
1481+
start_idx = len(input_ids)
1482+
input_ids.append(request.get_last_tokens(beam))
1483+
end_idx = len(input_ids)
1484+
first_draft_input_ids_positions.append(
1485+
(start_idx, end_idx, request.py_seq_slot))
1486+
else:
1487+
input_ids.append(request.get_last_tokens(beam))
14801488
past_seen_token_num = request.max_beam_num_tokens - 1
14811489
else:
14821490
# the request has previous tensor
@@ -1842,6 +1850,7 @@ def previous_seq_slots_device():
18421850
self.iter_states['num_ctx_requests'] = num_ctx_requests
18431851
self.iter_states['num_ctx_tokens'] = num_ctx_tokens
18441852
self.iter_states['num_generation_tokens'] = num_generation_tokens
1853+
print(f"DEBUG: is_draft_model: {self.is_draft_model}, inputs: {inputs}")
18451854
return inputs, self.gather_ids_cuda[:len(
18461855
gather_ids)] if self.enable_spec_decode else None
18471856

0 commit comments

Comments
 (0)