@@ -1476,7 +1476,15 @@ def _prepare_tp_inputs(
14761476 # skip adding input_ids of CUDA graph dummy requests so that new_tokens_device
14771477 # can be aligned to the correct positions.
14781478 if not request .is_cuda_graph_dummy :
1479- input_ids .append (request .get_last_tokens (beam ))
1479+ # Track position for GPU update (draft model only)
1480+ if self .is_draft_model and num_accepted_tokens_device is not None :
1481+ start_idx = len (input_ids )
1482+ input_ids .append (request .get_last_tokens (beam ))
1483+ end_idx = len (input_ids )
1484+ first_draft_input_ids_positions .append (
1485+ (start_idx , end_idx , request .py_seq_slot ))
1486+ else :
1487+ input_ids .append (request .get_last_tokens (beam ))
14801488 past_seen_token_num = request .max_beam_num_tokens - 1
14811489 else :
14821490 # the request has previous tensor
@@ -1842,6 +1850,7 @@ def previous_seq_slots_device():
18421850 self .iter_states ['num_ctx_requests' ] = num_ctx_requests
18431851 self .iter_states ['num_ctx_tokens' ] = num_ctx_tokens
18441852 self .iter_states ['num_generation_tokens' ] = num_generation_tokens
1853+ print (f"DEBUG: is_draft_model: { self .is_draft_model } , inputs: { inputs } " )
18451854 return inputs , self .gather_ids_cuda [:len (
18461855 gather_ids )] if self .enable_spec_decode else None
18471856
0 commit comments