3535from tensorrt_llm .mapping import CpType
3636from tensorrt_llm .runtime .generation import CUASSERT
3737
38+ from ..attention_backend .trtllm import TrtllmAttention
3839from ..distributed import Distributed
3940from ..models .modeling_utils import DecoderModelForCausalLM
4041from ..modules .decoder_layer import DecoderLayer
4142from ..speculative .drafter import Drafter
43+ from ..speculative .mtp import SampleStateTensorsMTP
4244from ..speculative .speculation_gate import SpeculationGate
4345from .executor_request_queue import ExecutorRequestQueue , RequestQueueItem
4446from .guided_decoder import GuidedDecoder
@@ -271,8 +273,18 @@ def __init__(self,
271273
272274 if self .dist .pp_size > 1 :
273275 self .event_loop = self ._executor_loop_pp
276+ elif self .disable_overlap_scheduler :
277+ self .event_loop = self ._executor_loop
278+ elif self .drafter is not None and (
279+ not self .drafter .use_static_draft_loop or not issubclass (
280+ self .draft_model_engine .attn_backend , TrtllmAttention )):
281+ logger .warning (
282+ "Overlap scheduler is disabled for draft model engine with non-CDL or non-TrtllmAttention attention backend."
283+ )
284+ self .disable_overlap_scheduler = True
285+ self .event_loop = self ._executor_loop
274286 else :
275- self .event_loop = self ._executor_loop if disable_overlap_scheduler else self . _executor_loop_overlap
287+ self .event_loop = self ._executor_loop_overlap
276288 if is_trace_enabled ("TLLM_TRACE_EXECUTOR_LOOP" ):
277289 self .event_loop = trace_func (self .event_loop )
278290
@@ -1045,14 +1057,11 @@ def _prepare_and_schedule_batch(self):
10451057 0
10461058 ] * max_total_draft_tokens if max_total_draft_tokens > 0 else []
10471059
1048- # When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch,
1049- # we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet.
1050- if not self .has_previous_draft_tokens :
1051- # If speculation is off, this function sets py_draft_tokens to []
1052- # for all active requests. If it's on, we initialize py_draft_tokens
1053- # with dummy draft tokens to make the scheduler aware of the fact
1054- # that speculation is about to happen.
1055- self ._prepare_draft_requests ()
1060+ # If speculation is off, this function sets py_draft_tokens to []
1061+ # for all active requests. If it's on, we initialize py_draft_tokens
1062+ # with dummy draft tokens to make the scheduler aware of the fact
1063+ # that speculation is about to happen.
1064+ self ._prepare_draft_requests ()
10561065
10571066 scheduled_batch , fitting_disagg_gen_init_requests , num_fitting_reqs = self ._schedule (
10581067 )
@@ -1256,6 +1265,8 @@ def _executor_loop_overlap(self):
12561265 with self ._profiler () as profile_step :
12571266 iter_start_time = time .time ()
12581267 iter_stats = None
1268+ target_inputs = None
1269+ previous_tensors_device = None
12591270 can_forward = False if self .benchmark_req_queues_size > 0 and self .kv_cache_transceiver else True
12601271 while True :
12611272 profile_step ()
@@ -1329,31 +1340,29 @@ def _executor_loop_overlap(self):
13291340 self .guided_decoder .init_disagg_gen_requests ()
13301341
13311342 previous_tensors = self .previous_batch and self .previous_batch .sample_state
1332- target_inputs = None
1333- draft_outputs = None
13341343 # If there are previous draft tokens, we need to update the target requests to accept some draft tokens.
13351344 # When there's any accepted tokens, we can't directly use the previous batch's outputs in this iteration for the target model,
13361345 # so we'll set the target model's input to None and skip updating the target requests after target model forward.
13371346 use_previous_draft_tokens = self .has_previous_draft_tokens
13381347 if self .drafter is not None and (self .use_spec_decode or
13391348 use_previous_draft_tokens ):
1340- target_inputs , draft_outputs , draft_batch = self ._handle_speculative_decoding (
1341- scheduled_batch , previous_tensors )
1349+ target_inputs = self ._handle_speculative_decoding (
1350+ scheduled_batch , previous_tensors ,
1351+ previous_tensors_device )
13421352
13431353 # Use the draft_model's outputs if we've launched the draft model.
13441354 # Otherwise, use the previous batch's outputs.
1345- if target_inputs is not None or use_previous_draft_tokens :
1355+ if (target_inputs is not None
1356+ and target_inputs .next_draft_tokens
1357+ is not None ) or use_previous_draft_tokens :
13461358 previous_tensors_device = target_inputs
13471359 else :
13481360 previous_tensors_device = self .previous_batch and self .previous_batch .sample_state and self .previous_batch .sample_state .device
13491361
13501362 batch_outputs = self ._forward_step (scheduled_batch ,
13511363 previous_tensors_device )
13521364
1353- if target_inputs is not None :
1354- self ._process_draft_results (scheduled_batch ,
1355- draft_outputs , draft_batch )
1356- elif self .previous_batch is not None and not use_previous_draft_tokens :
1365+ if self .previous_batch is not None :
13571366 self ._update_requests (self .previous_batch .sample_state )
13581367
13591368 if self .block_reuse_enabled and not self .kv_cache_manager .is_vswa and self .kv_cache_transceiver :
@@ -1368,6 +1377,10 @@ def _executor_loop_overlap(self):
13681377 (req , block_id ,
13691378 self .ctx_in_transmission_counter ))
13701379
1380+ if self .drafter is not None and self .use_spec_decode :
1381+ # Cleanup previous draft resources used in the draft model
1382+ self .drafter .cleanup_previous_draft_resources ()
1383+
13711384 if self .guided_decoder is not None :
13721385 # add_batch must be called again to have updated new tokens.
13731386 self .guided_decoder .add_batch (scheduled_batch )
@@ -1402,6 +1415,94 @@ def _executor_loop_overlap(self):
14021415
14031416 self ._kv_connector_terminate_requests ()
14041417
1418+ def _accept_draft_tokens (
1419+ self , scheduled_batch : ScheduledRequests ,
1420+ target_outputs : SampleStateTensors ,
1421+ target_inputs : Optional [SampleStateTensors ]
1422+ ) -> Tuple [SampleStateTensorsMTP , Optional [torch .Tensor ]]:
1423+ """
1424+ Prepare target device inputs after computing draft token acceptance.
1425+
1426+ This function:
1427+ 1. If draft tokens exist: compares sampled tokens with draft tokens to compute acceptance
1428+ 2. If no draft tokens: directly uses the first sampled token
1429+ 3. Creates new_tokens by extracting accepted tokens per request
1430+
1431+ Args:
1432+ scheduled_batch: The scheduled requests
1433+ target_outputs: Contains new_tokens [max_draft_len + 1, batch_size, beam_width]
1434+ or [1, batch_size, beam_width] if no draft tokens
1435+ target_inputs: Contains next_draft_tokens [batch_size, max_draft_len]
1436+ Returns:
1437+ Tuple of:
1438+ - SampleStateTensorsMTP with new_tokens set to accepted tokens,
1439+ new_tokens_lens and next_draft_tokens set to None
1440+ - num_accepted_tokens: [batch_size] tensor with acceptance counts per request,
1441+ or None if no draft tokens
1442+ """
1443+ has_draft_tokens = target_inputs is not None and isinstance (
1444+ target_inputs , SampleStateTensorsMTP
1445+ ) and target_inputs .next_draft_tokens is not None
1446+ target_tokens = target_outputs .new_tokens # [max_draft_len + 1, batch_size, beam_width] or [1, batch_size, beam_width]
1447+ new_tokens = torch .zeros_like (target_tokens )
1448+
1449+ # Squeeze the beam dimension (beam_width=1 for greedy or single beam)
1450+ target_tokens = target_tokens .squeeze (
1451+ - 1 ) # [max_draft_len + 1, batch_size] or [1, batch_size]
1452+
1453+ batch_size = target_tokens .shape [1 ]
1454+ device = target_tokens .device
1455+ # Compute number of accepted tokens per request
1456+ num_accepted_tokens = torch .zeros (batch_size ,
1457+ dtype = torch .int32 ,
1458+ device = device )
1459+ # Handle case where there are no draft tokens
1460+ if has_draft_tokens :
1461+ # Draft tokens exist, compute acceptance
1462+ draft_tokens = target_inputs .next_draft_tokens # [batch_size, max_draft_len]
1463+ max_draft_len = draft_tokens .shape [1 ]
1464+
1465+ # Compute number of accepted tokens per request
1466+ # Generation requests: compare with draft tokens to find acceptance
1467+ num_contexts = len (scheduled_batch .context_requests )
1468+ if batch_size > num_contexts :
1469+ # Use .T to transpose: [max_draft_len + 1, num_gens] -> [num_gens, max_draft_len + 1]
1470+ gen_target_tokens = target_tokens [:,
1471+ num_contexts :].T # [num_gens, max_draft_len + 1]
1472+
1473+ # Compare draft tokens with target tokens to find acceptance
1474+ # Use cumprod to find the first rejection point
1475+ draft_tokens_gen = draft_tokens [
1476+ num_contexts :, :].int () # [num_gens, max_draft_len]
1477+ num_accepted_tokens [num_contexts :] += torch .cumprod (
1478+ (draft_tokens_gen == gen_target_tokens [:, :max_draft_len ]
1479+ ).int (),
1480+ dim = - 1 ).sum (dim = 1 )
1481+
1482+ # Vectorized extraction using advanced indexing (no GPU-CPU sync)
1483+ # Use num_accepted_tokens as indices to gather the right tokens
1484+ batch_indices = torch .arange (batch_size , device = device )
1485+ new_tokens [0 , :, 0 ] = target_tokens [num_accepted_tokens ,
1486+ batch_indices ]
1487+ else :
1488+ # No draft tokens to accept, just use the first (and only) sampled token
1489+ batch_indices = torch .arange (batch_size , device = device )
1490+ new_tokens [0 , :, 0 ] = target_tokens [0 , batch_indices ]
1491+
1492+ # Create the updated SampleStateTensorsMTP
1493+ # new_tokens_lens and next_draft_tokens are left as None
1494+ result_tensors = SampleStateTensorsMTP (
1495+ new_tokens = new_tokens ,
1496+ log_probs = target_outputs .log_probs ,
1497+ new_tokens_lens = None ,
1498+ next_draft_tokens = None )
1499+
1500+ # Copy logits if available
1501+ if hasattr (target_outputs , 'logits' ):
1502+ result_tensors .logits = target_outputs .logits
1503+
1504+ return result_tensors , num_accepted_tokens
1505+
14051506 def _process_previous_batch (self ):
14061507 if self .kv_cache_transceiver and self .previous_batch .ctx_transmission_reqs :
14071508 for req in self .previous_batch .ctx_transmission_reqs :
@@ -2298,7 +2399,8 @@ def _remove_inflight_ids(self, scheduled_requests):
22982399 for req in scheduled_requests .all_requests ():
22992400 self .inflight_req_ids .erase (req .request_id )
23002401
2301- def _handle_speculative_decoding (self , scheduled_batch , previous_tensors ):
2402+ def _handle_speculative_decoding (self , scheduled_batch , previous_tensors ,
2403+ target_inputs ):
23022404 with request_context (is_draft = self .draft_model_engine is not None ,
23032405 scheduled_requests = scheduled_batch ):
23042406 # Do an early checking to see if we need to forward the draft model.
@@ -2308,20 +2410,25 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
23082410 self .previous_batch is not None and self .use_spec_decode
23092411 and self .drafter .should_forward_draft_model (scheduled_batch ))
23102412
2311- if has_draft_batch or self .has_previous_draft_tokens :
2312- self ._update_requests (self .previous_batch .sample_state )
2313- if self .has_previous_draft_tokens :
2314- self ._prepare_draft_requests ()
2413+ new_target_inputs = None
2414+ if has_draft_batch :
2415+ target_outputs = self .previous_batch .sample_state and self .previous_batch .sample_state .device
2416+ assert target_outputs is not None , "target_outputs should not be None"
2417+ new_target_inputs , num_accepted_tokens_device = self ._accept_draft_tokens (
2418+ scheduled_batch = scheduled_batch ,
2419+ target_inputs = target_inputs ,
2420+ target_outputs = target_outputs )
23152421
23162422 if has_draft_batch :
2317- target_inputs , draft_outputs , draft_batch = self .drafter .generate_draft_tokens_with_overlap (
2423+ self .drafter .generate_draft_tokens_with_overlap (
23182424 scheduled_batch , self .resource_manager ,
2319- previous_tensors .device if previous_tensors else None )
2425+ previous_tensors .device if previous_tensors else None ,
2426+ new_target_inputs , num_accepted_tokens_device )
23202427
2321- self .has_previous_draft_tokens = target_inputs is not None and target_inputs .next_draft_tokens is not None
2428+ # Pad draft tokens to the max draft length for CUDA graph compatibility
2429+ self .has_previous_draft_tokens = new_target_inputs is not None and new_target_inputs .next_draft_tokens is not None
23222430 else :
23232431 self .has_previous_draft_tokens = False
2324- target_inputs , draft_outputs , draft_batch = None , None , None
23252432 # We are not running the draft model. Remove the draft tokens and turn off spec
23262433 # decode so that the requests get handled correctly.
23272434 # One corner case: when we have at least one context request, we have to keep spec
@@ -2334,34 +2441,7 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
23342441 for request in scheduled_batch .all_requests ():
23352442 request .py_draft_tokens = []
23362443
2337- return target_inputs , draft_outputs , draft_batch
2338-
2339- def _process_draft_results (self , scheduled_batch , draft_outputs ,
2340- draft_batch ):
2341- """
2342- Append the draft tokens to the target requests, and clean up the draft resources.
2343- """
2344- with request_context (is_draft = self .draft_model_engine is not None ,
2345- scheduled_requests = scheduled_batch ):
2346- req_id_to_old_request = {
2347- req .py_request_id : req
2348- for req in scheduled_batch .all_requests ()
2349- }
2350-
2351- if self .drafter .use_static_draft_loop :
2352- self .drafter .process_static_draft_outputs (
2353- draft_outputs , draft_batch , req_id_to_old_request )
2354- elif draft_outputs is not None :
2355- self .drafter .process_dynamic_draft_outputs (
2356- draft_outputs , req_id_to_old_request )
2357-
2358- # Pad draft tokens to the max draft length. This is for CUDA graph compatibility.
2359- self .drafter .pad_draft_tokens_for_cuda_graph (scheduled_batch )
2360- # add_batch must be called again to restore to target requests with updated draft tokens.
2361- if self .guided_decoder is not None :
2362- self .guided_decoder .add_batch (scheduled_batch )
2363- if hasattr (self .drafter , "guided_decoder" ):
2364- self .guided_decoder .rollback_draft_tokens ()
2444+ return new_target_inputs
23652445
23662446
23672447class DisaggPPTerminationHandler :
0 commit comments