From 12659b296c10355bb12b2daa8c11bf98e7a7529d Mon Sep 17 00:00:00 2001 From: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com> Date: Tue, 4 Nov 2025 00:32:08 -0800 Subject: [PATCH 1/2] [TRTLLM-8084][feat] Enhance overlap scheduler for two-model spec decoding Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com> --- .../_torch/pyexecutor/model_engine.py | 305 ++++++++++++++-- tensorrt_llm/_torch/pyexecutor/py_executor.py | 196 +++++++--- tensorrt_llm/_torch/pyexecutor/sampler.py | 4 +- .../_torch/speculative/model_drafter.py | 342 +++++++++++------- .../_torch/speculative/test_eagle3.py | 39 +- 5 files changed, 650 insertions(+), 236 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index e3c12e36b49..9118662fc16 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -58,7 +58,7 @@ from .cuda_graph_runner import CUDAGraphRunner from .guided_decoder import CapturableGuidedDecoder from .layerwise_nvtx_marker import LayerwiseNvtxMarker -from .llm_request import get_draft_token_length +from .llm_request import LlmRequest, get_draft_token_length from .model_loader import ModelLoader, _construct_checkpoint_loader from .resource_manager import (BaseResourceManager, KVCacheManager, ResourceManager, ResourceManagerType) @@ -73,14 +73,13 @@ def get_max_num_sequences(self) -> int: raise NotImplementedError @abstractmethod - def forward( - self, - scheduled_requests: ScheduledRequests, - resource_manager: ResourceManager, - new_tensors_device: Optional[SampleStateTensors], - gather_context_logits: bool = False, - cache_indirection_buffer: Optional[torch.Tensor] = None, - ): + def forward(self, + scheduled_requests: ScheduledRequests, + resource_manager: ResourceManager, + new_tensors_device: Optional[SampleStateTensors], + gather_context_logits: bool = False, + cache_indirection_buffer: Optional[torch.Tensor] = None, + num_accepted_tokens_device: Optional[torch.Tensor] = None): raise NotImplementedError def warmup(self, resource_manager: ResourceManager) -> None: @@ -367,6 +366,31 @@ def __init__( (3, 1, self.max_num_tokens), dtype=torch.int, device='cuda') self.iter_counter = 0 + # Pre-allocated buffers for draft model to avoid implicit synchronization + # These are used to build index tensors without creating tensors from Python lists + if is_draft_model: + # Buffers for context and first_draft input_ids updates + self.draft_ctx_token_indices_cuda = torch.empty((self.batch_size, ), + dtype=torch.long, + device='cuda') + self.draft_ctx_seq_slots_cuda = torch.empty((self.batch_size, ), + dtype=torch.long, + device='cuda') + # Buffers for first_draft requests (max_draft_len+1 tokens per request) + max_first_draft_tokens = self.batch_size * ( + self.original_max_draft_len + + 1) if spec_config else self.batch_size + self.draft_first_draft_indices_cuda = torch.empty( + (max_first_draft_tokens, ), dtype=torch.long, device='cuda') + self.draft_first_draft_seq_slots_cuda = torch.empty( + (max_first_draft_tokens, ), dtype=torch.long, device='cuda') + # Buffers for seq_slots and request indices + self.draft_seq_slots_buffer_cuda = torch.empty((self.batch_size, ), + dtype=torch.int, + device='cuda') + self.draft_request_indices_buffer_cuda = torch.empty( + (self.batch_size, ), dtype=torch.int, device='cuda') + # We look up this key in resource_manager during forward to find the # kv cache manager. Can be changed to support multiple model engines # with different KV cache managers. @@ -586,6 +610,13 @@ def _run_autotuner_warmup(self, resource_manager: ResourceManager): with self._release_batch_context(warmup_request, resource_manager) as batch: if batch is not None: + # Reset the flag is_first_draft for the draft model. + # This is necessary for overlap scheduler. + spec_resource_manager = resource_manager.get_resource_manager( + ResourceManagerType.SPEC_RESOURCE_MANAGER) + if self.is_draft_model and isinstance( + spec_resource_manager, Eagle3ResourceManager): + spec_resource_manager.is_first_draft = True self.forward(batch, new_tensors_device=None, resource_manager=resource_manager) @@ -1223,10 +1254,13 @@ def _prepare_tp_inputs( attn_metadata: AttentionMetadata, spec_metadata: Optional[SpecMetadata] = None, new_tensors_device: Optional[SampleStateTensors] = None, - cache_indirection_buffer: Optional[torch.Tensor] = None): + cache_indirection_buffer: Optional[torch.Tensor] = None, + num_accepted_tokens_device: Optional[torch.Tensor] = None, + req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None): """ Prepare inputs for Pytorch Model. """ + new_tokens_device, new_tokens_lens_device, next_draft_tokens_device = None, None, None if new_tensors_device is not None: # speculative decoding cases: [batch, 1 + draft_len], others: [batch] @@ -1257,6 +1291,19 @@ def _prepare_tp_inputs( mrope_position_ids = [] num_accepted_draft_tokens = [] # per request + # Variables for updating the inputs of draft model + # Base values for gather_ids computation + first_draft_base_gather_ids = [] + # seq_slots to index into num_accepted_tokens_device + first_draft_seq_slots = [] + # Indices in the num_accepted_draft_tokens list + first_draft_request_indices = [] + + # (start_idx, end_idx, seq_slot) for context requests + context_input_ids_positions = [] + # (start_idx, end_idx, seq_slot) for first_draft requests + first_draft_input_ids_positions = [] + for request in scheduled_requests.context_requests: request_ids.append(request.py_request_id) all_prompt_tokens = request.get_tokens(0) @@ -1266,7 +1313,20 @@ def _prepare_tp_inputs( prompt_tokens = all_prompt_tokens[begin_compute:end_compute] position_ids.extend( range(begin_compute, begin_compute + len(prompt_tokens))) - input_ids.extend(prompt_tokens) + + # Track position for updating the inputs of draft model + if self.is_draft_model and num_accepted_tokens_device is not None: + start_idx = len(input_ids) + input_ids.extend(prompt_tokens) + end_idx = len(input_ids) + slot_idx = req_id_to_old_request[ + request.py_request_id].py_seq_slot + context_input_ids_positions.append( + (start_idx, end_idx - 1, + slot_idx)) # end_idx-1 is the last token position + else: + input_ids.extend(prompt_tokens) + gather_ids.append(len(input_ids) - 1) sequence_lengths.append(len(prompt_tokens)) num_accepted_draft_tokens.append(len(prompt_tokens) - 1) @@ -1410,14 +1470,9 @@ def _prepare_tp_inputs( previous_batch_indices.append(previous_batch_idx) previous_pos_indices.extend([previous_batch_idx] * (1 + self.runtime_draft_len)) - if self.spec_config.spec_dec_mode.has_draft_model(): - # In the overlap scheduler workflow, if having draft model, we already updated the previous batch before launching the target model, - # so we only need to add the runtime_draft_len to the past_seen_token_num. - num_cached_tokens_per_seq.append(past_seen_token_num + - self.runtime_draft_len) - else: - num_cached_tokens_per_seq.append(past_seen_token_num + - self.runtime_draft_len + 1) + + num_cached_tokens_per_seq.append(past_seen_token_num + + self.runtime_draft_len + 1) request.cached_tokens = num_cached_tokens_per_seq[-1] if self.enable_spec_decode and spec_config.spec_dec_mode.extend_ctx( self.attn_backend): @@ -1435,13 +1490,39 @@ def _prepare_tp_inputs( prompt_tokens = all_prompt_tokens[begin_compute:end_compute] position_ids.extend( range(begin_compute, begin_compute + len(prompt_tokens))) - input_ids.extend(prompt_tokens) - gather_ids.append( - len(input_ids) - 1 - (self.original_max_draft_len - - request.py_num_accepted_draft_tokens)) + + # Track position for updating the inputs of draft model + if self.is_draft_model and num_accepted_tokens_device is not None: + start_idx = len(input_ids) + input_ids.extend(prompt_tokens) + end_idx = len(input_ids) + # For first_draft, we need to replace the last original_max_draft_len+1 tokens + slot_idx = req_id_to_old_request[ + request.py_request_id].py_seq_slot + first_draft_input_ids_positions.append( + (start_idx, end_idx, slot_idx)) + + # Store info for GPU computation of gather_ids and num_accepted_draft_tokens + base_gather_id = len( + input_ids) - 1 - self.original_max_draft_len + # Placeholder, will be corrected on GPU + gather_ids.append(base_gather_id) + first_draft_base_gather_ids.append(base_gather_id) + first_draft_seq_slots.append(slot_idx) + first_draft_request_indices.append( + len(num_accepted_draft_tokens)) + + # Placeholder, will be corrected on GPU + num_accepted_draft_tokens.append(0) + else: + input_ids.extend(prompt_tokens) + gather_ids.append( + len(input_ids) - 1 - (self.original_max_draft_len - + request.py_num_accepted_draft_tokens)) + num_accepted_draft_tokens.append( + request.py_num_accepted_draft_tokens) + sequence_lengths.append(1 + self.original_max_draft_len) - num_accepted_draft_tokens.append( - request.py_num_accepted_draft_tokens) prompt_lengths.append(request.py_prompt_len) past_seen_token_num = begin_compute num_cached_tokens_per_seq.append(past_seen_token_num) @@ -1461,7 +1542,17 @@ def _prepare_tp_inputs( # skip adding input_ids of CUDA graph dummy requests so that new_tokens_device # can be aligned to the correct positions. if not request.is_cuda_graph_dummy: - input_ids.append(request.get_last_tokens(beam)) + # Track position for GPU update (draft model only) + if self.is_draft_model and num_accepted_tokens_device is not None: + start_idx = len(input_ids) + input_ids.append(request.get_last_tokens(beam)) + end_idx = len(input_ids) + slot_idx = req_id_to_old_request[ + request.py_request_id].py_seq_slot + first_draft_input_ids_positions.append( + (start_idx, end_idx, slot_idx)) + else: + input_ids.append(request.get_last_tokens(beam)) past_seen_token_num = request.max_beam_num_tokens - 1 else: # the request has previous tensor @@ -1536,6 +1627,79 @@ def previous_seq_slots_device(): dtype=torch.int, pin_memory=True) self.input_ids_cuda[:num_tokens].copy_(input_ids, non_blocking=True) + + # Update input_ids_cuda with new tokens from new_tensors_device (draft model only) + if self.is_draft_model and num_accepted_tokens_device is not None: + # For context requests: replace the last token with new_tensors_device[0, seq_slot, 0] + if len(context_input_ids_positions) > 0: + # Build tensors on CPU first, then copy to GPU to avoid implicit sync + num_ctx_positions = len(context_input_ids_positions) + ctx_token_indices_cpu = torch.tensor([ + last_token_idx + for _, last_token_idx, _ in context_input_ids_positions + ], + dtype=torch.long, + pin_memory=True) + ctx_seq_slots_cpu = torch.tensor([ + seq_slot + for _, _, seq_slot in context_input_ids_positions + ], + dtype=torch.long, + pin_memory=True) + # Copy to pre-allocated GPU buffers + self.draft_ctx_token_indices_cuda[:num_ctx_positions].copy_( + ctx_token_indices_cpu, non_blocking=True) + self.draft_ctx_seq_slots_cuda[:num_ctx_positions].copy_( + ctx_seq_slots_cpu, non_blocking=True) + self.input_ids_cuda[ + self. + draft_ctx_token_indices_cuda[:num_ctx_positions]] = new_tensors_device.new_tokens[ + 0, + self.draft_ctx_seq_slots_cuda[:num_ctx_positions], + 0] + + # For first_draft requests: replace the last (original_max_draft_len+1) tokens + # with new_tensors_device[:, seq_slot, 0] + if len(first_draft_input_ids_positions) > 0: + # All first_draft requests have same token length (original_max_draft_len + 1) + # Build index tensors on CPU first, then copy to GPU to avoid implicit sync + num_requests = len(first_draft_input_ids_positions) + tokens_per_request = first_draft_input_ids_positions[0][ + 1] - first_draft_input_ids_positions[0][0] + + # Create flat index array for all tokens to update on CPU + all_indices = [] + all_seq_slots = [] + for start_idx, end_idx, seq_slot in first_draft_input_ids_positions: + all_indices.extend(range(start_idx, end_idx)) + all_seq_slots.extend([seq_slot] * (end_idx - start_idx)) + + # Create CPU tensors with pinned memory + total_tokens = len(all_indices) + idx_tensor_cpu = torch.tensor(all_indices, + dtype=torch.long, + pin_memory=True) + seq_slots_tensor_cpu = torch.tensor(all_seq_slots, + dtype=torch.long, + pin_memory=True) + + # Copy to pre-allocated GPU buffers + self.draft_first_draft_indices_cuda[:total_tokens].copy_( + idx_tensor_cpu, non_blocking=True) + self.draft_first_draft_seq_slots_cuda[:total_tokens].copy_( + seq_slots_tensor_cpu, non_blocking=True) + + # Create token position indices (repeating 0..tokens_per_request for each request) + token_positions = torch.arange( + tokens_per_request, dtype=torch.long, + device='cuda').repeat(num_requests) + + self.input_ids_cuda[ + self. + draft_first_draft_indices_cuda[:total_tokens]] = new_tensors_device.new_tokens[ + token_positions, self. + draft_first_draft_seq_slots_cuda[:total_tokens], 0] + if num_draft_tokens > 0: draft_tokens = torch.tensor(draft_tokens, dtype=torch.int, @@ -1549,6 +1713,33 @@ def previous_seq_slots_device(): self.num_accepted_draft_tokens_cuda[:len( num_accepted_draft_tokens)].copy_(num_accepted_draft_tokens, non_blocking=True) + + # Update num_accepted_draft_tokens_cuda for first_draft_requests directly from num_accepted_tokens_device (draft model only) + if self.is_draft_model and len(first_draft_seq_slots) > 0: + # Build tensors on CPU first, then copy to GPU to avoid implicit sync + num_first_draft = len(first_draft_seq_slots) + first_draft_seq_slots_cpu = torch.tensor(first_draft_seq_slots, + dtype=torch.int, + pin_memory=True) + first_draft_indices_cpu = torch.tensor( + first_draft_request_indices, + dtype=torch.int, + pin_memory=True) + + # Copy to pre-allocated GPU buffers + self.draft_seq_slots_buffer_cuda[:num_first_draft].copy_( + first_draft_seq_slots_cpu, non_blocking=True) + self.draft_request_indices_buffer_cuda[:num_first_draft].copy_( + first_draft_indices_cpu, non_blocking=True) + + # Extract accepted tokens for first_draft requests from device tensor + accepted_tokens = num_accepted_tokens_device[ + self.draft_seq_slots_buffer_cuda[:num_first_draft]] + # Update the correct positions in num_accepted_draft_tokens_cuda + self.num_accepted_draft_tokens_cuda[ + self. + draft_request_indices_buffer_cuda[: + num_first_draft]] = accepted_tokens if next_draft_tokens_device is not None: # Initialize these two values to zeros self.previous_pos_id_offsets_cuda *= 0 @@ -1654,6 +1845,34 @@ def previous_seq_slots_device(): gather_ids, dtype=torch.int, pin_memory=True), non_blocking=True) + # Update gather_ids for first_draft_requests on GPU (draft model only) + if self.is_draft_model and len(first_draft_seq_slots) > 0: + # Build tensors on CPU first, then copy to GPU to avoid implicit sync + num_first_draft = len(first_draft_seq_slots) + first_draft_seq_slots_cpu = torch.tensor(first_draft_seq_slots, + dtype=torch.int, + pin_memory=True) + first_draft_indices_cpu = torch.tensor( + first_draft_request_indices, + dtype=torch.int, + pin_memory=True) + + # Copy to pre-allocated GPU buffers + self.draft_seq_slots_buffer_cuda[:num_first_draft].copy_( + first_draft_seq_slots_cpu, non_blocking=True) + self.draft_request_indices_buffer_cuda[:num_first_draft].copy_( + first_draft_indices_cpu, non_blocking=True) + + # Extract accepted tokens for first_draft requests from device tensor + accepted_tokens = num_accepted_tokens_device[ + self.draft_seq_slots_buffer_cuda[:num_first_draft]] + # Update gather_ids: gather_id = base_gather_id + num_accepted_tokens + # (since gather_id = len(input_ids) - 1 - (max_draft_len - num_accepted)) + self.gather_ids_cuda[ + self. + draft_request_indices_buffer_cuda[: + num_first_draft]] += accepted_tokens + if not attn_metadata.is_cuda_graph: # Assumes seq lens do not change between CUDA graph invocations. This applies # to draft sequences too. This means that all draft sequences must be padded. @@ -1670,8 +1889,11 @@ def previous_seq_slots_device(): is_cuda_graph_during_warmup = self.is_warmup and attn_metadata.is_cuda_graph if cache_indirection_buffer is not None: #Copy cache indirection to local buffer with offsets changing: seq_slots[i] -> i + # Convert to GPU tensor to avoid implicit sync + gen_request_seq_slots_tensor = torch.tensor( + gen_request_seq_slots, dtype=torch.long, device='cuda') self.cache_indirection_attention[:num_generation_requests].copy_( - cache_indirection_buffer[gen_request_seq_slots]) + cache_indirection_buffer[gen_request_seq_slots_tensor]) if cache_indirection_buffer is not None or is_cuda_graph_during_warmup: attn_metadata.beam_width = self.max_beam_width else: @@ -2252,7 +2474,9 @@ def _prepare_inputs( attn_metadata: AttentionMetadata, spec_metadata: Optional[SpecMetadata] = None, new_tensors_device: Optional[SampleStateTensors] = None, - cache_indirection_buffer: Optional[torch.Tensor] = None): + cache_indirection_buffer: Optional[torch.Tensor] = None, + num_accepted_tokens_device: Optional[torch.Tensor] = None, + req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None): if self.mapping is not None and 'cp_type' in self.mapping.cp_config: cp_type = self.mapping.cp_config['cp_type'] if CpType.STAR == cp_type: @@ -2268,19 +2492,21 @@ def _prepare_inputs( return self._prepare_tp_inputs(scheduled_requests, kv_cache_manager, attn_metadata, spec_metadata, new_tensors_device, - cache_indirection_buffer) + cache_indirection_buffer, + num_accepted_tokens_device, + req_id_to_old_request) @torch.inference_mode() @with_model_extra_attrs(lambda self: self.model.extra_attrs) - def forward( - self, - scheduled_requests: ScheduledRequests, - resource_manager: ResourceManager, - new_tensors_device: Optional[SampleStateTensors] = None, - gather_context_logits: bool = False, - cache_indirection_buffer: Optional[torch.Tensor] = None, - spec_decoding_tensor: Optional[SpecDecodingTensor] = None, - ): + def forward(self, + scheduled_requests: ScheduledRequests, + resource_manager: ResourceManager, + new_tensors_device: Optional[SampleStateTensors] = None, + gather_context_logits: bool = False, + cache_indirection_buffer: Optional[torch.Tensor] = None, + spec_decoding_tensor: Optional[SpecDecodingTensor] = None, + num_accepted_tokens_device: Optional[torch.Tensor] = None, + req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None): kv_cache_manager = resource_manager.get_resource_manager( self.kv_cache_manager_key) @@ -2335,7 +2561,8 @@ def forward( inputs, gather_ids = self._prepare_inputs( padded_requests, kv_cache_manager, attn_metadata, spec_metadata, - new_tensors_device, cache_indirection_buffer) + new_tensors_device, cache_indirection_buffer, + num_accepted_tokens_device, req_id_to_old_request) self.iter_counter += 1 with with_shared_pool(self.cuda_graph_runner.get_graph_pool()): diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 3eb9181f6c1..5030590c2bf 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -35,10 +35,13 @@ from tensorrt_llm.mapping import CpType from tensorrt_llm.runtime.generation import CUASSERT +from ..attention_backend.trtllm import TrtllmAttention from ..distributed import Distributed from ..models.modeling_utils import DecoderModelForCausalLM from ..modules.decoder_layer import DecoderLayer from ..speculative.drafter import Drafter +from ..speculative.model_drafter import ModelDrafter +from ..speculative.mtp import SampleStateTensorsMTP from ..speculative.speculation_gate import SpeculationGate from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem from .guided_decoder import GuidedDecoder @@ -275,8 +278,23 @@ def __init__(self, if self.dist.pp_size > 1: self.event_loop = self._executor_loop_pp + elif self.disable_overlap_scheduler: + self.event_loop = self._executor_loop else: - self.event_loop = self._executor_loop if disable_overlap_scheduler else self._executor_loop_overlap + # TODO: Overlap scheduler is not supported for below cases: + # 1. non-CDL is used + # 2. non-TrtllmAttention attention backend is used + overlap_not_supported = self.drafter is not None and isinstance( + self.drafter, ModelDrafter) and ( + not self.drafter.use_static_draft_loop or not issubclass( + self.draft_model_engine.attn_backend, TrtllmAttention)) + + if overlap_not_supported: + logger.warning( + "Overlap scheduler is disabled for draft model engine with non-CDL or non-TrtllmAttention attention backend." + ) + self.disable_overlap_scheduler = True + self.event_loop = self._executor_loop_overlap if not overlap_not_supported else self._executor_loop if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"): self.event_loop = trace_func(self.event_loop) @@ -1060,14 +1078,11 @@ def _prepare_and_schedule_batch(self): 0 ] * max_total_draft_tokens if max_total_draft_tokens > 0 else [] - # When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch, - # we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet. - if not self.has_previous_draft_tokens: - # If speculation is off, this function sets py_draft_tokens to [] - # for all active requests. If it's on, we initialize py_draft_tokens - # with dummy draft tokens to make the scheduler aware of the fact - # that speculation is about to happen. - self._prepare_draft_requests() + # If speculation is off, this function sets py_draft_tokens to [] + # for all active requests. If it's on, we initialize py_draft_tokens + # with dummy draft tokens to make the scheduler aware of the fact + # that speculation is about to happen. + self._prepare_draft_requests() scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( ) @@ -1316,6 +1331,8 @@ def _executor_loop_overlap(self): with self._profiler() as profile_step: iter_start_time = time.time() iter_stats = None + target_inputs = None + previous_tensors_device = None can_forward = False if self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver else True while True: profile_step() @@ -1396,20 +1413,21 @@ def _executor_loop_overlap(self): self.guided_decoder.init_disagg_gen_requests() previous_tensors = self.previous_batch and self.previous_batch.sample_state - target_inputs = None - draft_outputs = None # If there are previous draft tokens, we need to update the target requests to accept some draft tokens. # When there's any accepted tokens, we can't directly use the previous batch's outputs in this iteration for the target model, # so we'll set the target model's input to None and skip updating the target requests after target model forward. use_previous_draft_tokens = self.has_previous_draft_tokens if self.drafter is not None and (self.use_spec_decode or use_previous_draft_tokens): - target_inputs, draft_outputs, draft_batch = self._handle_speculative_decoding( - scheduled_batch, previous_tensors) + target_inputs = self._handle_speculative_decoding( + scheduled_batch, previous_tensors, + previous_tensors_device) # Use the draft_model's outputs if we've launched the draft model. # Otherwise, use the previous batch's outputs. - if target_inputs is not None or use_previous_draft_tokens: + if (target_inputs is not None + and target_inputs.next_draft_tokens + is not None) or use_previous_draft_tokens: previous_tensors_device = target_inputs else: previous_tensors_device = self.previous_batch and self.previous_batch.sample_state and self.previous_batch.sample_state.device @@ -1417,10 +1435,7 @@ def _executor_loop_overlap(self): batch_outputs = self._forward_step(scheduled_batch, previous_tensors_device) - if target_inputs is not None: - self._process_draft_results(scheduled_batch, - draft_outputs, draft_batch) - elif self.previous_batch is not None and not use_previous_draft_tokens: + if self.previous_batch is not None: self._update_requests(self.previous_batch.sample_state) if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: @@ -1435,6 +1450,10 @@ def _executor_loop_overlap(self): (req, block_id, self.ctx_in_transmission_counter)) + if self.drafter is not None and self.use_spec_decode: + # Cleanup previous draft resources used in the draft model + self.drafter.cleanup_previous_draft_resources() + if self.guided_decoder is not None: # add_batch must be called again to have updated new tokens. self.guided_decoder.add_batch(scheduled_batch) @@ -1469,6 +1488,94 @@ def _executor_loop_overlap(self): self._kv_connector_terminate_requests() + def _accept_draft_tokens( + self, scheduled_batch: ScheduledRequests, + target_outputs: SampleStateTensors, + target_inputs: Optional[SampleStateTensors] + ) -> Tuple[SampleStateTensorsMTP, Optional[torch.Tensor]]: + """ + Prepare target device inputs after computing draft token acceptance. + + This function: + 1. If draft tokens exist: compares sampled tokens with draft tokens to compute acceptance + 2. If no draft tokens: directly uses the first sampled token + 3. Creates new_tokens by extracting accepted tokens per request + + Args: + scheduled_batch: The scheduled requests + target_outputs: Contains new_tokens [max_draft_len + 1, batch_size, beam_width] + or [1, batch_size, beam_width] if no draft tokens + target_inputs: Contains next_draft_tokens [batch_size, max_draft_len] + Returns: + Tuple of: + - SampleStateTensorsMTP with new_tokens set to accepted tokens, + new_tokens_lens and next_draft_tokens set to None + - num_accepted_tokens: [batch_size] tensor with acceptance counts per request, + or None if no draft tokens + """ + has_draft_tokens = target_inputs is not None and isinstance( + target_inputs, SampleStateTensorsMTP + ) and target_inputs.next_draft_tokens is not None + target_tokens = target_outputs.new_tokens # [max_draft_len + 1, batch_size, beam_width] or [1, batch_size, beam_width] + new_tokens = torch.zeros_like(target_tokens) + + # Squeeze the beam dimension (beam_width=1 for greedy or single beam) + target_tokens = target_tokens.squeeze( + -1) # [max_draft_len + 1, batch_size] or [1, batch_size] + + batch_size = target_tokens.shape[1] + device = target_tokens.device + # Compute number of accepted tokens per request + num_accepted_tokens = torch.zeros(batch_size, + dtype=torch.int32, + device=device) + # Handle case where there are no draft tokens + if has_draft_tokens: + # Draft tokens exist, compute acceptance + draft_tokens = target_inputs.next_draft_tokens # [batch_size, max_draft_len] + max_draft_len = draft_tokens.shape[1] + + # Compute number of accepted tokens per request + # Generation requests: compare with draft tokens to find acceptance + num_contexts = len(scheduled_batch.context_requests) + if batch_size > num_contexts: + # Use .T to transpose: [max_draft_len + 1, num_gens] -> [num_gens, max_draft_len + 1] + gen_target_tokens = target_tokens[:, + num_contexts:].T # [num_gens, max_draft_len + 1] + + # Compare draft tokens with target tokens to find acceptance + # Use cumprod to find the first rejection point + draft_tokens_gen = draft_tokens[ + num_contexts:, :].int() # [num_gens, max_draft_len] + num_accepted_tokens[num_contexts:] += torch.cumprod( + (draft_tokens_gen == gen_target_tokens[:, :max_draft_len] + ).int(), + dim=-1).sum(dim=1) + + # Vectorized extraction using advanced indexing (no GPU-CPU sync) + # Use num_accepted_tokens as indices to gather the right tokens + batch_indices = torch.arange(batch_size, device=device) + new_tokens[0, :, 0] = target_tokens[num_accepted_tokens, + batch_indices] + else: + # No draft tokens to accept, just use the first (and only) sampled token + batch_indices = torch.arange(batch_size, device=device) + new_tokens[0, :, 0] = target_tokens[0, batch_indices] + + # Create the updated SampleStateTensorsMTP + # new_tokens_lens and next_draft_tokens are left as None + result_tensors = SampleStateTensorsMTP( + new_tokens=new_tokens, + log_probs=target_outputs.log_probs, + new_tokens_lens=None, + next_draft_tokens=None) + + # Copy logits if available + if hasattr(target_outputs, 'logits'): + result_tensors.logits = target_outputs.logits + + return result_tensors, num_accepted_tokens + def _process_previous_batch(self): if self.kv_cache_transceiver and self.previous_batch.ctx_transmission_reqs: for req in self.previous_batch.ctx_transmission_reqs: @@ -2365,7 +2472,8 @@ def _remove_inflight_ids(self, scheduled_requests): for req in scheduled_requests.all_requests(): self.inflight_req_ids.erase(req.request_id) - def _handle_speculative_decoding(self, scheduled_batch, previous_tensors): + def _handle_speculative_decoding(self, scheduled_batch, previous_tensors, + target_inputs): with request_context(is_draft=self.draft_model_engine is not None, scheduled_requests=scheduled_batch): # Do an early checking to see if we need to forward the draft model. @@ -2375,20 +2483,25 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors): self.previous_batch is not None and self.use_spec_decode and self.drafter.should_forward_draft_model(scheduled_batch)) - if has_draft_batch or self.has_previous_draft_tokens: - self._update_requests(self.previous_batch.sample_state) - if self.has_previous_draft_tokens: - self._prepare_draft_requests() + new_target_inputs = None + if has_draft_batch: + target_outputs = self.previous_batch.sample_state and self.previous_batch.sample_state.device + assert target_outputs is not None, "target_outputs should not be None" + new_target_inputs, num_accepted_tokens_device = self._accept_draft_tokens( + scheduled_batch=scheduled_batch, + target_inputs=target_inputs, + target_outputs=target_outputs) if has_draft_batch: - target_inputs, draft_outputs, draft_batch = self.drafter.generate_draft_tokens_with_overlap( + self.drafter.generate_draft_tokens_with_overlap( scheduled_batch, self.resource_manager, - previous_tensors.device if previous_tensors else None) + previous_tensors.device if previous_tensors else None, + new_target_inputs, num_accepted_tokens_device) - self.has_previous_draft_tokens = target_inputs is not None and target_inputs.next_draft_tokens is not None + # Pad draft tokens to the max draft length for CUDA graph compatibility + self.has_previous_draft_tokens = new_target_inputs is not None and new_target_inputs.next_draft_tokens is not None else: self.has_previous_draft_tokens = False - target_inputs, draft_outputs, draft_batch = None, None, None # We are not running the draft model. Remove the draft tokens and turn off spec # decode so that the requests get handled correctly. # One corner case: when we have at least one context request, we have to keep spec @@ -2401,34 +2514,7 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors): for request in scheduled_batch.all_requests(): request.py_draft_tokens = [] - return target_inputs, draft_outputs, draft_batch - - def _process_draft_results(self, scheduled_batch, draft_outputs, - draft_batch): - """ - Append the draft tokens to the target requests, and clean up the draft resources. - """ - with request_context(is_draft=self.draft_model_engine is not None, - scheduled_requests=scheduled_batch): - req_id_to_old_request = { - req.py_request_id: req - for req in scheduled_batch.all_requests() - } - - if self.drafter.use_static_draft_loop: - self.drafter.process_static_draft_outputs( - draft_outputs, draft_batch, req_id_to_old_request) - elif draft_outputs is not None: - self.drafter.process_dynamic_draft_outputs( - draft_outputs, req_id_to_old_request) - - # Pad draft tokens to the max draft length. This is for CUDA graph compatibility. - self.drafter.pad_draft_tokens_for_cuda_graph(scheduled_batch) - # add_batch must be called again to restore to target requests with updated draft tokens. - if self.guided_decoder is not None: - self.guided_decoder.add_batch(scheduled_batch) - if hasattr(self.drafter, "guided_decoder"): - self.guided_decoder.rollback_draft_tokens() + return new_target_inputs def reset_prefix_cache(self): self.kv_cache_manager.reset_reuse_state() diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 01b5e6d0d87..e486d67201b 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -280,7 +280,9 @@ def _group_requests_by_strategy_key( ) for req_index, req in enumerate(requests): strategy = _request_strategy(req, vocab_size=vocab_size) - speculation_needs_probs = req.py_draft_logits is not None and strategy is not GREEDY + # In the overlap path, py_draft_logits is not updated yet, + # so we use get_draft_token_length() for the checking. + speculation_needs_probs = get_draft_token_length(req) > 0 and strategy is not GREEDY strategy_key = strategy_to_key(strategy, speculation_needs_probs) group_dict_entry = group_dict[(strategy_key, speculation_needs_probs)] group_dict_entry[0].append(req_index) diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index 30045cfef53..7f1e7d57756 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -28,7 +28,8 @@ # Place the tool function here to avoid circular import def get_draft_model_prompt(spec_dec_mode: SpeculativeDecodingMode, - request: LlmRequest) -> List[int]: + request: LlmRequest, + disable_overlap_scheduler: bool) -> List[int]: """ Can be used to modify prompts for speculative algorithms that need to update tokens before drafting. @@ -36,7 +37,11 @@ def get_draft_model_prompt(spec_dec_mode: SpeculativeDecodingMode, draft_input_tokens = request.get_tokens(0) if spec_dec_mode.is_eagle3() or spec_dec_mode.is_mtp_eagle(): # EAGLE3 always throws away the first token when processing draft inputs + if not disable_overlap_scheduler: + # Add a fake golden token here since the real one has not been generated. + draft_input_tokens.extend([0]) draft_input_tokens = draft_input_tokens[1:] + if request.is_context_init_state: # A draft request's prompt is its target request's prompt adding the first golden token. # Add a fake golden token here since the real one has not been generated. @@ -73,6 +78,7 @@ def __init__( self.draft_model_engine = draft_model_engine self.draft_seq_slot_manager = draft_seq_slot_manager self.spec_resource_manager = spec_resource_manager + self.disable_overlap_scheduler = True # Configuration self.spec_config = spec_config @@ -88,6 +94,15 @@ def __init__( assert guided_decoder is None assert spec_config._allow_greedy_draft_tokens + # Create accumulator for draft tokens in non-CDL mode + self.draft_tokens_accumulator: Dict[int, List[int]] = {} + # Track previous draft batch for overlap scheduling + self.previous_draft_batch: Optional[ScheduledRequests] = None + self.previous_draft_outputs: Optional[Any] = None + self.previous_scheduled_batch: Optional[ScheduledRequests] = None + # Map from request ID to original request + self.req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None + def _create_draft_request(self, request: LlmRequest, input_tokens: Optional[List]) -> LlmRequest: """Create a draft request with common parameters.""" @@ -109,6 +124,9 @@ def _create_draft_request(self, request: LlmRequest, def _initialize_draft_tokens(self, request: LlmRequest) -> Tuple[int, int]: """Initialize draft token tracking for a request.""" + if not self.disable_overlap_scheduler: + return self.max_draft_len, 0 + num_draft_tokens = len( request.py_last_draft_tokens ) if request.py_last_draft_tokens is not None else 0 @@ -173,18 +191,22 @@ def _create_draft_request_for_request( """Create a draft request based on the original request state.""" num_draft_tokens, num_accepted_tokens = self._initialize_draft_tokens( request) + input_tokens = get_draft_model_prompt(self.spec_config.spec_dec_mode, - request) + request, + self.disable_overlap_scheduler) is_eagle_style = self.spec_config.spec_dec_mode.is_eagle3( ) or self.spec_config.spec_dec_mode.is_mtp_eagle() # First time seeing this request - context request - if request.max_beam_num_tokens - 1 == request.py_prompt_len: + num_overlap_tokens = 0 if self.disable_overlap_scheduler else 1 + if request.max_beam_num_tokens - 1 + num_overlap_tokens == request.py_prompt_len: # This is the first time the draft model is seeing this request. # Prepare a context request. We discard the first token and take # the newly decoded one - this is the convention for EAGLE 2 and 3. - assert num_draft_tokens == 0 + if self.disable_overlap_scheduler: + assert num_draft_tokens == 0 return self._create_context_request(request, input_tokens) # For TRTLLM attention backend, we need to create a generation request for both no tokens accepted and tokens accepted @@ -256,7 +278,8 @@ def _prepare_draft_batch( # a prefill chunk on the last iteration. Now, we need to fill in the KV cache # for the draft model too. input_tokens = get_draft_model_prompt( - self.spec_config.spec_dec_mode, request) + self.spec_config.spec_dec_mode, request, + self.disable_overlap_scheduler) new_request = self._create_context_request( request, input_tokens) @@ -279,7 +302,8 @@ def _prepare_draft_batch( # that we want to do spec decoding, so no need to do anything else here. # This makes the perf for this case suboptimal, but that's OK - this is # a corner case for weird models like the llama 3.1 8b EAGLE3 implementation. - if request.max_beam_num_tokens - 1 >= self.draft_model_engine.max_seq_len: + num_overlap_tokens = 0 if self.disable_overlap_scheduler else 1 + if request.max_beam_num_tokens - 1 + num_overlap_tokens >= self.draft_model_engine.max_seq_len: continue draft_request = self._create_draft_request_for_request(request) @@ -302,12 +326,14 @@ def _should_disable_cuda_graph(self, is_first_draft_token: bool) -> bool: return False return self.spec_config.spec_dec_mode.needs_kv_cache_recompute() + @nvtx_range("forward_draft_model") def forward_draft_model( self, draft_batch: ScheduledRequests, resource_manager: ResourceManager, is_first_draft_token: bool, - previous_tensors: Optional[SampleStateTensors] = None + previous_tensors: Optional[SampleStateTensors] = None, + num_accepted_tokens_device: Optional[torch.Tensor] = None ) -> Dict[str, Any]: """Forward pass through the draft model.""" if self._should_disable_cuda_graph(is_first_draft_token): @@ -315,12 +341,16 @@ def forward_draft_model( outputs = self.draft_model_engine.forward( draft_batch, resource_manager, - new_tensors_device=previous_tensors) + new_tensors_device=previous_tensors, + num_accepted_tokens_device=num_accepted_tokens_device, + req_id_to_old_request=self.req_id_to_old_request) else: outputs = self.draft_model_engine.forward( draft_batch, resource_manager, - new_tensors_device=previous_tensors) + new_tensors_device=previous_tensors, + num_accepted_tokens_device=num_accepted_tokens_device, + req_id_to_old_request=self.req_id_to_old_request) # Handle d2t data if available. Static drafting loops should incorporate d2t # in their implementations. @@ -330,6 +360,7 @@ def forward_draft_model( return outputs + @nvtx_range("sample_async") def sample_async( self, draft_batch: ScheduledRequests, @@ -386,26 +417,37 @@ def update_requests( """Update requests with sample state.""" self.sampler.update_requests(sample_state, resource_manager) - def process_decoded_tokens(self, draft_batch: ScheduledRequests, - req_id_to_old_request: Dict[int, LlmRequest], - draft_position: int) -> List[LlmRequest]: + def process_decoded_tokens( + self, + draft_batch: ScheduledRequests, + draft_position: int, + cleanup_resources: bool = True) -> List[LlmRequest]: """Process decoded tokens and determine which requests to continue processing.""" new_requests = [] for req in draft_batch.all_requests(): - target_model_req = req_id_to_old_request[req.py_request_id] + target_model_req = self.req_id_to_old_request[req.py_request_id] if target_model_req.state != LlmRequestState.GENERATION_IN_PROGRESS: # This is a chunked prefill request and we have more prefill chunks # to process. Defer adding draft tokens until the whole prompt is processed. self.draft_seq_slot_manager.free_resources(req) continue - target_model_req.py_draft_tokens[draft_position - - 1] = req.get_last_tokens(0) + # Save tokens to accumulator instead of directly modifying target_model_req.py_draft_tokens + if req.py_request_id not in self.draft_tokens_accumulator: + self.draft_tokens_accumulator[ + req.py_request_id] = [0] * self.max_total_draft_tokens + self.draft_tokens_accumulator[req.py_request_id][ + draft_position - 1] = req.get_last_tokens(0) target_model_req.py_draft_logits = req.py_result.generation_logits # forwards Nones - if req.state != LlmRequestState.GENERATION_COMPLETE and draft_position < self.max_draft_len: + + # Check against the accumulator length instead + accumulated_tokens_count = len( + self.draft_tokens_accumulator[req.py_request_id]) + if req.state != LlmRequestState.GENERATION_COMPLETE and accumulated_tokens_count < target_model_req.py_draft_pages_allocated: new_requests.append(req) else: - self.draft_seq_slot_manager.free_resources(req) + if cleanup_resources: + self.draft_seq_slot_manager.free_resources(req) return new_requests @@ -442,10 +484,11 @@ def should_forward_draft_model(self, return False - def _convert_draft_tensors( + def _initialize_draft_tokens_for_target_inputs( self, scheduled_batch: ScheduledRequests, - new_tensors_device: Optional[SampleStateTensors] = None + target_inputs: Optional[SampleStateTensorsMTP] = None, + num_accepted_tokens_device: Optional[torch.Tensor] = None ) -> Optional[SampleStateTensorsMTP]: """ Convert tensors for draft model processing. @@ -457,67 +500,51 @@ def _convert_draft_tensors( Returns: SampleStateTensorsMTP: Converted tensors or None """ - if new_tensors_device is None: + if target_inputs is None: return None # Get device from the new_tokens tensor - device = new_tensors_device.new_tokens.device + device = target_inputs.new_tokens.device - # Use the same shape as new_tensors_device.new_tokens - new_tokens = torch.zeros_like(new_tensors_device.new_tokens) new_tokens_lens = None next_draft_tokens = None has_draft_tokens = False - batch_size = new_tokens.shape[1] + batch_size = target_inputs.new_tokens.shape[1] # Iterate through generation requests and copy tokens based on accepted draft tokens for request in scheduled_batch.all_requests(): - idx = request.py_seq_slot - if request.state != LlmRequestState.GENERATION_IN_PROGRESS: - num_accepted_tokens = request.py_num_accepted_draft_tokens - new_tokens[0, idx] = new_tensors_device.new_tokens[ - num_accepted_tokens, idx] - else: + if request.state == LlmRequestState.GENERATION_IN_PROGRESS: has_draft_tokens = True - num_accepted_tokens = request.py_num_accepted_draft_tokens - new_tokens[0, idx] = new_tensors_device.new_tokens[ - num_accepted_tokens, idx] if has_draft_tokens: # We already updated the target state, so the new_tokens_lens should be all ones. new_tokens_lens = torch.ones(batch_size, device=device) + new_tokens_lens += num_accepted_tokens_device next_draft_tokens = torch.zeros(batch_size, self.max_draft_len, device=device) + target_inputs.new_tokens_lens = new_tokens_lens + target_inputs.next_draft_tokens = next_draft_tokens + return target_inputs - # Create a new SampleStateTensorsMTP object with the additional fields - updated_tensors = SampleStateTensorsMTP( - new_tokens=new_tokens, - log_probs=new_tensors_device.log_probs, - new_tokens_lens=new_tokens_lens, - next_draft_tokens=next_draft_tokens) - - if hasattr(new_tensors_device, 'logits'): - updated_tensors.logits = new_tensors_device.logits - - return updated_tensors - - def _update_target_inputs_with_draft_tokens( + @nvtx_range("_update_draft_tokens_for_target_inputs") + def _update_draft_tokens_for_target_inputs( self, target_inputs: SampleStateTensorsMTP, draft_tensors: Optional[torch.Tensor], draft_position: int, - draft_length: int, draft_batch: ScheduledRequests, - req_id_to_old_request: Dict[int, LlmRequest]) -> None: + draft_length: int, draft_batch: ScheduledRequests) -> None: """ Update target inputs with new draft tokens from sample state. """ + if target_inputs.next_draft_tokens is None: + return + if draft_tensors is not None: for req_idx, request in enumerate(draft_batch.all_requests()): - # Skip prefill requests - if target_inputs.next_draft_tokens is None: + target_req = self.req_id_to_old_request[request.py_request_id] + if target_req.state != LlmRequestState.GENERATION_IN_PROGRESS: + # Skip prefill requests continue - # Get the index of the draft/target tokens in the device tensor draft_idx = req_idx if self.use_static_draft_loop else request.py_seq_slot - target_idx = req_id_to_old_request[ - request.py_request_id].py_seq_slot + target_idx = target_req.py_seq_slot target_inputs.new_tokens[draft_position + 1:draft_position + draft_length + 1, target_idx, 0] = draft_tensors[0:draft_length, @@ -527,8 +554,8 @@ def _update_target_inputs_with_draft_tokens( draft_length] = draft_tensors[0:draft_length, draft_idx] def _setup_draft_batch_and_resources( - self, scheduled_batch: ScheduledRequests - ) -> Tuple[Optional[ScheduledRequests], Optional[Dict[int, LlmRequest]]]: + self, + scheduled_batch: ScheduledRequests) -> Optional[ScheduledRequests]: """ Setup draft batch and prepare resources. @@ -536,39 +563,36 @@ def _setup_draft_batch_and_resources( scheduled_batch: The scheduled requests Returns: - Tuple of (draft_batch, req_id_to_old_request) or (None, None) if no batch + draft_batch or None if no batch """ draft_batch = self._prepare_draft_batch(scheduled_batch) if draft_batch.batch_size == 0: - return None, None + return None - req_id_to_old_request = { + self.req_id_to_old_request = { req.py_request_id: req for req in scheduled_batch.all_requests() } for request in draft_batch.all_requests(): - target_model_req = req_id_to_old_request[request.py_request_id] + target_model_req = self.req_id_to_old_request[request.py_request_id] if target_model_req.is_context_init_state: continue target_model_req.py_draft_tokens = [0] * self.max_draft_len self.draft_seq_slot_manager.prepare_resources(draft_batch) - return draft_batch, req_id_to_old_request + return draft_batch - def process_static_draft_outputs( - self, - outputs: dict[str, torch.Tensor] | tuple[torch.Tensor, SampleState], - draft_batch: ScheduledRequests, - req_id_to_old_request: Dict[int, LlmRequest]) -> None: + def process_static_draft_outputs(self, outputs: dict[str, torch.Tensor] + | tuple[torch.Tensor, SampleState], + draft_batch: ScheduledRequests) -> None: """ Process outputs from static draft loop, update target requests, and clean up resources. Args: outputs: The outputs from the draft model draft_batch: The draft batch that was processed - req_id_to_old_request: Mapping from draft request ID to original request """ if isinstance(outputs, dict): @@ -580,7 +604,7 @@ def process_static_draft_outputs( outputs[1].sampler_event.synchronize() for req_idx, req in enumerate(draft_batch.all_requests()): - target_model_req = req_id_to_old_request[req.py_request_id] + target_model_req = self.req_id_to_old_request[req.py_request_id] if target_model_req.state != LlmRequestState.GENERATION_IN_PROGRESS: # Chunked prefill request in progress; no need to append draft tokens continue @@ -592,27 +616,34 @@ def process_static_draft_outputs( py_draft_logits.append(draft_logits[token_idx][req_idx]) target_model_req.py_draft_logits = torch.stack(py_draft_logits) - # Clean up draft resources - for req in draft_batch.all_requests(): - self.draft_seq_slot_manager.free_resources(req) - def process_dynamic_draft_outputs( self, outputs: Any, - req_id_to_old_request: Dict[int, LlmRequest], - resource_manager: Optional[ResourceManager] = None) -> None: + resource_manager: Optional[ResourceManager] = None, + cleanup_resources: bool = True) -> None: """ Process outputs from dynamic draft loop, update target requests, and clean up resources. """ self.update_requests(outputs, resource_manager) + + # Create accumulator for draft tokens and process them self.process_decoded_tokens(outputs.scheduled_requests, - req_id_to_old_request, self.max_draft_len) + self.max_draft_len, cleanup_resources) + + # Update py_draft_tokens after processing + for req_id, tokens in self.draft_tokens_accumulator.items(): + target_model_req = self.req_id_to_old_request[req_id] + target_model_req.py_draft_tokens = tokens + @nvtx_range("_execute_draft_iteration") def _execute_draft_iteration( - self, draft_batch: ScheduledRequests, - resource_manager: ResourceManager, - previous_draft_state: Optional[SampleState], - cur_draft_layer_idx: int) -> Tuple[Any, Optional[SampleState]]: + self, + draft_batch: ScheduledRequests, + resource_manager: ResourceManager, + previous_draft_state: Optional[SampleState], + cur_draft_layer_idx: int, + num_accepted_tokens_device: Optional[torch.Tensor] = None + ) -> Tuple[Any, Optional[SampleState]]: self.update_cur_draft_layer_idx( cur_draft_layer_idx, resource_manager ) # Update the current draft layer index in the resource manager. @@ -622,7 +653,8 @@ def _execute_draft_iteration( resource_manager, is_first_draft_token=False, previous_tensors=previous_draft_state.device - if previous_draft_state else None) + if previous_draft_state else None, + num_accepted_tokens_device=num_accepted_tokens_device) if previous_draft_state is not None: self.update_requests(previous_draft_state, resource_manager) @@ -637,11 +669,51 @@ def _execute_draft_iteration( return outputs, sample_state + @nvtx_range("_process_previous_draft_results") + def _process_previous_draft_results(self) -> None: + """ + Process the previous draft batch results. + This should be called after the current draft forward to enable overlap scheduling. + """ + if (self.previous_draft_batch is None + or self.previous_draft_outputs is None + or self.previous_scheduled_batch is None): + return + + # Save current req_id_to_old_request temporarily + current_req_id_to_old_request = self.req_id_to_old_request + + # Set req_id_to_old_request for the previous batch, + # this will be used in process_static_draft_outputs and process_dynamic_draft_outputs + self.req_id_to_old_request = { + req.py_request_id: req + for req in self.previous_scheduled_batch.all_requests() + } + + if self.use_static_draft_loop: + self.process_static_draft_outputs(self.previous_draft_outputs, + self.previous_draft_batch) + elif self.previous_draft_outputs is not None: + self.process_dynamic_draft_outputs(self.previous_draft_outputs, + cleanup_resources=False) + + self.req_id_to_old_request = current_req_id_to_old_request + + # Pad draft tokens to the max draft length for CUDA graph compatibility + self.pad_draft_tokens_for_cuda_graph(self.previous_scheduled_batch) + + def cleanup_previous_draft_resources(self) -> None: + if self.previous_draft_batch is None: + return + + # Free draft_seq_slot_manager resources for all requests in the previous draft batch + for req in self.previous_draft_batch.all_requests(): + self.draft_seq_slot_manager.free_resources(req) + def _execute_draft_loop( self, draft_batch: ScheduledRequests, resource_manager: ResourceManager, - req_id_to_old_request: Dict[int, LlmRequest], target_inputs: Optional[SampleStateTensorsMTP] = None, num_draft_reqs: Optional[int] = None, initial_draft_state: Optional[SampleState] = None @@ -652,7 +724,6 @@ def _execute_draft_loop( Args: draft_batch: The draft batch to process resource_manager: The resource manager - req_id_to_old_request: Mapping from request ID to original request target_inputs: Optional target inputs to update (for overlap mode) num_draft_reqs: Number of draft requests (for overlap mode) initial_draft_state: The initial draft state from the first forward pass @@ -667,7 +738,8 @@ def _execute_draft_loop( draft_batch.context_requests = [] previous_draft_state = initial_draft_state - + # reset draft tokens accumulator + self.draft_tokens_accumulator = {} # Generate remaining draft tokens iteratively for i in range(self.max_draft_len - 1): if len(draft_batch.generation_requests) == 0: @@ -679,18 +751,16 @@ def _execute_draft_loop( # Update target inputs if provided (for overlap mode) if target_inputs is not None and num_draft_reqs is not None: draft_tensors = sample_state and sample_state.device and sample_state.device.new_tokens - self._update_target_inputs_with_draft_tokens( + self._update_draft_tokens_for_target_inputs( target_inputs, draft_tensors, draft_position=i + 1, draft_length=1, - draft_batch=draft_batch, - req_id_to_old_request=req_id_to_old_request) + draft_batch=draft_batch) if sample_state is not None and previous_draft_state is not None: new_requests = self.process_decoded_tokens( previous_draft_state.scheduled_requests, - req_id_to_old_request, draft_position=i + 1) else: new_requests = [] @@ -700,12 +770,14 @@ def _execute_draft_loop( return previous_draft_state + @nvtx_range("generate_draft_tokens_with_overlap") def generate_draft_tokens_with_overlap( - self, scheduled_batch: ScheduledRequests, - resource_manager: ResourceManager, - previous_tensors: Optional[SampleStateTensors] - ) -> Tuple[Optional[SampleStateTensorsMTP], Optional[Any], - Optional[ScheduledRequests]]: + self, + scheduled_batch: ScheduledRequests, + resource_manager: ResourceManager, + previous_tensors: Optional[SampleStateTensors], + target_inputs: Optional[SampleStateTensorsMTP], + num_accepted_tokens_device: Optional[torch.Tensor] = None) -> None: """ Generate draft tokens with overlap scheduling support. @@ -720,35 +792,42 @@ def generate_draft_tokens_with_overlap( - Updated target inputs or None - Draft sample state or None """ - draft_batch, req_id_to_old_request = self._setup_draft_batch_and_resources( - scheduled_batch) - if draft_batch is None: - return None, None, None - target_inputs = self._convert_draft_tensors(scheduled_batch, - previous_tensors) + self.disable_overlap_scheduler = False if target_inputs is None: - return None, None, None + return + + draft_batch = self._setup_draft_batch_and_resources(scheduled_batch) + if draft_batch is None: + return + + self._initialize_draft_tokens_for_target_inputs( + scheduled_batch, target_inputs, num_accepted_tokens_device) # Initial forward pass self.update_cur_draft_layer_idx( 0, resource_manager ) # Update the current draft layer index in the resource manager. - outputs = self.forward_draft_model(draft_batch, - resource_manager, - is_first_draft_token=True, - previous_tensors=previous_tensors) + outputs = self.forward_draft_model( + draft_batch, + resource_manager, + is_first_draft_token=True, + previous_tensors=previous_tensors, + num_accepted_tokens_device=num_accepted_tokens_device) + + # Process previous draft results after current forward pass + # This enables overlap scheduling: process old batch while new batch is prepared + self._process_previous_draft_results() num_draft_reqs = len(draft_batch.all_requests()) if self.use_static_draft_loop: # Only update target inputs, cleanup will be done in executor loop - self._update_target_inputs_with_draft_tokens( + self._update_draft_tokens_for_target_inputs( target_inputs, outputs["new_draft_tokens"], draft_position=0, draft_length=self.max_draft_len, - draft_batch=draft_batch, - req_id_to_old_request=req_id_to_old_request) + draft_batch=draft_batch) new_tokens_host = outputs["new_draft_tokens"].to(device='cpu', non_blocking=True) @@ -762,8 +841,13 @@ def generate_draft_tokens_with_overlap( host=SampleStateTensors(new_tokens=new_tokens_host), sampler_event=sampler_event) - return target_inputs, (outputs["draft_logits"], - sample_state), draft_batch + # Store current batch for processing in next iteration + self.previous_draft_batch = draft_batch + self.previous_draft_outputs = (outputs["draft_logits"], + sample_state) + self.previous_scheduled_batch = scheduled_batch + + return # Handle guided decoder and sampling for non-static loop if self.guided_decoder is not None: @@ -775,22 +859,25 @@ def generate_draft_tokens_with_overlap( # Update target inputs with first iteration results draft_tensors = draft_sample_state and draft_sample_state.device and draft_sample_state.device.new_tokens - self._update_target_inputs_with_draft_tokens( - target_inputs, - draft_tensors, - draft_position=0, - draft_length=1, - draft_batch=draft_batch, - req_id_to_old_request=req_id_to_old_request) + self._update_draft_tokens_for_target_inputs(target_inputs, + draft_tensors, + draft_position=0, + draft_length=1, + draft_batch=draft_batch) self.update_request_states(draft_batch) # Execute the iterative draft loop - previous_draft_state = self._execute_draft_loop( - draft_batch, resource_manager, req_id_to_old_request, target_inputs, - num_draft_reqs, draft_sample_state) + previous_draft_state = self._execute_draft_loop(draft_batch, + resource_manager, + target_inputs, + num_draft_reqs, + draft_sample_state) - return target_inputs, previous_draft_state, draft_batch + # Store current batch for processing in next iteration + self.previous_draft_batch = draft_batch + self.previous_draft_outputs = previous_draft_state + self.previous_scheduled_batch = scheduled_batch @nvtx_range("prepare_draft_tokens") def prepare_draft_tokens( @@ -805,6 +892,7 @@ def prepare_draft_tokens( scheduled_requests: The scheduled requests for this iteration resource_manager: The resource manager for this iteration """ + self.disable_overlap_scheduler = True if not self.draft_model_engine: raise ValueError("Draft model engine is not set") @@ -812,7 +900,7 @@ def prepare_draft_tokens( raise ValueError("Resource manager is required") try: - draft_batch, req_id_to_old_request = self._setup_draft_batch_and_resources( + draft_batch = self._setup_draft_batch_and_resources( scheduled_requests) if draft_batch is None: return @@ -827,8 +915,10 @@ def prepare_draft_tokens( is_first_draft_token=True) if self.use_static_draft_loop: - self.process_static_draft_outputs(outputs, draft_batch, - req_id_to_old_request) + self.process_static_draft_outputs(outputs, draft_batch) + # Clean up draft_seq_slot_manager resources + for req in draft_batch.all_requests(): + self.draft_seq_slot_manager.free_resources(req) return if self.guided_decoder is not None: @@ -841,14 +931,16 @@ def prepare_draft_tokens( # Execute the iterative draft loop previous_draft_state = self._execute_draft_loop( - draft_batch, resource_manager, req_id_to_old_request, None, - None, sample_state) + draft_batch, resource_manager, None, None, sample_state) # Final cleanup if previous_draft_state is not None: - self.process_dynamic_draft_outputs(previous_draft_state, - req_id_to_old_request) + self.process_dynamic_draft_outputs(previous_draft_state) + # Update py_draft_tokens after the loop completes + for req_id, tokens in self.draft_tokens_accumulator.items(): + target_model_req = self.req_id_to_old_request[req_id] + target_model_req.py_draft_tokens = tokens except Exception as e: traceback.print_exc() error_msg = str(e) diff --git a/tests/unittest/_torch/speculative/test_eagle3.py b/tests/unittest/_torch/speculative/test_eagle3.py index 26d54330b1d..094b2f8e506 100644 --- a/tests/unittest/_torch/speculative/test_eagle3.py +++ b/tests/unittest/_torch/speculative/test_eagle3.py @@ -122,21 +122,28 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str, tok_ids.append(llm_spec.tokenizer.encode(prompts)) sampling_params = SamplingParams(max_tokens=128, temperature=0) - for i in range(len(tok_ids)): - num_tokens = 0 - num_drafted = 0 - num_accepted = 0 - - for output in llm_spec.generate_async(tok_ids[i], - sampling_params, - streaming=True): - new_tokens = output.outputs[0].token_ids - num_drafted += max_draft_len - num_accepted += len(new_tokens) - num_tokens - 1 - num_tokens = len(new_tokens) - - accept_rate = num_accepted / num_drafted - assert accept_rate > 0.15 + run_ar_test = True + # Overlap scheduler is disabled for non-CDL or non-TrtllmAttention attention backend, + # so it would fallback to the non-overlap scheduler. + if not disable_overlap_scheduler and (attn_backend != "TRTLLM" + or not use_chain_drafter): + run_ar_test = False + if run_ar_test: + for i in range(len(tok_ids)): + num_tokens = 0 + num_drafted = 0 + num_accepted = 0 + + for output in llm_spec.generate_async(tok_ids[i], + sampling_params, + streaming=True): + new_tokens = output.outputs[0].token_ids + num_drafted += max_draft_len + num_accepted += len(new_tokens) - num_tokens - 1 + num_tokens = len(new_tokens) + + accept_rate = num_accepted / num_drafted + assert accept_rate > 0.15 # Output tests sampling_params = SamplingParams(max_tokens=10, temperature=0) @@ -515,7 +522,7 @@ def test_eagle3_cdl_sampling(disable_overlap_scheduler: bool): prompts = ["The president of the United States is"] - sampling_params = SamplingParams(max_tokens=20, temperature=0, top_p=0.9) + sampling_params = SamplingParams(max_tokens=20, temperature=1.0, top_p=0.9) llm_spec.generate(prompts, sampling_params) llm_spec.shutdown() From 30fac0399c013a66036cc4da68357caf5b64ec10 Mon Sep 17 00:00:00 2001 From: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com> Date: Sun, 9 Nov 2025 23:47:44 -0800 Subject: [PATCH 2/2] Resolve the comments Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 21 +--------- .../_torch/pyexecutor/py_executor_creator.py | 11 ++++++ .../_torch/speculative/model_drafter.py | 5 +-- .../_torch/speculative/test_eagle3.py | 38 ++++++++----------- 4 files changed, 30 insertions(+), 45 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 5030590c2bf..6e3e03b6c89 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -35,12 +35,10 @@ from tensorrt_llm.mapping import CpType from tensorrt_llm.runtime.generation import CUASSERT -from ..attention_backend.trtllm import TrtllmAttention from ..distributed import Distributed from ..models.modeling_utils import DecoderModelForCausalLM from ..modules.decoder_layer import DecoderLayer from ..speculative.drafter import Drafter -from ..speculative.model_drafter import ModelDrafter from ..speculative.mtp import SampleStateTensorsMTP from ..speculative.speculation_gate import SpeculationGate from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem @@ -278,23 +276,8 @@ def __init__(self, if self.dist.pp_size > 1: self.event_loop = self._executor_loop_pp - elif self.disable_overlap_scheduler: - self.event_loop = self._executor_loop else: - # TODO: Overlap scheduler is not supported for below cases: - # 1. non-CDL is used - # 2. non-TrtllmAttention attention backend is used - overlap_not_supported = self.drafter is not None and isinstance( - self.drafter, ModelDrafter) and ( - not self.drafter.use_static_draft_loop or not issubclass( - self.draft_model_engine.attn_backend, TrtllmAttention)) - - if overlap_not_supported: - logger.warning( - "Overlap scheduler is disabled for draft model engine with non-CDL or non-TrtllmAttention attention backend." - ) - self.disable_overlap_scheduler = True - self.event_loop = self._executor_loop_overlap if not overlap_not_supported else self._executor_loop + self.event_loop = self._executor_loop if self.disable_overlap_scheduler else self._executor_loop_overlap if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"): self.event_loop = trace_func(self.event_loop) @@ -1529,7 +1512,7 @@ def _accept_draft_tokens( num_accepted_tokens = torch.zeros(batch_size, dtype=torch.int32, device=device) - # Handle case where there are no draft tokens + if has_draft_tokens: # Draft tokens exist, compute acceptance draft_tokens = target_inputs.next_draft_tokens # [batch_size, max_draft_len] diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index be407a47ed4..7bb8bd769cd 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -25,6 +25,7 @@ from tensorrt_llm.quantization import QuantAlgo from ..attention_backend.interface import AttentionRuntimeFeatures +from ..attention_backend.trtllm import TrtllmAttention from ..distributed import MPIDist, TorchDist from ..speculative import (get_num_extra_kv_tokens, get_spec_drafter, get_spec_resource_manager) @@ -390,6 +391,16 @@ def drafting_loop_wrapper(model): else: draft_model_engine = None + # TODO: Overlap scheduler is not supported for below cases: + # 1. non-CDL is used + # 2. non-TrtllmAttention attention backend is used + if has_draft_model_engine and (not use_chain_drafter or not issubclass( + draft_model_engine.attn_backend, TrtllmAttention)): + logger.warning( + "Overlap scheduler is not supported for non-CDL or non-TrtllmAttention backend." + ) + llm_args.disable_overlap_scheduler = True + # PyTorchModelEngine modifies these fields, update them model_engine_max_seq_len = model_engine.max_seq_len net_max_seq_len = model_engine_max_seq_len diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index 7f1e7d57756..6064f915a77 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -440,10 +440,7 @@ def process_decoded_tokens( draft_position - 1] = req.get_last_tokens(0) target_model_req.py_draft_logits = req.py_result.generation_logits # forwards Nones - # Check against the accumulator length instead - accumulated_tokens_count = len( - self.draft_tokens_accumulator[req.py_request_id]) - if req.state != LlmRequestState.GENERATION_COMPLETE and accumulated_tokens_count < target_model_req.py_draft_pages_allocated: + if req.state != LlmRequestState.GENERATION_COMPLETE and draft_position < target_model_req.py_draft_pages_allocated: new_requests.append(req) else: if cleanup_resources: diff --git a/tests/unittest/_torch/speculative/test_eagle3.py b/tests/unittest/_torch/speculative/test_eagle3.py index 094b2f8e506..08919f43a96 100644 --- a/tests/unittest/_torch/speculative/test_eagle3.py +++ b/tests/unittest/_torch/speculative/test_eagle3.py @@ -122,28 +122,22 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str, tok_ids.append(llm_spec.tokenizer.encode(prompts)) sampling_params = SamplingParams(max_tokens=128, temperature=0) - run_ar_test = True - # Overlap scheduler is disabled for non-CDL or non-TrtllmAttention attention backend, - # so it would fallback to the non-overlap scheduler. - if not disable_overlap_scheduler and (attn_backend != "TRTLLM" - or not use_chain_drafter): - run_ar_test = False - if run_ar_test: - for i in range(len(tok_ids)): - num_tokens = 0 - num_drafted = 0 - num_accepted = 0 - - for output in llm_spec.generate_async(tok_ids[i], - sampling_params, - streaming=True): - new_tokens = output.outputs[0].token_ids - num_drafted += max_draft_len - num_accepted += len(new_tokens) - num_tokens - 1 - num_tokens = len(new_tokens) - - accept_rate = num_accepted / num_drafted - assert accept_rate > 0.15 + + for i in range(len(tok_ids)): + num_tokens = 0 + num_drafted = 0 + num_accepted = 0 + + for output in llm_spec.generate_async(tok_ids[i], + sampling_params, + streaming=True): + new_tokens = output.outputs[0].token_ids + num_drafted += max_draft_len + num_accepted += len(new_tokens) - num_tokens - 1 + num_tokens = len(new_tokens) + + accept_rate = num_accepted / num_drafted + assert accept_rate > 0.15 # Output tests sampling_params = SamplingParams(max_tokens=10, temperature=0)