diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index a08a45a8cb7..e18a8c8cf15 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -49,7 +49,9 @@ def __init__(self, engine: "PyTorchModelEngine"): Callable[[], Optional[torch.Tensor]]] = {} self.graph_metadata: Dict[Tuple[int, int, int], Dict[str, Any]] = {} self.memory_pool = engine._cuda_graph_mem_pool - self.padding_dummy_request: Optional["Request"] = None + # Stage 2: Pre-allocate one padding dummy per unique draft_len for zero overhead + self.padding_dummies: Dict[int, + "Request"] = {} # draft_len -> dummy_request self.shared_static_tensors: Dict[str, torch.Tensor] = {} if self.enabled: @@ -98,14 +100,38 @@ def max_possible_draft_len(self): def get_graph_key( self, batch_size, - spec_resource_manager: Optional[BaseResourceManager] = None): + spec_resource_manager: Optional[BaseResourceManager] = None, + runtime_draft_len: Optional[int] = None): + """ + Get the CUDA graph key for the given batch and draft configuration. + + Stage 2 Dynamic Draft Length: When runtime_draft_len is provided and + a draft_len_schedule exists, use the runtime draft length instead of + max_draft_len. This enables selecting the appropriate CUDA graph for + the current draft length. + + Args: + batch_size: Batch size for the graph + spec_resource_manager: Optional resource manager for spec decoding + runtime_draft_len: Optional runtime draft length (Stage 2 feature) + + Returns: + Tuple of (batch_size, draft_len, is_first_draft) + """ engine = self._get_engine() if engine.is_draft_model and spec_resource_manager is not None and isinstance( spec_resource_manager, Eagle3ResourceManager): draft_len = engine.original_max_draft_len if spec_resource_manager.is_first_draft else 0 key = (batch_size, draft_len, spec_resource_manager.is_first_draft) else: - draft_len = self.spec_config.max_draft_len if self.enable_spec_decode else 0 + # Stage 2: Use runtime draft length if provided and schedule exists + if (runtime_draft_len is not None and self.spec_config + and hasattr(self.spec_config, 'draft_len_schedule') + and self.spec_config.draft_len_schedule is not None): + draft_len = runtime_draft_len + else: + # Legacy behavior: use max_draft_len + draft_len = self.spec_config.max_draft_len if self.enable_spec_decode else 0 key = (batch_size, draft_len, False) return key @@ -135,10 +161,14 @@ def _get_engine(self) -> "PyTorchModelEngine": def maybe_get_cuda_graph( self, batch: ScheduledRequests, - spec_resource_manager: Optional[BaseResourceManager] = None): + spec_resource_manager: Optional[BaseResourceManager] = None, + runtime_draft_len: Optional[int] = None): """ Determines if the current batch can be run with a CUDA graph. + Stage 2 Dynamic Draft Length: When runtime_draft_len is provided, + selects the CUDA graph corresponding to that draft length. + Returns a tuple containing: - A boolean indicating if a graph can be used. - The attn_metadata for the graph, if applicable. @@ -168,7 +198,8 @@ def maybe_get_cuda_graph( if not self.enabled or not can_run_cuda_graph: return False, None, None, None - key = self.get_graph_key(batch_size, spec_resource_manager) + key = self.get_graph_key(batch_size, spec_resource_manager, + runtime_draft_len) if key in self.graphs: return True, self.graph_metadata[key][ @@ -342,29 +373,41 @@ def _get_padded_batch(self, batch: ScheduledRequests, if padding_size + batch.batch_size > engine.batch_size: return 0 - # No padding if it would create too many concurrent requests. - # This is not strictly required, but we should probably - # respect the requirement just in case that changes in the future. - if self.padding_dummy_request is None: + # Stage 2: Get or create padding dummy for current runtime draft length + # Pre-allocation strategy: Create one dummy per unique draft_len (very small cost) + # This avoids recreation overhead while preserving Stage 2 benefits + runtime_draft_len = engine.max_draft_len # Current draft_len for this iteration + + # Check if dummy for this draft_len already exists + if runtime_draft_len not in self.padding_dummies: available_blocks = kv_cache_manager.get_num_free_blocks() # No padding if not enough KV cache space if available_blocks < 1: return 0 - self.padding_dummy_request = kv_cache_manager.add_dummy_requests( - [CUDA_GRAPH_DUMMY_REQUEST_ID], + # Create dummy for this specific draft_len (happens once per unique draft_len) + # Use unique request ID per draft_len to avoid conflicts + dummy_req_id = CUDA_GRAPH_DUMMY_REQUEST_ID + runtime_draft_len + + dummy = kv_cache_manager.add_dummy_requests( + [dummy_req_id], is_gen=True, - max_num_draft_tokens=engine.runtime_draft_len, + max_num_draft_tokens=runtime_draft_len, use_mrope=engine.use_mrope, max_beam_width=engine.max_beam_width)[0] - self.padding_dummy_request.is_cuda_graph_dummy = True + dummy.is_cuda_graph_dummy = True + spec_res_mgr = resource_manager.get_resource_manager( ResourceManagerType.SPEC_RESOURCE_MANAGER) if spec_res_mgr: - spec_res_mgr.add_dummy_requests([CUDA_GRAPH_DUMMY_REQUEST_ID]) + spec_res_mgr.add_dummy_requests([dummy_req_id]) + + # Store for reuse + self.padding_dummies[runtime_draft_len] = dummy - batch.generation_requests.extend([self.padding_dummy_request] * - padding_size) + # Select the appropriate dummy for current draft_len (zero overhead!) + padding_dummy = self.padding_dummies[runtime_draft_len] + batch.generation_requests.extend([padding_dummy] * padding_size) return padding_size def _round_up_batch_size(self, batch_size: int) -> int: @@ -397,7 +440,7 @@ def clear(self): self.graphs.clear() self.graph_outputs.clear() self.graph_metadata.clear() - self.padding_dummy_request = None + self.padding_dummies.clear() del self.memory_pool self.memory_pool = None torch.cuda.empty_cache() diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 320b62e9bc2..86229ad401a 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -551,6 +551,124 @@ def _run_autotuner_warmup(self, resource_manager: ResourceManager): ) AutoTuner.get().print_profiling_cache() + def _get_runtime_draft_len(self, + scheduled_requests: ScheduledRequests) -> int: + """ + Get the runtime draft length for the current batch. + + Stage 2 Dynamic Draft Length: Returns the max_draft_tokens that was set + for this batch based on batch size. This is the value that all requests + in the batch will be padded to. + + Note: Individual requests may have fewer draft tokens due to NGram + mismatches or early stopping, but they will all be padded to this value + for CUDA graph compatibility. + + Args: + scheduled_requests: The scheduled requests for this iteration + + Returns: + The runtime max_draft_tokens for this batch, or 0 if spec decode disabled + """ + if not self.enable_spec_decode: + return 0 + + # Return the current max_draft_len (which was set based on batch size) + # This is the value that drafter.max_draft_tokens is set to + return self.runtime_draft_len + + def _get_graphs_to_capture(self) -> List[Tuple[int, int]]: + """ + Determine which (batch_size, draft_len) combinations to capture CUDA graphs for. + + Stage 2 Optimization: When draft_len_schedule is provided, only capture + graphs that will actually be used based on the schedule and batch size padding. + This avoids over-capturing and saves memory + warmup time. + + Returns: + List of (batch_size, draft_len) tuples to capture graphs for. + """ + spec_resource_manager = self.resource_manager.get_resource_manager( + ResourceManagerType.SPEC_RESOURCE_MANAGER) if hasattr( + self, 'resource_manager') else None + + # Draft model logic (unchanged) + if self.is_draft_model: + if self.model_is_wrapped and self.is_spec_decode and spec_resource_manager is not None: + from ..speculative.eagle3 import Eagle3ResourceManager + if isinstance(spec_resource_manager, Eagle3ResourceManager): + draft_len = self.original_max_draft_len + return [(bs, draft_len) + for bs in self._cuda_graph_batch_sizes] + draft_len = self.max_draft_len + return [(bs, draft_len) for bs in self._cuda_graph_batch_sizes] + + # Target model with schedule: compute exact reachable set + if (self.spec_config and hasattr(self.spec_config, 'draft_len_schedule') + and self.spec_config.draft_len_schedule is not None): + + graphs_needed = self._compute_reachable_graphs() + logger.info( + f"Stage 2 Dynamic Draft Length: Capturing {len(graphs_needed)} CUDA graphs " + f"(from schedule {self.spec_config.draft_len_schedule}): {sorted(graphs_needed)}" + ) + return sorted(graphs_needed) + + # Legacy: all batch sizes with same draft_len(s) + draft_lengths = [] + if (self.max_draft_len > 0 + and not self.spec_config.spec_dec_mode.use_one_engine() + and self.spec_config.max_concurrency is not None): + draft_lengths.append(0) + draft_lengths.append(self.max_draft_len) + + graphs = [] + for bs in self._cuda_graph_batch_sizes: + for draft_len in draft_lengths: + graphs.append((bs, draft_len)) + return graphs + + def _compute_reachable_graphs(self) -> set: + """ + Compute the set of (batch_size, draft_len) pairs that are actually reachable. + + Takes into account: + 1. Schedule: which draft_len for each actual batch size + 2. Batch padding: actual batch size might be padded to larger graph size + + Returns: + Set of (batch_size, draft_len) tuples + """ + graphs_needed = set() + schedule = self.spec_config.draft_len_schedule + + # For each possible actual batch size + for actual_bs in range(1, self.batch_size + 1): + # Determine draft_len for this actual batch size using same logic as drafter + # Use bisect_right to find the largest threshold <= actual_bs + from bisect import bisect_right + thresholds = list(schedule.keys()) + idx = bisect_right(thresholds, actual_bs) + if idx == 0: + draft_len = 0 # Defensive - shouldn't happen with valid schedules + else: + draft_len = schedule[thresholds[idx - 1]] + + # Determine padded batch size + padded_bs = self._round_up_to_graph_size(actual_bs) + + if padded_bs > 0: # Valid graph size exists + graphs_needed.add((padded_bs, draft_len)) + + return graphs_needed + + def _round_up_to_graph_size(self, actual_bs: int) -> int: + """Round up actual batch size to nearest CUDA graph batch size.""" + for graph_bs in sorted(self._cuda_graph_batch_sizes): + if actual_bs <= graph_bs: + return graph_bs + return 0 # Too large, no graph available + def _run_cuda_graph_warmup(self, resource_manager: ResourceManager): """Captures CUDA graphs for various batch sizes and draft lengths.""" if not (self.cuda_graph_runner.enabled @@ -572,55 +690,36 @@ def _capture_generation_cuda_graphs(self, spec_resource_manager = resource_manager.get_resource_manager( ResourceManagerType.SPEC_RESOURCE_MANAGER) + # Stage 2 Optimization: Only capture graphs that will actually be used + graphs_to_capture = self._get_graphs_to_capture() + # Reverse order so smaller graphs can reuse memory from larger ones - cuda_graph_batch_sizes = sorted(self._cuda_graph_batch_sizes, - reverse=True) - # Create CUDA graphs for different draft lengths - draft_lengths = [] - if self.is_draft_model: - if self.model_is_wrapped and self.is_spec_decode and spec_resource_manager is not None and isinstance( - spec_resource_manager, Eagle3ResourceManager): - # The CDL path uses draft_len > 0 for the number of iterations in the drafting loop. - draft_lengths.append(self.original_max_draft_len) - else: - draft_lengths.append(self.max_draft_len) - else: - # For non-draft model, we also capture the CUDA graph instance for draft length 0, - # so that when we disable spec decode at runtime, we can still run the captured graph. - # Note that for one engine mode, we are not able to turn off spec decode at runtime. - if (self.max_draft_len > 0 - and not self.spec_config.spec_dec_mode.use_one_engine() - # Assume that speculation is always on if the user didn't give us a max_concurrency - # value. This will save on memory. - and self.spec_config.max_concurrency is not None): - draft_lengths.append(0) - draft_lengths = [self.max_draft_len] - - for bs in cuda_graph_batch_sizes: - if bs > self.batch_size: + graphs_to_capture = sorted(graphs_to_capture, reverse=True) + + for batch_size, draft_len in graphs_to_capture: + if batch_size > self.batch_size: continue - for draft_len in draft_lengths: - warmup_request = self._create_cuda_graph_warmup_request( - resource_manager, bs, draft_len) - with self._release_batch_context(warmup_request, - resource_manager) as batch: - if batch is None: - # No KV cache space, cannot continue capturing graphs - return + warmup_request = self._create_cuda_graph_warmup_request( + resource_manager, batch_size, draft_len) + with self._release_batch_context(warmup_request, + resource_manager) as batch: + if batch is None: + # No KV cache space, cannot continue capturing graphs + return - logger.info( - f"Run generation-only CUDA graph warmup for batch size={bs}, draft_len={draft_len}" - ) + logger.info( + f"Run generation-only CUDA graph warmup for batch size={batch_size}, draft_len={draft_len}" + ) - self.enable_spec_decode = draft_len > 0 or self.is_draft_model - self._update_draft_inference_state_for_warmup( - batch, draft_len > 0, resource_manager) + self.enable_spec_decode = draft_len > 0 or self.is_draft_model + self._update_draft_inference_state_for_warmup( + batch, draft_len > 0, resource_manager) - self.forward(batch, - new_tensors_device=None, - resource_manager=resource_manager) - torch.cuda.synchronize() + self.forward(batch, + new_tensors_device=None, + resource_manager=resource_manager) + torch.cuda.synchronize() def _capture_piecewise_cuda_graphs(self, resource_manager: ResourceManager): """Captures piecewise CUDA graphs for context/prefill steps via torch.compile.""" @@ -695,8 +794,9 @@ def _create_warmup_request( spec_resource_manager = resource_manager.get_resource_manager( ResourceManagerType.SPEC_RESOURCE_MANAGER) + # Warmup uses static max since it happens before dynamic updates available_tokens = kv_cache_manager.get_num_available_tokens( - self.runtime_draft_len) + self.original_max_draft_len) available_blocks = kv_cache_manager.get_num_free_blocks() if num_tokens > self.max_num_tokens or num_tokens > available_tokens: return None @@ -736,11 +836,12 @@ def _create_warmup_request( if num_left_over_tokens > 0: ctx_token_nums.append(num_left_over_tokens) + # Warmup dummy requests use static max since warmup happens before dynamic updates ctx_requests = kv_cache_manager.add_dummy_requests( list(range(num_ctx_requests)), token_nums=ctx_token_nums, is_gen=False, - max_num_draft_tokens=self.runtime_draft_len, + max_num_draft_tokens=self.original_max_draft_len, use_mrope=self.use_mrope) if spec_resource_manager is not None: @@ -2275,8 +2376,12 @@ def forward( with self.cuda_graph_runner.pad_batch( scheduled_requests, resource_manager) as padded_requests: + # Stage 2 Dynamic Draft Length: Get runtime draft length from the batch + runtime_draft_len = self._get_runtime_draft_len( + padded_requests) if not self.is_draft_model else None + maybe_graph, maybe_attn_metadata, maybe_spec_metadata, key = self.cuda_graph_runner.maybe_get_cuda_graph( - padded_requests, spec_resource_manager) + padded_requests, spec_resource_manager, runtime_draft_len) if maybe_graph: attn_metadata = maybe_attn_metadata spec_metadata = maybe_spec_metadata diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 2ddc7c27800..b043608111f 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -195,6 +195,7 @@ def __init__(self, self.active = True self.max_beam_width = max_beam_width self.max_draft_len = max_draft_len + self._static_max_draft_len = max_draft_len self.max_num_tokens = model_engine.pytorch_backend_config.max_num_tokens self.print_log = model_engine.pytorch_backend_config.print_iter_log self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats @@ -1033,27 +1034,37 @@ def _prepare_and_schedule_batch(self): self._pad_attention_dp_dummy_request() if self.drafter is not None: - # Honor permanent disable flag based on rolling acceptance first - if getattr(self, 'speculation_permanently_disabled', False): + # Update draft_len based on schedule (if exists) + if self.drafter.draft_len_schedule is not None: + batch_size_input = len(self.active_requests) + + self.max_draft_len = self.drafter.get_draft_len_for_batch_size( + batch_size_input) + + self.drafter.update_max_draft_tokens(self.max_draft_len) + self.model_engine.max_draft_len = self.max_draft_len + + # Check if draft_len=0 → immediately disable + # max_draft_len==0 is only possible when draft_len_schedule is provided + # for example, draft_len_schedule = {1:4, 4:2, 8:0}, batch_size >= 8 will set self.max_draft_len = 0 + if self.drafter.draft_len_schedule is not None and self.max_draft_len == 0: + self.use_spec_decode = False + elif getattr(self, 'speculation_permanently_disabled', False): self.use_spec_decode = False else: self.use_spec_decode = self.drafter.should_use_spec_decode( self.active_requests, self.max_batch_size, - self.model_engine.max_num_tokens, - self.model_engine.spec_config.max_draft_len) - logger.debug(f"Use spec decode: {self.use_spec_decode}") + self.model_engine.max_num_tokens, self.max_draft_len) self.model_engine.enable_spec_decode = self.use_spec_decode - # Set up draft_tokens in active_requests, because they could be used in the scheduling stage. for request in self.active_requests: if request.state not in ( LlmRequestState.GENERATION_IN_PROGRESS, LlmRequestState.DISAGG_GENERATION_INIT): continue - max_draft_len = self.model_engine.spec_config.max_draft_len request.draft_tokens = [ 0 - ] * max_draft_len if max_draft_len > 0 else [] + ] * self.max_draft_len if self.max_draft_len > 0 else [] # When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch, # we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet. @@ -1224,11 +1235,10 @@ def _prepare_draft_requests(self): continue req.py_last_draft_tokens = req.py_draft_tokens - max_draft_len = self.model_engine.spec_config.max_draft_len - if max_draft_len > 0 and self.use_spec_decode: - req.py_draft_tokens = [0] * max_draft_len - req.py_draft_pages_allocated = max_draft_len + if self.max_draft_len > 0 and self.use_spec_decode: + req.py_draft_tokens = [0] * self.max_draft_len + req.py_draft_pages_allocated = self.max_draft_len else: req.py_draft_tokens = [] req.py_draft_pages_allocated = 0 @@ -1616,7 +1626,7 @@ def _pad_attention_dp_dummy_request(self): request_ids=[0], is_gen=True, prepare_resource=True, - max_num_draft_tokens=self.max_draft_len, + max_num_draft_tokens=self._static_max_draft_len, )[0] llm_request.is_attention_dp_dummy = True spec_resource_manager = self.resource_manager.get_resource_manager( diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index c8aafeff429..9799e9baa90 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -350,7 +350,8 @@ def create_py_executor( guided_decoding_config is None and draft_spec_config._allow_chain_drafter and draft_spec_config._allow_greedy_draft_tokens - and pytorch_backend_config.attn_backend == "TRTLLM") + and pytorch_backend_config.attn_backend == "TRTLLM" + and draft_spec_config.draft_len_schedule is None) logger.debug(f"USE CHAIN DRAFTER: {use_chain_drafter}") if use_chain_drafter: diff --git a/tensorrt_llm/_torch/speculative/drafter.py b/tensorrt_llm/_torch/speculative/drafter.py index 485934f7b5c..1372ec1b658 100644 --- a/tensorrt_llm/_torch/speculative/drafter.py +++ b/tensorrt_llm/_torch/speculative/drafter.py @@ -1,5 +1,8 @@ from abc import ABC, abstractmethod -from typing import List, Optional, final +from bisect import bisect_right +from typing import Dict, List, Optional, final + +from tensorrt_llm.logger import logger from ..pyexecutor.llm_request import LlmRequest, get_draft_token_length from ..pyexecutor.resource_manager import ResourceManager @@ -9,8 +12,19 @@ class Drafter(ABC): """Abstract base class for all drafter implementations.""" - def __init__(self, max_concurrency: Optional[int] = None) -> None: + def __init__( + self, + max_draft_tokens: int, + max_concurrency: Optional[int] = None, + draft_len_schedule: Optional[Dict[int, int]] = None, + ) -> None: self.max_concurrency = max_concurrency + # Schedule is already validated and sorted by config validator + self.draft_len_schedule = draft_len_schedule + # It's dynamic if draft_len_schedule is provided in spec_config (dynamic draft length based on runtime batch size is enabled). It's static in other cases. + self.max_draft_tokens = max_draft_tokens + # It's always static + self._static_max_draft_tokens = max_draft_tokens @abstractmethod def prepare_draft_tokens( @@ -26,6 +40,39 @@ def prepare_draft_tokens( """ raise NotImplementedError + @final + def get_draft_len_for_batch_size(self, batch_size: int) -> int: + """ + Get the appropriate draft length for the given batch size using binary search. + + Args: + batch_size: Current batch size (has been sorted by config validator) + + Returns: + The draft length to use for this batch size + """ + + # Binary search to find the largest threshold <= batch_size + # draft_len_schedule is already sorted by config validator + thresholds = list(self.draft_len_schedule.keys()) + + # bisect_right finds where to insert batch_size to keep list sorted + # The element before insertion point is the largest threshold <= batch_size + idx = bisect_right(thresholds, batch_size) + + if idx == 0: + # batch_size is smaller than smallest threshold (batch_size smaller than 1) + # This shouldn't happen in practice, but handle defensively + logger.warning( + f"get_draft_len_for_batch_size called with batch_size={batch_size} < 1. " + f"This is unexpected. Disabling speculation (returning draft_len=0)." + ) + return 0 + + # Return draft_len for the largest threshold <= batch_size + threshold = thresholds[idx - 1] + return self.draft_len_schedule[threshold] + @final def should_use_spec_decode(self, requests: List[LlmRequest], max_batch_size: int, max_num_tokens: int, @@ -57,16 +104,27 @@ def should_use_spec_decode(self, requests: List[LlmRequest], def pad_draft_tokens_for_cuda_graph( self, scheduled_requests: ScheduledRequests) -> None: """ - Pad draft tokens to the max draft length for CUDA graph compatibility. + Pad draft tokens for CUDA graph compatibility. + + CUDA graphs require all requests in a batch to have the same tensor shape. + Individual requests may generate fewer draft tokens (e.g., NGram mismatches, + early stopping), but all must be padded to the same length. Args: scheduled_requests: The scheduled requests to pad """ for req in scheduled_requests.generation_requests: - max_draft_tokens = self.max_draft_tokens num_draft_tokens = get_draft_token_length(req) - req.py_draft_tokens.extend( - 0 for _ in range(max_draft_tokens - num_draft_tokens)) + + if self.draft_len_schedule is not None: + # Pad to current iteration's (dynamic) max_draft_tokens if dynamic draft length is enabled + target_len = self.max_draft_tokens + else: + target_len = self._static_max_draft_tokens + + if num_draft_tokens < target_len: + req.py_draft_tokens.extend( + 0 for _ in range(target_len - num_draft_tokens)) def run_drafter_post( self, @@ -79,3 +137,15 @@ def run_drafter_post( this method can be overridden to do that. Used in SaveHiddenStatesDrafter (to ensure correct input_ids) """ + + def update_max_draft_tokens(self, new_max_draft_tokens: int) -> None: + """ + Update the dynamic max_draft_tokens based on current batch size. + + Used when draft_len_schedule is provided in spec_config (dynamic draft length + based on runtime batch size is enabled). + + Args: + new_max_draft_tokens: The new max draft tokens for the current batch size + """ + self.max_draft_tokens = new_max_draft_tokens diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index 974e13130e8..ce2c503d560 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -57,14 +57,18 @@ def __init__( spec_resource_manager: Optional[BaseResourceManager] = None, guided_decoder: Optional[GuidedDecoder] = None, ): - super().__init__(spec_config.max_concurrency) - # Validate required parameters if draft_model_engine is None: raise ValueError("draft_model_engine cannot be None") if max_draft_tokens < 0: raise ValueError("max_draft_tokens must be >= 0") + super().__init__( + max_draft_tokens=spec_config.max_draft_len, + max_concurrency=spec_config.max_concurrency, + draft_len_schedule=spec_config.draft_len_schedule, + ) + # Model and resource management self.draft_model_engine = draft_model_engine self.draft_seq_slot_manager = draft_seq_slot_manager @@ -72,7 +76,6 @@ def __init__( # Configuration self.spec_config = spec_config - self.max_draft_tokens = max_draft_tokens # Sampling self.sampler = sampler self.guided_decoder = guided_decoder @@ -83,6 +86,16 @@ def __init__( assert guided_decoder is None assert spec_config._allow_greedy_draft_tokens + # Currently dynamic draft length is not compatible with static draft loops + # TODO: support static draft loops with dynamic draft_len + if self.draft_len_schedule is not None: + raise ValueError( + "Dynamic draft length (draft_len_schedule) is not supported with " + "static draft loops (fused ChainDrafter/Eagle3). Static loops have " + "fixed iteration counts compiled into the model.\n" + "To use draft_len_schedule, please use ModelDrafter (2-model setup) " + "or NGramDrafter instead.") + def _create_draft_request(self, request: LlmRequest, input_tokens: Optional[List]) -> LlmRequest: """Create a draft request with common parameters.""" @@ -696,6 +709,7 @@ def generate_draft_tokens_with_overlap( - Updated target inputs or None - Draft sample state or None """ + draft_batch, req_id_to_old_request = self._setup_draft_batch_and_resources( scheduled_batch) if draft_batch is None: diff --git a/tensorrt_llm/_torch/speculative/ngram.py b/tensorrt_llm/_torch/speculative/ngram.py index dc23270945b..a9f7fb719e0 100644 --- a/tensorrt_llm/_torch/speculative/ngram.py +++ b/tensorrt_llm/_torch/speculative/ngram.py @@ -26,7 +26,7 @@ class NGramPoolManager(BaseResourceManager): Arguments: max_draft_tokens: int - The length maximum of draft tokens (can be understood as length maximum of output draft tokens). + The length maximum of draft tokens (can be understood as length maximum of output draft tokens). If draft_len_schedule is provided in spec_config (dynamic draft length based on batch size is enabled), this value will be updated by the dynamic draft_len each step. max_matching_ngram_size: int The length maximum of searching tokens (can be understood as length maximum of input tokens to search). @@ -167,10 +167,15 @@ def __init__( spec_config: NGramDecodingConfig, ngram_pool_manager: NGramPoolManager = None, ): - super().__init__(spec_config.max_concurrency) + + super().__init__( + max_draft_tokens=spec_config.max_draft_len, + max_concurrency=spec_config.max_concurrency, + draft_len_schedule=spec_config.draft_len_schedule, + ) + assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool." self.spec_config = spec_config - self.max_draft_tokens = spec_config.max_draft_len self.spec_resource_manager = ngram_pool_manager def prepare_draft_tokens( @@ -178,6 +183,7 @@ def prepare_draft_tokens( scheduled_requests: ScheduledRequests, resource_manager: Optional[ResourceManager] = None, ) -> None: + # Sort by request_id when py_batch_idx is None as a fallback. # This happens in the disagg case: for a set of new requests, we draft # before forward_step, so py_batch_idx is not assigned. @@ -197,3 +203,8 @@ def prepare_draft_tokens( request.py_max_new_tokens, ) request.py_draft_tokens = draft_tokens + + def update_max_draft_tokens(self, new_max_draft_tokens: int) -> None: + """Override to propagate to NGramPoolManager.""" + super().update_max_draft_tokens(new_max_draft_tokens) + self.spec_resource_manager.max_draft_tokens = new_max_draft_tokens diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 35d02350e9c..b9525bb0ca7 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -421,6 +421,16 @@ class DecodingBaseConfig(StrictBaseModel): # this value. Otherwise, speculation will always be on. max_concurrency: Optional[int] = None + # Developer interface: dynamically adjust draft length based on active batch size in runtime. + # Maps batch size to draft lengths. For example: + # {1: 4, 4: 2, 8: 0} means: + # - batch_size >= 1: use draft_len=4 + # - batch_size >= 4: use draft_len=2 + # - batch_size >= 8: use draft_len=0 (disable speculation) + # draft_len_schedule is enforced to contain batch_size=1 and its according draft_len equals max_draft_len for consistency + # for example, if max_draft_len=4, the schedule must contain {1: 4} + draft_len_schedule: Optional[dict[int, int]] = None + load_format: Optional[str] = None # PyTorch only. # Rolling average window size (N) for acceptance length across completed requests. @@ -458,6 +468,51 @@ def _validate_acceptance_length_threshold(cls, v: Optional[float]): # If set, drafting uses greedy sampling, irrespective of sampling parameters. _allow_greedy_draft_tokens: bool = PrivateAttr(True) + @field_validator('draft_len_schedule') + @classmethod + def validate_draft_len_schedule_and_sort(cls, v, info): + """Validate and sort draft_len_schedule by batch size thresholds.""" + if v is not None: + # Validate values + for batch_size, draft_len in v.items(): + if batch_size < 1: + raise ValueError( + f"draft_len_schedule: batch size threshold must be >= 1, got {batch_size}" + ) + if draft_len < 0: + raise ValueError( + f"draft_len_schedule: draft length must be >= 0, got {draft_len}" + ) + + # Require batch_size=1 in schedule + if 1 not in v: + raise ValueError( + "draft_len_schedule must include batch_size=1. " + "All systems can have batch_size=1. Add {1: } to your schedule." + ) + + # Enforce schedule[1] == max_draft_len for consistency + max_draft_len = info.data.get('max_draft_len') + if max_draft_len is not None and v[1] != max_draft_len: + raise ValueError( + f"draft_len_schedule[1] must equal max_draft_len for consistency. " + f"Got schedule[1]={v[1]}, but max_draft_len={max_draft_len}. " + f"batch_size=1 should use maximum draft length.") + + # Enforce all draft lengths <= max_draft_len + if max_draft_len is not None: + for batch_size, draft_len in v.items(): + if draft_len > max_draft_len: + raise ValueError( + f"draft_len_schedule: all draft lengths must be <= max_draft_len. " + f"Got draft_len={draft_len} for batch_size={batch_size}, " + f"but max_draft_len={max_draft_len}.") + + # Return sorted dict (by batch size thresholds) + # This ensures efficient lookup + return dict(sorted(v.items(), key=lambda x: x[0])) + return v + @classmethod def from_dict(cls, data: dict): # dispatch to the correct decoding config diff --git a/tests/unittest/_torch/speculative/test_draft_len_schedule.py b/tests/unittest/_torch/speculative/test_draft_len_schedule.py new file mode 100644 index 00000000000..b512969911e --- /dev/null +++ b/tests/unittest/_torch/speculative/test_draft_len_schedule.py @@ -0,0 +1,356 @@ +import os +import sys + +import pytest +import torch + +from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm.llmapi import (DraftTargetDecodingConfig, KvCacheConfig, + NGramDecodingConfig) + +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.llm_data import llm_models_root +from utils.util import similar + + +# # ============================================================================ +# # Fixture: Force single-worker mode for all tests in this module +# # ============================================================================ +@pytest.fixture(scope="module", autouse=True) +def enforce_single_worker(): + """Force single-worker mode for all tests in this module.""" + import os + os.environ["TLLM_WORKER_USE_SINGLE_PROCESS"] = "1" + yield + if "TLLM_WORKER_USE_SINGLE_PROCESS" in os.environ: + del os.environ["TLLM_WORKER_USE_SINGLE_PROCESS"] + + +# # ============================================================================ +# # test 1: Generation correctness check +# # ============================================================================ +@pytest.mark.parametrize("drafter_type,schedule", [ + ("ngram", { + 1: 3, + 4: 2, + 8: 1 + }), + ("model_drafter", { + 1: 3, + 4: 2, + 8: 1 + }), +]) +@pytest.mark.high_cuda_memory +@pytest.mark.no_xdist +def test_correctness_across_batch_sizes(drafter_type: str, schedule: dict): + """ + Test output correctness with various schedules and batch sizes. + + This is the primary correctness test that validates: + - Multiple different schedules work correctly + - Output with draft_len_schedule matches output with fixed draft_len + - Works across different batch size transitions + - Both NGram and ModelDrafter function correctly + """ + total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 + memory_required = 30 if drafter_type == "model_drafter" else 20 + if total_mem_gb < memory_required: + pytest.skip( + f"Not enough memory (need {memory_required}GB, have {total_mem_gb:.1f}GB)" + ) + + models_path = llm_models_root() + target_model = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct" + draft_model = f"{models_path}/llama-3.2-models/Llama-3.2-3B-Instruct" + + max_batch_size = 4 + max_draft_len = max(schedule.values()) # Use max from schedule + + kv_cache_config = KvCacheConfig(enable_block_reuse=False, + enable_partial_reuse=False, + max_tokens=8192) + + llm_common_config = dict( + model=target_model, + backend='pytorch', + attn_backend="TRTLLM", + disable_overlap_scheduler=True, + max_batch_size=max_batch_size, + kv_cache_config=kv_cache_config, + max_num_tokens=2048, + ) + + if drafter_type == "ngram": + spec_config = NGramDecodingConfig( + max_draft_len=max_draft_len, + max_matching_ngram_size=2, + draft_len_schedule=schedule, + is_keep_all=True, + is_use_oldest=True, + is_public_pool=False, + ) + else: + spec_config = DraftTargetDecodingConfig( + max_draft_len=max_draft_len, + speculative_model_dir=str(draft_model), + draft_len_schedule=schedule, + ) + + prompts = [ + "The capital of France is", + "The president of the United States is", + "Machine learning is", + "The future of AI", + "What is the capital of Australia?", + "Explain in one sentence why the sky is blue.", + "Who wrote the book 'Pride and Prejudice'?", + "List three U.S. national holidays in the year 2025.", + ] + + # Give each request different max_tokens so they finish at different times + # This creates batch size transitions to test draft_len_schedule + # Use deterministic sampling settings to maximize similarity with non-spec baseline + sampling_params_list = [ + SamplingParams( + max_tokens=i + 3, + temperature=0, + seed=42, + ignore_eos=True, # Prevent early stopping differences + top_k=1, + top_p=1.0, + ) for i in range(len(prompts)) + ] + + # With dynamic draft_len_schedule + llm_with_schedule = LLM(**llm_common_config, speculative_config=spec_config) + results_with_schedule = llm_with_schedule.generate(prompts, + sampling_params_list) + generated_text_with_schedule = [ + result.outputs[0].text for result in results_with_schedule + ] + llm_with_schedule.shutdown() + + # Reference: spec decode with fixed max_draft_len (no schedule) + if drafter_type == "ngram": + spec_config_fixed = NGramDecodingConfig( + max_draft_len=max_draft_len, + max_matching_ngram_size=2, + draft_len_schedule=None, # No schedule - fixed draft length + is_keep_all=True, + is_use_oldest=True, + is_public_pool=False, + ) + else: + spec_config_fixed = DraftTargetDecodingConfig( + max_draft_len=max_draft_len, + speculative_model_dir=str(draft_model), + draft_len_schedule=None, # No schedule - fixed draft length + ) + + llm_fixed = LLM(**llm_common_config, speculative_config=spec_config_fixed) + results_fixed = llm_fixed.generate(prompts, sampling_params_list) + generated_text_fixed = [result.outputs[0].text for result in results_fixed] + llm_fixed.shutdown() + + # Verify correctness: spec decode with schedule should match spec decode without schedule + for text_schedule, text_fixed in zip(generated_text_with_schedule, + generated_text_fixed): + assert similar(text_schedule, text_fixed), \ + f"{drafter_type} output with draft_len_schedule should match output with fixed draft_len. Got:\n" \ + f"With schedule: {text_schedule}\n" \ + f"Fixed: {text_fixed}" + + +# # ============================================================================ +# # test 2: Drafting side functionality check +# # ============================================================================ +@pytest.mark.parametrize("drafter_type,spec_config_factory", [ + ("ngram", lambda: NGramDecodingConfig( + max_draft_len=5, + max_matching_ngram_size=2, + draft_len_schedule={ + 1: 5, + 4: 4, + 5: 3, + 6: 2, + 7: 1, + 8: 0 + }, + )), + ("model_drafter", lambda: DraftTargetDecodingConfig( + max_draft_len=5, + speculative_model_dir=str(llm_models_root() / "llama-3.2-models" / + "Llama-3.2-3B-Instruct"), + draft_len_schedule={ + 1: 5, + 4: 4, + 5: 3, + 6: 2, + 7: 1, + 8: 0 + }, + )), +]) +@pytest.mark.high_cuda_memory +@pytest.mark.no_xdist +def test_draft_len_schedule_functionality(drafter_type: str, + spec_config_factory): + """ + Test that draft_len=0 in schedule properly disables speculation. + + Verifies: + - When schedule maps to draft_len=0, speculation is disabled + - System falls back to normal generation + - Output is still correct + - max_draft_tokens is set to 0 when batch_size triggers draft_len=0 + - Works for both NGram and ModelDrafter + """ + + total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 + if drafter_type == "model_drafter" and total_mem_gb < 30: + pytest.skip("Not enough memory for 2-model setup") + elif total_mem_gb < 20: + pytest.skip("Not enough memory") + max_batch_size = 8 + + llm_common_config = dict( + model=llm_models_root() / "llama-3.1-model" / "Meta-Llama-3.1-8B", + backend='pytorch', + attn_backend="TRTLLM", + disable_overlap_scheduler=True, + max_batch_size=max_batch_size, + max_num_tokens=2048, + ) + spec_config = spec_config_factory() + prompts = [f"Prompt {i}: The answer is" for i in range(8)] + # Give each request different max_tokens so they finish at different times + # This creates batch size transitions: 8 -> 7 -> 6 -> 5 -> 4 -> 3 -> 2 -> 1 + sampling_params_list = [ + SamplingParams( + max_tokens=20 * (i + 1), + temperature=0, + seed=42, + ignore_eos=True, # Prevent early stopping + top_k=1, + top_p=1.0, + ) for i in range(8) + ] + llm_spec = LLM(**llm_common_config, speculative_config=spec_config) + + drafter = llm_spec._executor.engine.drafter + executor = llm_spec._executor.engine + + iteration_data = [] + + # Store original methods + original_update_max_draft_tokens = drafter.update_max_draft_tokens + original_prepare_draft = drafter.prepare_draft_tokens + original_should_use_spec_decode = drafter.should_use_spec_decode + + # 1. Mock should_use_spec_decode to always return True + # This isolates draft_len_schedule testing from max_concurrency logic + def mock_should_use_spec_decode(*args, **kwargs): + return True # Always allow speculation (draft_len_schedule controls it) + + drafter.should_use_spec_decode = mock_should_use_spec_decode + + # 2. Instrument update_max_draft_tokens to capture when draft_len changes + def instrumented_update_max_draft_tokens(new_max_draft_tokens: int): + batch_size_active = len(executor.active_requests) + + original_update_max_draft_tokens(new_max_draft_tokens) + + iteration_data.append({ + 'batch_size_active': batch_size_active, + 'drafter_max_draft_tokens': new_max_draft_tokens, + 'use_spec_decode': + None, # Will be filled after _prepare_and_schedule_batch completes + 'actual_draft_lens': + [], # Will be filled after prepare_draft_tokens + }) + + drafter.update_max_draft_tokens = instrumented_update_max_draft_tokens + + # 3. Instrument prepare_draft_tokens - where actual draft tokens are produced + def instrumented_prepare_draft(scheduled_batch, resource_manager): + result = original_prepare_draft(scheduled_batch, resource_manager) + + if iteration_data and len(iteration_data) > 0: + iteration_data[-1]['use_spec_decode'] = executor.use_spec_decode + + actual_draft_lens = [] + for req in scheduled_batch.generation_requests: + draft_len = len( + req.py_draft_tokens) if req.py_draft_tokens else 0 + actual_draft_lens.append(draft_len) + + iteration_data[-1]['actual_draft_lens'] = actual_draft_lens + + # Filter out context-phase iterations (no generation requests = no draft tokens) + # This happens when all requests are still in prefill/context phase + if len(scheduled_batch.generation_requests) == 0: + iteration_data.pop() + + return result + + drafter.prepare_draft_tokens = instrumented_prepare_draft + + try: + llm_spec.generate(prompts, sampling_params_list) + finally: + # Restore methods in finally block to ensure cleanup even if generate() fails + drafter.update_max_draft_tokens = original_update_max_draft_tokens + drafter.prepare_draft_tokens = original_prepare_draft + drafter.should_use_spec_decode = original_should_use_spec_decode + llm_spec.shutdown() + + # ======================================================================== + # Verification Rule 1: batch_size_active → drafter_max_draft_tokens mapping + # ======================================================================== + # Hardcoded expected mapping (floor lookup): + # This matches what the authentic code does: len(executor.active_requests) + expected_mapping = { + 1: 5, + 2: 5, + 3: 5, + 4: 4, + 5: 3, + 6: 2, + 7: 1, + 8: 0, # >= 8 + } + + for idx, it in enumerate(iteration_data): + bs = it['batch_size_active'] + drafter_tokens = it['drafter_max_draft_tokens'] + + expected = expected_mapping.get(bs, 0) + assert drafter_tokens == expected, \ + f"Iter {idx}: batch_size_gen={bs} → expected {expected} tokens, got {drafter_tokens}" + + if drafter_tokens == 0: + assert not it['use_spec_decode'], \ + f"Iter {idx}: drafter_max_draft_tokens=0 but use_spec_decode={it['use_spec_decode']}" + + # ======================================================================== + # Verification Rule 2: actual_draft_lens (req.py_draft_tokens) vs drafter_max_draft_tokens + # ======================================================================== + if drafter_type == "ngram": + # NGram: all actual_draft_lens <= drafter_max_draft_tokens (because ngram drafting length not necessary to be the same as drafter_max_draft_tokens) + for idx, it in enumerate(iteration_data): + drafter_tokens = it['drafter_max_draft_tokens'] + for req_idx, actual_len in enumerate(it['actual_draft_lens']): + assert actual_len <= drafter_tokens, \ + f"Iter {idx}, req {req_idx}: NGram produced {actual_len} > max {drafter_tokens}" + + elif drafter_type == "model_drafter": + # ModelDrafter: the drafter should produce full draft_len all the time + for idx, it in enumerate(iteration_data): + drafter_tokens = it['drafter_max_draft_tokens'] + actual_lens = it['actual_draft_lens'] + + if drafter_tokens > 0: # Only count when speculation is active (draft_len > 0) + for req_idx, actual_len in enumerate(actual_lens): + assert actual_len == drafter_tokens, \ + f"Iter {idx}, req {req_idx}: ModelDrafter produced {actual_len} != max_draft_tokens {drafter_tokens}" diff --git a/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py b/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py index 3018c904256..dde37b836e4 100644 --- a/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py +++ b/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py @@ -189,7 +189,7 @@ def prepare_draft_tokens(self, resource_manager=None) -> None: return - drafter = _DummyDrafter(max_concurrency=6) + drafter = _DummyDrafter(max_draft_tokens=1, max_concurrency=6) # Compare min(len(requests), max_batch_size, token_cap) with max_concurrency