3535from tensorrt_llm .mapping import CpType
3636from tensorrt_llm .runtime .generation import CUASSERT
3737
38+ from ..attention_backend .trtllm import TrtllmAttention
3839from ..distributed import Distributed
3940from ..models .modeling_utils import DecoderModelForCausalLM
4041from ..modules .decoder_layer import DecoderLayer
4142from ..speculative .drafter import Drafter
43+ from ..speculative .model_drafter import ModelDrafter
44+ from ..speculative .mtp import SampleStateTensorsMTP
4245from ..speculative .speculation_gate import SpeculationGate
4346from .executor_request_queue import ExecutorRequestQueue , RequestQueueItem
4447from .guided_decoder import GuidedDecoder
@@ -271,8 +274,23 @@ def __init__(self,
271274
272275 if self .dist .pp_size > 1 :
273276 self .event_loop = self ._executor_loop_pp
277+ elif self .disable_overlap_scheduler :
278+ self .event_loop = self ._executor_loop
274279 else :
275- self .event_loop = self ._executor_loop if disable_overlap_scheduler else self ._executor_loop_overlap
280+ # TODO: Overlap scheduler is not supported for below cases:
281+ # 1. non-CDL is used
282+ # 2. non-TrtllmAttention attention backend is used
283+ overlap_not_supported = self .drafter is not None and isinstance (
284+ self .drafter , ModelDrafter ) and (
285+ not self .drafter .use_static_draft_loop or not issubclass (
286+ self .draft_model_engine .attn_backend , TrtllmAttention ))
287+
288+ if overlap_not_supported :
289+ logger .warning (
290+ "Overlap scheduler is disabled for draft model engine with non-CDL or non-TrtllmAttention attention backend."
291+ )
292+ self .disable_overlap_scheduler = True
293+ self .event_loop = self ._executor_loop_overlap if not overlap_not_supported else self ._executor_loop
276294 if is_trace_enabled ("TLLM_TRACE_EXECUTOR_LOOP" ):
277295 self .event_loop = trace_func (self .event_loop )
278296
@@ -1045,14 +1063,11 @@ def _prepare_and_schedule_batch(self):
10451063 0
10461064 ] * max_total_draft_tokens if max_total_draft_tokens > 0 else []
10471065
1048- # When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch,
1049- # we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet.
1050- if not self .has_previous_draft_tokens :
1051- # If speculation is off, this function sets py_draft_tokens to []
1052- # for all active requests. If it's on, we initialize py_draft_tokens
1053- # with dummy draft tokens to make the scheduler aware of the fact
1054- # that speculation is about to happen.
1055- self ._prepare_draft_requests ()
1066+ # If speculation is off, this function sets py_draft_tokens to []
1067+ # for all active requests. If it's on, we initialize py_draft_tokens
1068+ # with dummy draft tokens to make the scheduler aware of the fact
1069+ # that speculation is about to happen.
1070+ self ._prepare_draft_requests ()
10561071
10571072 scheduled_batch , fitting_disagg_gen_init_requests , num_fitting_reqs = self ._schedule (
10581073 )
@@ -1256,6 +1271,8 @@ def _executor_loop_overlap(self):
12561271 with self ._profiler () as profile_step :
12571272 iter_start_time = time .time ()
12581273 iter_stats = None
1274+ target_inputs = None
1275+ previous_tensors_device = None
12591276 can_forward = False if self .benchmark_req_queues_size > 0 and self .kv_cache_transceiver else True
12601277 while True :
12611278 profile_step ()
@@ -1329,31 +1346,29 @@ def _executor_loop_overlap(self):
13291346 self .guided_decoder .init_disagg_gen_requests ()
13301347
13311348 previous_tensors = self .previous_batch and self .previous_batch .sample_state
1332- target_inputs = None
1333- draft_outputs = None
13341349 # If there are previous draft tokens, we need to update the target requests to accept some draft tokens.
13351350 # When there's any accepted tokens, we can't directly use the previous batch's outputs in this iteration for the target model,
13361351 # so we'll set the target model's input to None and skip updating the target requests after target model forward.
13371352 use_previous_draft_tokens = self .has_previous_draft_tokens
13381353 if self .drafter is not None and (self .use_spec_decode or
13391354 use_previous_draft_tokens ):
1340- target_inputs , draft_outputs , draft_batch = self ._handle_speculative_decoding (
1341- scheduled_batch , previous_tensors )
1355+ target_inputs = self ._handle_speculative_decoding (
1356+ scheduled_batch , previous_tensors ,
1357+ previous_tensors_device )
13421358
13431359 # Use the draft_model's outputs if we've launched the draft model.
13441360 # Otherwise, use the previous batch's outputs.
1345- if target_inputs is not None or use_previous_draft_tokens :
1361+ if (target_inputs is not None
1362+ and target_inputs .next_draft_tokens
1363+ is not None ) or use_previous_draft_tokens :
13461364 previous_tensors_device = target_inputs
13471365 else :
13481366 previous_tensors_device = self .previous_batch and self .previous_batch .sample_state and self .previous_batch .sample_state .device
13491367
13501368 batch_outputs = self ._forward_step (scheduled_batch ,
13511369 previous_tensors_device )
13521370
1353- if target_inputs is not None :
1354- self ._process_draft_results (scheduled_batch ,
1355- draft_outputs , draft_batch )
1356- elif self .previous_batch is not None and not use_previous_draft_tokens :
1371+ if self .previous_batch is not None :
13571372 self ._update_requests (self .previous_batch .sample_state )
13581373
13591374 if self .block_reuse_enabled and not self .kv_cache_manager .is_vswa and self .kv_cache_transceiver :
@@ -1368,6 +1383,10 @@ def _executor_loop_overlap(self):
13681383 (req , block_id ,
13691384 self .ctx_in_transmission_counter ))
13701385
1386+ if self .drafter is not None and self .use_spec_decode :
1387+ # Cleanup previous draft resources used in the draft model
1388+ self .drafter .cleanup_previous_draft_resources ()
1389+
13711390 if self .guided_decoder is not None :
13721391 # add_batch must be called again to have updated new tokens.
13731392 self .guided_decoder .add_batch (scheduled_batch )
@@ -1402,6 +1421,94 @@ def _executor_loop_overlap(self):
14021421
14031422 self ._kv_connector_terminate_requests ()
14041423
1424+ def _accept_draft_tokens (
1425+ self , scheduled_batch : ScheduledRequests ,
1426+ target_outputs : SampleStateTensors ,
1427+ target_inputs : Optional [SampleStateTensors ]
1428+ ) -> Tuple [SampleStateTensorsMTP , Optional [torch .Tensor ]]:
1429+ """
1430+ Prepare target device inputs after computing draft token acceptance.
1431+
1432+ This function:
1433+ 1. If draft tokens exist: compares sampled tokens with draft tokens to compute acceptance
1434+ 2. If no draft tokens: directly uses the first sampled token
1435+ 3. Creates new_tokens by extracting accepted tokens per request
1436+
1437+ Args:
1438+ scheduled_batch: The scheduled requests
1439+ target_outputs: Contains new_tokens [max_draft_len + 1, batch_size, beam_width]
1440+ or [1, batch_size, beam_width] if no draft tokens
1441+ target_inputs: Contains next_draft_tokens [batch_size, max_draft_len]
1442+ Returns:
1443+ Tuple of:
1444+ - SampleStateTensorsMTP with new_tokens set to accepted tokens,
1445+ new_tokens_lens and next_draft_tokens set to None
1446+ - num_accepted_tokens: [batch_size] tensor with acceptance counts per request,
1447+ or None if no draft tokens
1448+ """
1449+ has_draft_tokens = target_inputs is not None and isinstance (
1450+ target_inputs , SampleStateTensorsMTP
1451+ ) and target_inputs .next_draft_tokens is not None
1452+ target_tokens = target_outputs .new_tokens # [max_draft_len + 1, batch_size, beam_width] or [1, batch_size, beam_width]
1453+ new_tokens = torch .zeros_like (target_tokens )
1454+
1455+ # Squeeze the beam dimension (beam_width=1 for greedy or single beam)
1456+ target_tokens = target_tokens .squeeze (
1457+ - 1 ) # [max_draft_len + 1, batch_size] or [1, batch_size]
1458+
1459+ batch_size = target_tokens .shape [1 ]
1460+ device = target_tokens .device
1461+ # Compute number of accepted tokens per request
1462+ num_accepted_tokens = torch .zeros (batch_size ,
1463+ dtype = torch .int32 ,
1464+ device = device )
1465+ # Handle case where there are no draft tokens
1466+ if has_draft_tokens :
1467+ # Draft tokens exist, compute acceptance
1468+ draft_tokens = target_inputs .next_draft_tokens # [batch_size, max_draft_len]
1469+ max_draft_len = draft_tokens .shape [1 ]
1470+
1471+ # Compute number of accepted tokens per request
1472+ # Generation requests: compare with draft tokens to find acceptance
1473+ num_contexts = len (scheduled_batch .context_requests )
1474+ if batch_size > num_contexts :
1475+ # Use .T to transpose: [max_draft_len + 1, num_gens] -> [num_gens, max_draft_len + 1]
1476+ gen_target_tokens = target_tokens [:,
1477+ num_contexts :].T # [num_gens, max_draft_len + 1]
1478+
1479+ # Compare draft tokens with target tokens to find acceptance
1480+ # Use cumprod to find the first rejection point
1481+ draft_tokens_gen = draft_tokens [
1482+ num_contexts :, :].int () # [num_gens, max_draft_len]
1483+ num_accepted_tokens [num_contexts :] += torch .cumprod (
1484+ (draft_tokens_gen == gen_target_tokens [:, :max_draft_len ]
1485+ ).int (),
1486+ dim = - 1 ).sum (dim = 1 )
1487+
1488+ # Vectorized extraction using advanced indexing (no GPU-CPU sync)
1489+ # Use num_accepted_tokens as indices to gather the right tokens
1490+ batch_indices = torch .arange (batch_size , device = device )
1491+ new_tokens [0 , :, 0 ] = target_tokens [num_accepted_tokens ,
1492+ batch_indices ]
1493+ else :
1494+ # No draft tokens to accept, just use the first (and only) sampled token
1495+ batch_indices = torch .arange (batch_size , device = device )
1496+ new_tokens [0 , :, 0 ] = target_tokens [0 , batch_indices ]
1497+
1498+ # Create the updated SampleStateTensorsMTP
1499+ # new_tokens_lens and next_draft_tokens are left as None
1500+ result_tensors = SampleStateTensorsMTP (
1501+ new_tokens = new_tokens ,
1502+ log_probs = target_outputs .log_probs ,
1503+ new_tokens_lens = None ,
1504+ next_draft_tokens = None )
1505+
1506+ # Copy logits if available
1507+ if hasattr (target_outputs , 'logits' ):
1508+ result_tensors .logits = target_outputs .logits
1509+
1510+ return result_tensors , num_accepted_tokens
1511+
14051512 def _process_previous_batch (self ):
14061513 if self .kv_cache_transceiver and self .previous_batch .ctx_transmission_reqs :
14071514 for req in self .previous_batch .ctx_transmission_reqs :
@@ -2298,7 +2405,8 @@ def _remove_inflight_ids(self, scheduled_requests):
22982405 for req in scheduled_requests .all_requests ():
22992406 self .inflight_req_ids .erase (req .request_id )
23002407
2301- def _handle_speculative_decoding (self , scheduled_batch , previous_tensors ):
2408+ def _handle_speculative_decoding (self , scheduled_batch , previous_tensors ,
2409+ target_inputs ):
23022410 with request_context (is_draft = self .draft_model_engine is not None ,
23032411 scheduled_requests = scheduled_batch ):
23042412 # Do an early checking to see if we need to forward the draft model.
@@ -2308,20 +2416,25 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
23082416 self .previous_batch is not None and self .use_spec_decode
23092417 and self .drafter .should_forward_draft_model (scheduled_batch ))
23102418
2311- if has_draft_batch or self .has_previous_draft_tokens :
2312- self ._update_requests (self .previous_batch .sample_state )
2313- if self .has_previous_draft_tokens :
2314- self ._prepare_draft_requests ()
2419+ new_target_inputs = None
2420+ if has_draft_batch :
2421+ target_outputs = self .previous_batch .sample_state and self .previous_batch .sample_state .device
2422+ assert target_outputs is not None , "target_outputs should not be None"
2423+ new_target_inputs , num_accepted_tokens_device = self ._accept_draft_tokens (
2424+ scheduled_batch = scheduled_batch ,
2425+ target_inputs = target_inputs ,
2426+ target_outputs = target_outputs )
23152427
23162428 if has_draft_batch :
2317- target_inputs , draft_outputs , draft_batch = self .drafter .generate_draft_tokens_with_overlap (
2429+ self .drafter .generate_draft_tokens_with_overlap (
23182430 scheduled_batch , self .resource_manager ,
2319- previous_tensors .device if previous_tensors else None )
2431+ previous_tensors .device if previous_tensors else None ,
2432+ new_target_inputs , num_accepted_tokens_device )
23202433
2321- self .has_previous_draft_tokens = target_inputs is not None and target_inputs .next_draft_tokens is not None
2434+ # Pad draft tokens to the max draft length for CUDA graph compatibility
2435+ self .has_previous_draft_tokens = new_target_inputs is not None and new_target_inputs .next_draft_tokens is not None
23222436 else :
23232437 self .has_previous_draft_tokens = False
2324- target_inputs , draft_outputs , draft_batch = None , None , None
23252438 # We are not running the draft model. Remove the draft tokens and turn off spec
23262439 # decode so that the requests get handled correctly.
23272440 # One corner case: when we have at least one context request, we have to keep spec
@@ -2334,34 +2447,7 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
23342447 for request in scheduled_batch .all_requests ():
23352448 request .py_draft_tokens = []
23362449
2337- return target_inputs , draft_outputs , draft_batch
2338-
2339- def _process_draft_results (self , scheduled_batch , draft_outputs ,
2340- draft_batch ):
2341- """
2342- Append the draft tokens to the target requests, and clean up the draft resources.
2343- """
2344- with request_context (is_draft = self .draft_model_engine is not None ,
2345- scheduled_requests = scheduled_batch ):
2346- req_id_to_old_request = {
2347- req .py_request_id : req
2348- for req in scheduled_batch .all_requests ()
2349- }
2350-
2351- if self .drafter .use_static_draft_loop :
2352- self .drafter .process_static_draft_outputs (
2353- draft_outputs , draft_batch , req_id_to_old_request )
2354- elif draft_outputs is not None :
2355- self .drafter .process_dynamic_draft_outputs (
2356- draft_outputs , req_id_to_old_request )
2357-
2358- # Pad draft tokens to the max draft length. This is for CUDA graph compatibility.
2359- self .drafter .pad_draft_tokens_for_cuda_graph (scheduled_batch )
2360- # add_batch must be called again to restore to target requests with updated draft tokens.
2361- if self .guided_decoder is not None :
2362- self .guided_decoder .add_batch (scheduled_batch )
2363- if hasattr (self .drafter , "guided_decoder" ):
2364- self .guided_decoder .rollback_draft_tokens ()
2450+ return new_target_inputs
23652451
23662452
23672453class DisaggPPTerminationHandler :
0 commit comments