5757from .cuda_graph_runner import CUDAGraphRunner
5858from .guided_decoder import CapturableGuidedDecoder
5959from .layerwise_nvtx_marker import LayerwiseNvtxMarker
60- from .llm_request import get_draft_token_length
60+ from .llm_request import LlmRequest , get_draft_token_length
6161from .model_loader import ModelLoader
6262from .resource_manager import (BaseResourceManager , KVCacheManager ,
6363 ResourceManager , ResourceManagerType )
@@ -1192,7 +1192,8 @@ def _prepare_tp_inputs(
11921192 spec_metadata : Optional [SpecMetadata ] = None ,
11931193 new_tensors_device : Optional [SampleStateTensors ] = None ,
11941194 cache_indirection_buffer : Optional [torch .Tensor ] = None ,
1195- num_accepted_tokens_device : Optional [torch .Tensor ] = None ):
1195+ num_accepted_tokens_device : Optional [torch .Tensor ] = None ,
1196+ req_id_to_old_request : Optional [Dict [int , LlmRequest ]] = None ):
11961197 """
11971198 Prepare inputs for Pytorch Model.
11981199 """
@@ -1256,9 +1257,11 @@ def _prepare_tp_inputs(
12561257 start_idx = len (input_ids )
12571258 input_ids .extend (prompt_tokens )
12581259 end_idx = len (input_ids )
1260+ slot_idx = req_id_to_old_request [
1261+ request .py_request_id ].py_seq_slot
12591262 context_input_ids_positions .append (
1260- (start_idx , end_idx - 1 , request . py_seq_slot
1261- )) # end_idx-1 is the last token position
1263+ (start_idx , end_idx - 1 ,
1264+ slot_idx )) # end_idx-1 is the last token position
12621265 else :
12631266 input_ids .extend (prompt_tokens )
12641267
@@ -1433,16 +1436,18 @@ def _prepare_tp_inputs(
14331436 input_ids .extend (prompt_tokens )
14341437 end_idx = len (input_ids )
14351438 # For first_draft, we need to replace the last original_max_draft_len+1 tokens
1439+ slot_idx = req_id_to_old_request [
1440+ request .py_request_id ].py_seq_slot
14361441 first_draft_input_ids_positions .append (
1437- (start_idx , end_idx , request . py_seq_slot ))
1442+ (start_idx , end_idx , slot_idx ))
14381443
14391444 # Store info for GPU computation of gather_ids and num_accepted_draft_tokens
14401445 base_gather_id = len (
14411446 input_ids ) - 1 - self .original_max_draft_len
14421447 gather_ids .append (
14431448 base_gather_id ) # Placeholder, will be corrected on GPU
14441449 first_draft_base_gather_ids .append (base_gather_id )
1445- first_draft_seq_slots .append (request . py_seq_slot )
1450+ first_draft_seq_slots .append (slot_idx )
14461451 first_draft_request_indices .append (
14471452 len (num_accepted_draft_tokens ))
14481453
@@ -1481,8 +1486,10 @@ def _prepare_tp_inputs(
14811486 start_idx = len (input_ids )
14821487 input_ids .append (request .get_last_tokens (beam ))
14831488 end_idx = len (input_ids )
1489+ slot_idx = req_id_to_old_request [
1490+ request .py_request_id ].py_seq_slot
14841491 first_draft_input_ids_positions .append (
1485- (start_idx , end_idx , request . py_seq_slot ))
1492+ (start_idx , end_idx , slot_idx ))
14861493 else :
14871494 input_ids .append (request .get_last_tokens (beam ))
14881495 past_seen_token_num = request .max_beam_num_tokens - 1
@@ -2328,7 +2335,8 @@ def _prepare_inputs(
23282335 spec_metadata : Optional [SpecMetadata ] = None ,
23292336 new_tensors_device : Optional [SampleStateTensors ] = None ,
23302337 cache_indirection_buffer : Optional [torch .Tensor ] = None ,
2331- num_accepted_tokens_device : Optional [torch .Tensor ] = None ):
2338+ num_accepted_tokens_device : Optional [torch .Tensor ] = None ,
2339+ req_id_to_old_request : Optional [Dict [int , LlmRequest ]] = None ):
23322340 if self .mapping is not None and 'cp_type' in self .mapping .cp_config :
23332341 cp_type = self .mapping .cp_config ['cp_type' ]
23342342 if CpType .STAR == cp_type :
@@ -2345,7 +2353,8 @@ def _prepare_inputs(
23452353 attn_metadata , spec_metadata ,
23462354 new_tensors_device ,
23472355 cache_indirection_buffer ,
2348- num_accepted_tokens_device )
2356+ num_accepted_tokens_device ,
2357+ req_id_to_old_request )
23492358
23502359 @torch .inference_mode ()
23512360 @with_model_extra_attrs (lambda self : self .model .extra_attrs )
@@ -2355,7 +2364,8 @@ def forward(self,
23552364 new_tensors_device : Optional [SampleStateTensors ] = None ,
23562365 gather_context_logits : bool = False ,
23572366 cache_indirection_buffer : Optional [torch .Tensor ] = None ,
2358- num_accepted_tokens_device : Optional [torch .Tensor ] = None ):
2367+ num_accepted_tokens_device : Optional [torch .Tensor ] = None ,
2368+ req_id_to_old_request : Optional [Dict [int , LlmRequest ]] = None ):
23592369 kv_cache_manager = resource_manager .get_resource_manager (
23602370 self .kv_cache_manager_key )
23612371
@@ -2411,7 +2421,7 @@ def forward(self,
24112421 inputs , gather_ids = self ._prepare_inputs (
24122422 padded_requests , kv_cache_manager , attn_metadata , spec_metadata ,
24132423 new_tensors_device , cache_indirection_buffer ,
2414- num_accepted_tokens_device )
2424+ num_accepted_tokens_device , req_id_to_old_request )
24152425
24162426 self .iter_counter += 1
24172427 with with_shared_pool (self .cuda_graph_runner .get_graph_pool ()):
0 commit comments