3535from tensorrt_llm .mapping import CpType
3636from tensorrt_llm .runtime .generation import CUASSERT
3737
38+ from ..attention_backend .trtllm import TrtllmAttention
3839from ..distributed import Distributed
3940from ..models .modeling_utils import DecoderModelForCausalLM
4041from ..modules .decoder_layer import DecoderLayer
4142from ..speculative .drafter import Drafter
43+ from ..speculative .model_drafter import ModelDrafter
44+ from ..speculative .mtp import SampleStateTensorsMTP
4245from ..speculative .speculation_gate import SpeculationGate
4346from .executor_request_queue import ExecutorRequestQueue , RequestQueueItem
4447from .guided_decoder import GuidedDecoder
@@ -275,8 +278,23 @@ def __init__(self,
275278
276279 if self .dist .pp_size > 1 :
277280 self .event_loop = self ._executor_loop_pp
281+ elif self .disable_overlap_scheduler :
282+ self .event_loop = self ._executor_loop
278283 else :
279- self .event_loop = self ._executor_loop if disable_overlap_scheduler else self ._executor_loop_overlap
284+ # TODO: Overlap scheduler is not supported for below cases:
285+ # 1. non-CDL is used
286+ # 2. non-TrtllmAttention attention backend is used
287+ overlap_not_supported = self .drafter is not None and isinstance (
288+ self .drafter , ModelDrafter ) and (
289+ not self .drafter .use_static_draft_loop or not issubclass (
290+ self .draft_model_engine .attn_backend , TrtllmAttention ))
291+
292+ if overlap_not_supported :
293+ logger .warning (
294+ "Overlap scheduler is disabled for draft model engine with non-CDL or non-TrtllmAttention attention backend."
295+ )
296+ self .disable_overlap_scheduler = True
297+ self .event_loop = self ._executor_loop_overlap if not overlap_not_supported else self ._executor_loop
280298 if is_trace_enabled ("TLLM_TRACE_EXECUTOR_LOOP" ):
281299 self .event_loop = trace_func (self .event_loop )
282300
@@ -1060,14 +1078,11 @@ def _prepare_and_schedule_batch(self):
10601078 0
10611079 ] * max_total_draft_tokens if max_total_draft_tokens > 0 else []
10621080
1063- # When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch,
1064- # we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet.
1065- if not self .has_previous_draft_tokens :
1066- # If speculation is off, this function sets py_draft_tokens to []
1067- # for all active requests. If it's on, we initialize py_draft_tokens
1068- # with dummy draft tokens to make the scheduler aware of the fact
1069- # that speculation is about to happen.
1070- self ._prepare_draft_requests ()
1081+ # If speculation is off, this function sets py_draft_tokens to []
1082+ # for all active requests. If it's on, we initialize py_draft_tokens
1083+ # with dummy draft tokens to make the scheduler aware of the fact
1084+ # that speculation is about to happen.
1085+ self ._prepare_draft_requests ()
10711086
10721087 scheduled_batch , fitting_disagg_gen_init_requests , num_fitting_reqs = self ._schedule (
10731088 )
@@ -1317,6 +1332,8 @@ def _executor_loop_overlap(self):
13171332 with self ._profiler () as profile_step :
13181333 iter_start_time = time .time ()
13191334 iter_stats = None
1335+ target_inputs = None
1336+ previous_tensors_device = None
13201337 can_forward = False if self .benchmark_req_queues_size > 0 and self .kv_cache_transceiver else True
13211338 while True :
13221339 profile_step ()
@@ -1397,31 +1414,29 @@ def _executor_loop_overlap(self):
13971414 self .guided_decoder .init_disagg_gen_requests ()
13981415
13991416 previous_tensors = self .previous_batch and self .previous_batch .sample_state
1400- target_inputs = None
1401- draft_outputs = None
14021417 # If there are previous draft tokens, we need to update the target requests to accept some draft tokens.
14031418 # When there's any accepted tokens, we can't directly use the previous batch's outputs in this iteration for the target model,
14041419 # so we'll set the target model's input to None and skip updating the target requests after target model forward.
14051420 use_previous_draft_tokens = self .has_previous_draft_tokens
14061421 if self .drafter is not None and (self .use_spec_decode or
14071422 use_previous_draft_tokens ):
1408- target_inputs , draft_outputs , draft_batch = self ._handle_speculative_decoding (
1409- scheduled_batch , previous_tensors )
1423+ target_inputs = self ._handle_speculative_decoding (
1424+ scheduled_batch , previous_tensors ,
1425+ previous_tensors_device )
14101426
14111427 # Use the draft_model's outputs if we've launched the draft model.
14121428 # Otherwise, use the previous batch's outputs.
1413- if target_inputs is not None or use_previous_draft_tokens :
1429+ if (target_inputs is not None
1430+ and target_inputs .next_draft_tokens
1431+ is not None ) or use_previous_draft_tokens :
14141432 previous_tensors_device = target_inputs
14151433 else :
14161434 previous_tensors_device = self .previous_batch and self .previous_batch .sample_state and self .previous_batch .sample_state .device
14171435
14181436 batch_outputs = self ._forward_step (scheduled_batch ,
14191437 previous_tensors_device )
14201438
1421- if target_inputs is not None :
1422- self ._process_draft_results (scheduled_batch ,
1423- draft_outputs , draft_batch )
1424- elif self .previous_batch is not None and not use_previous_draft_tokens :
1439+ if self .previous_batch is not None :
14251440 self ._update_requests (self .previous_batch .sample_state )
14261441
14271442 if self .block_reuse_enabled and not self .kv_cache_manager .is_vswa and self .kv_cache_transceiver :
@@ -1436,6 +1451,10 @@ def _executor_loop_overlap(self):
14361451 (req , block_id ,
14371452 self .ctx_in_transmission_counter ))
14381453
1454+ if self .drafter is not None and self .use_spec_decode :
1455+ # Cleanup previous draft resources used in the draft model
1456+ self .drafter .cleanup_previous_draft_resources ()
1457+
14391458 if self .guided_decoder is not None :
14401459 # add_batch must be called again to have updated new tokens.
14411460 self .guided_decoder .add_batch (scheduled_batch )
@@ -1470,6 +1489,94 @@ def _executor_loop_overlap(self):
14701489
14711490 self ._kv_connector_terminate_requests ()
14721491
1492+ def _accept_draft_tokens (
1493+ self , scheduled_batch : ScheduledRequests ,
1494+ target_outputs : SampleStateTensors ,
1495+ target_inputs : Optional [SampleStateTensors ]
1496+ ) -> Tuple [SampleStateTensorsMTP , Optional [torch .Tensor ]]:
1497+ """
1498+ Prepare target device inputs after computing draft token acceptance.
1499+
1500+ This function:
1501+ 1. If draft tokens exist: compares sampled tokens with draft tokens to compute acceptance
1502+ 2. If no draft tokens: directly uses the first sampled token
1503+ 3. Creates new_tokens by extracting accepted tokens per request
1504+
1505+ Args:
1506+ scheduled_batch: The scheduled requests
1507+ target_outputs: Contains new_tokens [max_draft_len + 1, batch_size, beam_width]
1508+ or [1, batch_size, beam_width] if no draft tokens
1509+ target_inputs: Contains next_draft_tokens [batch_size, max_draft_len]
1510+ Returns:
1511+ Tuple of:
1512+ - SampleStateTensorsMTP with new_tokens set to accepted tokens,
1513+ new_tokens_lens and next_draft_tokens set to None
1514+ - num_accepted_tokens: [batch_size] tensor with acceptance counts per request,
1515+ or None if no draft tokens
1516+ """
1517+ has_draft_tokens = target_inputs is not None and isinstance (
1518+ target_inputs , SampleStateTensorsMTP
1519+ ) and target_inputs .next_draft_tokens is not None
1520+ target_tokens = target_outputs .new_tokens # [max_draft_len + 1, batch_size, beam_width] or [1, batch_size, beam_width]
1521+ new_tokens = torch .zeros_like (target_tokens )
1522+
1523+ # Squeeze the beam dimension (beam_width=1 for greedy or single beam)
1524+ target_tokens = target_tokens .squeeze (
1525+ - 1 ) # [max_draft_len + 1, batch_size] or [1, batch_size]
1526+
1527+ batch_size = target_tokens .shape [1 ]
1528+ device = target_tokens .device
1529+ # Compute number of accepted tokens per request
1530+ num_accepted_tokens = torch .zeros (batch_size ,
1531+ dtype = torch .int32 ,
1532+ device = device )
1533+ # Handle case where there are no draft tokens
1534+ if has_draft_tokens :
1535+ # Draft tokens exist, compute acceptance
1536+ draft_tokens = target_inputs .next_draft_tokens # [batch_size, max_draft_len]
1537+ max_draft_len = draft_tokens .shape [1 ]
1538+
1539+ # Compute number of accepted tokens per request
1540+ # Generation requests: compare with draft tokens to find acceptance
1541+ num_contexts = len (scheduled_batch .context_requests )
1542+ if batch_size > num_contexts :
1543+ # Use .T to transpose: [max_draft_len + 1, num_gens] -> [num_gens, max_draft_len + 1]
1544+ gen_target_tokens = target_tokens [:,
1545+ num_contexts :].T # [num_gens, max_draft_len + 1]
1546+
1547+ # Compare draft tokens with target tokens to find acceptance
1548+ # Use cumprod to find the first rejection point
1549+ draft_tokens_gen = draft_tokens [
1550+ num_contexts :, :].int () # [num_gens, max_draft_len]
1551+ num_accepted_tokens [num_contexts :] += torch .cumprod (
1552+ (draft_tokens_gen == gen_target_tokens [:, :max_draft_len ]
1553+ ).int (),
1554+ dim = - 1 ).sum (dim = 1 )
1555+
1556+ # Vectorized extraction using advanced indexing (no GPU-CPU sync)
1557+ # Use num_accepted_tokens as indices to gather the right tokens
1558+ batch_indices = torch .arange (batch_size , device = device )
1559+ new_tokens [0 , :, 0 ] = target_tokens [num_accepted_tokens ,
1560+ batch_indices ]
1561+ else :
1562+ # No draft tokens to accept, just use the first (and only) sampled token
1563+ batch_indices = torch .arange (batch_size , device = device )
1564+ new_tokens [0 , :, 0 ] = target_tokens [0 , batch_indices ]
1565+
1566+ # Create the updated SampleStateTensorsMTP
1567+ # new_tokens_lens and next_draft_tokens are left as None
1568+ result_tensors = SampleStateTensorsMTP (
1569+ new_tokens = new_tokens ,
1570+ log_probs = target_outputs .log_probs ,
1571+ new_tokens_lens = None ,
1572+ next_draft_tokens = None )
1573+
1574+ # Copy logits if available
1575+ if hasattr (target_outputs , 'logits' ):
1576+ result_tensors .logits = target_outputs .logits
1577+
1578+ return result_tensors , num_accepted_tokens
1579+
14731580 def _process_previous_batch (self ):
14741581 if self .kv_cache_transceiver and self .previous_batch .ctx_transmission_reqs :
14751582 for req in self .previous_batch .ctx_transmission_reqs :
@@ -2366,7 +2473,8 @@ def _remove_inflight_ids(self, scheduled_requests):
23662473 for req in scheduled_requests .all_requests ():
23672474 self .inflight_req_ids .erase (req .request_id )
23682475
2369- def _handle_speculative_decoding (self , scheduled_batch , previous_tensors ):
2476+ def _handle_speculative_decoding (self , scheduled_batch , previous_tensors ,
2477+ target_inputs ):
23702478 with request_context (is_draft = self .draft_model_engine is not None ,
23712479 scheduled_requests = scheduled_batch ):
23722480 # Do an early checking to see if we need to forward the draft model.
@@ -2376,20 +2484,25 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
23762484 self .previous_batch is not None and self .use_spec_decode
23772485 and self .drafter .should_forward_draft_model (scheduled_batch ))
23782486
2379- if has_draft_batch or self .has_previous_draft_tokens :
2380- self ._update_requests (self .previous_batch .sample_state )
2381- if self .has_previous_draft_tokens :
2382- self ._prepare_draft_requests ()
2487+ new_target_inputs = None
2488+ if has_draft_batch :
2489+ target_outputs = self .previous_batch .sample_state and self .previous_batch .sample_state .device
2490+ assert target_outputs is not None , "target_outputs should not be None"
2491+ new_target_inputs , num_accepted_tokens_device = self ._accept_draft_tokens (
2492+ scheduled_batch = scheduled_batch ,
2493+ target_inputs = target_inputs ,
2494+ target_outputs = target_outputs )
23832495
23842496 if has_draft_batch :
2385- target_inputs , draft_outputs , draft_batch = self .drafter .generate_draft_tokens_with_overlap (
2497+ self .drafter .generate_draft_tokens_with_overlap (
23862498 scheduled_batch , self .resource_manager ,
2387- previous_tensors .device if previous_tensors else None )
2499+ previous_tensors .device if previous_tensors else None ,
2500+ new_target_inputs , num_accepted_tokens_device )
23882501
2389- self .has_previous_draft_tokens = target_inputs is not None and target_inputs .next_draft_tokens is not None
2502+ # Pad draft tokens to the max draft length for CUDA graph compatibility
2503+ self .has_previous_draft_tokens = new_target_inputs is not None and new_target_inputs .next_draft_tokens is not None
23902504 else :
23912505 self .has_previous_draft_tokens = False
2392- target_inputs , draft_outputs , draft_batch = None , None , None
23932506 # We are not running the draft model. Remove the draft tokens and turn off spec
23942507 # decode so that the requests get handled correctly.
23952508 # One corner case: when we have at least one context request, we have to keep spec
@@ -2402,34 +2515,7 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
24022515 for request in scheduled_batch .all_requests ():
24032516 request .py_draft_tokens = []
24042517
2405- return target_inputs , draft_outputs , draft_batch
2406-
2407- def _process_draft_results (self , scheduled_batch , draft_outputs ,
2408- draft_batch ):
2409- """
2410- Append the draft tokens to the target requests, and clean up the draft resources.
2411- """
2412- with request_context (is_draft = self .draft_model_engine is not None ,
2413- scheduled_requests = scheduled_batch ):
2414- req_id_to_old_request = {
2415- req .py_request_id : req
2416- for req in scheduled_batch .all_requests ()
2417- }
2418-
2419- if self .drafter .use_static_draft_loop :
2420- self .drafter .process_static_draft_outputs (
2421- draft_outputs , draft_batch , req_id_to_old_request )
2422- elif draft_outputs is not None :
2423- self .drafter .process_dynamic_draft_outputs (
2424- draft_outputs , req_id_to_old_request )
2425-
2426- # Pad draft tokens to the max draft length. This is for CUDA graph compatibility.
2427- self .drafter .pad_draft_tokens_for_cuda_graph (scheduled_batch )
2428- # add_batch must be called again to restore to target requests with updated draft tokens.
2429- if self .guided_decoder is not None :
2430- self .guided_decoder .add_batch (scheduled_batch )
2431- if hasattr (self .drafter , "guided_decoder" ):
2432- self .guided_decoder .rollback_draft_tokens ()
2518+ return new_target_inputs
24332519
24342520 def reset_prefix_cache (self ):
24352521 self .kv_cache_manager .reset_reuse_state ()
0 commit comments