From cd01510d009495415bd03364e35f29661b3653fe Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Tue, 7 Oct 2025 23:12:27 +0000 Subject: [PATCH 1/2] Draft of dynamic draft length stage 1. Signed-off-by: Zheyu Fu --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 41 ++++-- tensorrt_llm/_torch/speculative/drafter.py | 76 +++++++++- .../_torch/speculative/model_drafter.py | 42 +++++- tensorrt_llm/_torch/speculative/ngram.py | 28 +++- tensorrt_llm/llmapi/llm_args.py | 47 ++++++ .../speculative/test_draft_len_schedule.py | 138 ++++++++++++++++++ 6 files changed, 347 insertions(+), 25 deletions(-) create mode 100644 tests/unittest/_torch/speculative/test_draft_len_schedule.py diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 3ccd466eb63..0a441186e32 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -192,7 +192,8 @@ def __init__(self, # enqueue and _fetch_new_requests used data self.active = True self.max_beam_width = max_beam_width - self.max_draft_len = max_draft_len + self.max_draft_len = max_draft_len # Dynamic, if dynamic draft length is enabled (it will be dynamically updated before each scheduling step). Otherwise, it will be static. + self._static_max_draft_len = max_draft_len # Static, never changes self.max_num_tokens = model_engine.pytorch_backend_config.max_num_tokens self.print_log = model_engine.pytorch_backend_config.print_iter_log self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats @@ -1017,22 +1018,35 @@ def _prepare_and_schedule_batch(self): self._pad_attention_dp_dummy_request() if self.drafter is not None: - self.use_spec_decode = self.drafter.should_use_spec_decode( - self.active_requests, self.max_batch_size, - self.model_engine.max_num_tokens, - self.model_engine.spec_config.max_draft_len) - self.model_engine.enable_spec_decode = self.use_spec_decode + # Update draft_len based on schedule (if exists) + if self.drafter.draft_len_schedule is not None: + batch_size_input = len(self.active_requests) + + self.max_draft_len = self.drafter.get_draft_len_for_batch_size( + batch_size_input, + self.model_engine.spec_config.max_draft_len) + + self.drafter.update_max_draft_tokens(self.max_draft_len) + + # Check if draft_len=0 → immediately disable + if self.max_draft_len == 0: + self.use_spec_decode = False + self.model_engine.enable_spec_decode = False + else: + # Check should_use_spec_decode (max_concurrency logic) + self.use_spec_decode = self.drafter.should_use_spec_decode( + self.active_requests, self.max_batch_size, + self.model_engine.max_num_tokens, self.max_draft_len) + self.model_engine.enable_spec_decode = self.use_spec_decode - # Set up draft_tokens in active_requests, because they could be used in the scheduling stage. for request in self.active_requests: if request.state not in ( LlmRequestState.GENERATION_IN_PROGRESS, LlmRequestState.DISAGG_GENERATION_INIT): continue - max_draft_len = self.model_engine.spec_config.max_draft_len request.draft_tokens = [ 0 - ] * max_draft_len if max_draft_len > 0 else [] + ] * self.max_draft_len if self.max_draft_len > 0 else [] # When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch, # we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet. @@ -1203,11 +1217,10 @@ def _prepare_draft_requests(self): continue req.py_last_draft_tokens = req.py_draft_tokens - max_draft_len = self.model_engine.spec_config.max_draft_len - if max_draft_len > 0 and self.use_spec_decode: - req.py_draft_tokens = [0] * max_draft_len - req.py_draft_pages_allocated = max_draft_len + if self.max_draft_len > 0 and self.use_spec_decode: + req.py_draft_tokens = [0] * self.max_draft_len + req.py_draft_pages_allocated = self.max_draft_len else: req.py_draft_tokens = [] req.py_draft_pages_allocated = 0 @@ -1595,7 +1608,7 @@ def _pad_attention_dp_dummy_request(self): request_ids=[0], is_gen=True, prepare_resource=True, - max_num_draft_tokens=self.max_draft_len, + max_num_draft_tokens=self.static_max_draft_len, )[0] llm_request.is_attention_dp_dummy = True spec_resource_manager = self.resource_manager.get_resource_manager( diff --git a/tensorrt_llm/_torch/speculative/drafter.py b/tensorrt_llm/_torch/speculative/drafter.py index 485934f7b5c..3a007f90a7e 100644 --- a/tensorrt_llm/_torch/speculative/drafter.py +++ b/tensorrt_llm/_torch/speculative/drafter.py @@ -1,5 +1,8 @@ from abc import ABC, abstractmethod -from typing import List, Optional, final +from bisect import bisect_right +from typing import Dict, List, Optional, final + +from tensorrt_llm.logger import logger from ..pyexecutor.llm_request import LlmRequest, get_draft_token_length from ..pyexecutor.resource_manager import ResourceManager @@ -9,8 +12,20 @@ class Drafter(ABC): """Abstract base class for all drafter implementations.""" - def __init__(self, max_concurrency: Optional[int] = None) -> None: + def __init__( + self, + max_draft_tokens: int, + _static_max_draft_tokens: int, + max_concurrency: Optional[int] = None, + draft_len_schedule: Optional[Dict[int, int]] = None, + ) -> None: self.max_concurrency = max_concurrency + # Schedule is already validated and sorted by config validator + self.draft_len_schedule = draft_len_schedule + # It will be updated before each scheduling step by the executor if dynamic draft length is enabled, otherwise it stays the same as the static value + self.max_draft_tokens = max_draft_tokens + # original max_draft_tokens value from decode config + self._static_max_draft_tokens = max_draft_tokens @abstractmethod def prepare_draft_tokens( @@ -26,6 +41,43 @@ def prepare_draft_tokens( """ raise NotImplementedError + @final + def get_draft_len_for_batch_size(self, batch_size: int, + max_draft_len: int) -> int: + """ + Get the appropriate draft length for the given batch size. + + Args: + batch_size: Current batch size + max_draft_len: Maximum draft length (fallback if schedule not provided) + + Returns: + The draft length to use for this batch size + """ + if self.draft_len_schedule is None: + return max_draft_len + + # Binary search to find the largest threshold <= batch_size + # draft_len_schedule is already sorted by config validator + thresholds = list(self.draft_len_schedule.keys()) + + # bisect_right finds where to insert batch_size to keep list sorted + # The element before insertion point is the largest threshold <= batch_size + idx = bisect_right(thresholds, batch_size) + + if idx == 0: + # batch_size is smaller than smallest threshold (batch_size smaller than 1) + # This shouldn't happen in practice, but handle defensively + logger.warning( + f"get_draft_len_for_batch_size called with batch_size={batch_size} < 1. " + f"This is unexpected. Disabling speculation (returning draft_len=0)." + ) + return 0 + + # Return draft_len for the largest threshold <= batch_size + threshold = thresholds[idx - 1] + return self.draft_len_schedule[threshold] + @final def should_use_spec_decode(self, requests: List[LlmRequest], max_batch_size: int, max_num_tokens: int, @@ -59,14 +111,19 @@ def pad_draft_tokens_for_cuda_graph( """ Pad draft tokens to the max draft length for CUDA graph compatibility. + Note: Always pads to the STATIC max_draft_len (not dynamic) because + CUDA graphs are compiled with fixed tensor shapes based on max_draft_len. + Args: scheduled_requests: The scheduled requests to pad """ for req in scheduled_requests.generation_requests: - max_draft_tokens = self.max_draft_tokens + # Use static max_draft_tokens for CUDA graph compatibility + # CUDA graphs are sized for the maximum, even if we generate fewer tokens dynamically num_draft_tokens = get_draft_token_length(req) req.py_draft_tokens.extend( - 0 for _ in range(max_draft_tokens - num_draft_tokens)) + 0 for _ in range(self._static_max_draft_tokens - + num_draft_tokens)) def run_drafter_post( self, @@ -79,3 +136,14 @@ def run_drafter_post( this method can be overridden to do that. Used in SaveHiddenStatesDrafter (to ensure correct input_ids) """ + + def update_max_draft_tokens(self, new_max_draft_tokens: int) -> None: + """ + Used when dynamic draft length based on batch size is enabled. + Update max_draft_tokens in drafter and propagate to any dependent components. + Subclasses can override to propagate to their resource managers if needed. + + Args: + new_max_draft_tokens: The new max draft tokens + """ + self.max_draft_tokens = new_max_draft_tokens diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index 0a1a58d8575..1435598e982 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -52,14 +52,19 @@ def __init__( spec_resource_manager: Optional[BaseResourceManager] = None, guided_decoder: Optional[GuidedDecoder] = None, ): - super().__init__(spec_config.max_concurrency) - # Validate required parameters if draft_model_engine is None: raise ValueError("draft_model_engine cannot be None") if max_draft_tokens < 0: raise ValueError("max_draft_tokens must be >= 0") + super().__init__( + max_draft_tokens=spec_config.max_draft_len, + _static_max_draft_tokens=spec_config.max_draft_len, + max_concurrency=spec_config.max_concurrency, + draft_len_schedule=spec_config.draft_len_schedule, + ) + # Model and resource management self.draft_model_engine = draft_model_engine self.draft_seq_slot_manager = draft_seq_slot_manager @@ -67,7 +72,6 @@ def __init__( # Configuration self.spec_config = spec_config - self.max_draft_tokens = max_draft_tokens # Sampling self.sampler = sampler self.guided_decoder = guided_decoder @@ -78,6 +82,16 @@ def __init__( assert guided_decoder is None assert spec_config._allow_greedy_draft_tokens + # Currently dynamic draft length is not compatible with static draft loops + # TODO: support static draft loops with dynamic draft_len + if self.draft_len_schedule is not None: + raise ValueError( + "Dynamic draft length (draft_len_schedule) is not supported with " + "static draft loops (fused ChainDrafter/Eagle3). Static loops have " + "fixed iteration counts compiled into the model.\n" + "To use draft_len_schedule, please use ModelDrafter (2-model setup) " + "or NGramDrafter instead.") + def _create_draft_request(self, request: LlmRequest, input_tokens: Optional[List]) -> LlmRequest: """Create a draft request with common parameters.""" @@ -681,6 +695,17 @@ def generate_draft_tokens_with_overlap( - Updated target inputs or None - Draft sample state or None """ + # # Use pre-determined draft_len (set by executor BEFORE scheduling) + # if self.draft_len_schedule is not None and hasattr(self, '_current_batch_draft_len'): + # # Use pre-determined value from executor + # dynamic_draft_len = self._current_batch_draft_len + + # # Override max_draft_tokens to the dynamic value + # self.max_draft_tokens = dynamic_draft_len + + # # Note: If draft_len=0, this method won't be called anyway + # # (executor sets use_spec_decode=False and clears py_draft_tokens) + draft_batch, req_id_to_old_request = self._setup_draft_batch_and_resources( scheduled_batch) if draft_batch is None: @@ -769,6 +794,17 @@ def prepare_draft_tokens( if resource_manager is None: raise ValueError("Resource manager is required") + # # Use pre-determined draft_len (set by executor BEFORE scheduling) + # if self.draft_len_schedule is not None and hasattr(self, '_current_batch_draft_len'): + # # Use pre-determined value from executor + # dynamic_draft_len = self._current_batch_draft_len + + # # Override max_draft_tokens to the dynamic value + # self.max_draft_tokens = dynamic_draft_len + + # # Note: If draft_len=0, this method won't be called anyway + # # (executor sets use_spec_decode=False and clears py_draft_tokens) + try: draft_batch, req_id_to_old_request = self._setup_draft_batch_and_resources( scheduled_requests) diff --git a/tensorrt_llm/_torch/speculative/ngram.py b/tensorrt_llm/_torch/speculative/ngram.py index dc23270945b..16d420ccdd0 100644 --- a/tensorrt_llm/_torch/speculative/ngram.py +++ b/tensorrt_llm/_torch/speculative/ngram.py @@ -26,7 +26,7 @@ class NGramPoolManager(BaseResourceManager): Arguments: max_draft_tokens: int - The length maximum of draft tokens (can be understood as length maximum of output draft tokens). + The length maximum of draft tokens (can be understood as length maximum of output draft tokens). If dynamic draft length based on batch size is enabled, this value will be overridden by the dynamic draft_len each step. max_matching_ngram_size: int The length maximum of searching tokens (can be understood as length maximum of input tokens to search). @@ -51,7 +51,8 @@ class NGramPoolManager(BaseResourceManager): def __init__(self, spec_config: "NGramDecodingConfig", max_num_requests: int): - self.max_draft_tokens = spec_config.max_draft_len + self.max_draft_tokens = spec_config.max_draft_len # Dynamic, can be updated during execution + self._static_max_draft_tokens = spec_config.max_draft_len # Static, never changes self.max_matching_ngram_size = spec_config.max_matching_ngram_size self.is_keep_all = spec_config.is_keep_all self.is_use_oldest = spec_config.is_use_oldest # TODO: remove this if updating strategy is supported @@ -167,10 +168,16 @@ def __init__( spec_config: NGramDecodingConfig, ngram_pool_manager: NGramPoolManager = None, ): - super().__init__(spec_config.max_concurrency) + + super().__init__( + max_draft_tokens=spec_config.max_draft_len, + _static_max_draft_tokens=spec_config.max_draft_len, + max_concurrency=spec_config.max_concurrency, + draft_len_schedule=spec_config.draft_len_schedule, + ) + assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool." self.spec_config = spec_config - self.max_draft_tokens = spec_config.max_draft_len self.spec_resource_manager = ngram_pool_manager def prepare_draft_tokens( @@ -178,6 +185,14 @@ def prepare_draft_tokens( scheduled_requests: ScheduledRequests, resource_manager: Optional[ResourceManager] = None, ) -> None: + # # Override max_draft_tokens if dynamic draft_len is provided + # if self.draft_len_schedule is not None and hasattr(self, '_current_batch_draft_len'): + # # Use pre-determined value from executor (set BEFORE scheduling) + # self.max_draft_tokens = self._current_batch_draft_len + # # CRITICAL: Also update the pool manager's max_draft_tokens + # # This ensures pool building and start_index calculation use the correct dynamic length + # self.spec_resource_manager.max_draft_tokens = self._current_batch_draft_len + # Sort by request_id when py_batch_idx is None as a fallback. # This happens in the disagg case: for a set of new requests, we draft # before forward_step, so py_batch_idx is not assigned. @@ -197,3 +212,8 @@ def prepare_draft_tokens( request.py_max_new_tokens, ) request.py_draft_tokens = draft_tokens + + def update_max_draft_tokens(self, new_max_draft_tokens: int) -> None: + """Override to propagate to NGramPoolManager.""" + super().update_max_draft_tokens(new_max_draft_tokens) + self.spec_resource_manager.max_draft_tokens = new_max_draft_tokens diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 5bfcced072d..b9ab2a5a233 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -364,11 +364,58 @@ class DecodingBaseConfig(StrictBaseModel): # this value. Otherwise, speculation will always be on. max_concurrency: Optional[int] = None + # Developer interface: dynamically adjust draft length based on pre-scheduled batch size in runtime. + # Maps batch size to draft lengths. For example: + # {1: 4, 4: 2, 8: 0} means: + # - batch_size >= 1: use draft_len=4 + # - batch_size >= 4: use draft_len=2 + # - batch_size >= 8: use draft_len=0 (disable speculation) + # If not specified, this feature is disabled and will use max_draft_len for all batch sizes. + # draft_len_schedule is enforced to contain batch_size=1 and draft_len=max_draft_len for consistency + # for example, if max_draft_len=4, the schedule must contain {1: 4} + draft_len_schedule: Optional[dict[int, int]] = None + load_format: Optional[str] = None # If set, drafting uses greedy sampling, irrespective of sampling parameters. _allow_greedy_draft_tokens: bool = PrivateAttr(True) + @field_validator('draft_len_schedule') + @classmethod + def validate_draft_len_schedule_and_sort(cls, v, info): + """Validate and sort draft_len_schedule by batch size thresholds.""" + if v is not None: + # Validate values + for batch_size, draft_len in v.items(): + if batch_size < 1: + raise ValueError( + f"draft_len_schedule: batch size threshold must be >= 1, got {batch_size}" + ) + if draft_len < 0: + raise ValueError( + f"draft_len_schedule: draft length must be >= 0, got {draft_len}" + ) + + # Require batch_size=1 in schedule + if 1 not in v: + raise ValueError( + "draft_len_schedule must include batch_size=1. " + "All systems can have batch_size=1. Add {1: } to your schedule." + ) + + # Enforce schedule[1] == max_draft_len for consistency + max_draft_len = info.data.get('max_draft_len') + if max_draft_len is not None and v[1] != max_draft_len: + raise ValueError( + f"draft_len_schedule[1] must equal max_draft_len for consistency. " + f"Got schedule[1]={v[1]}, but max_draft_len={max_draft_len}. " + f"batch_size=1 should use maximum draft length.") + + # Return sorted dict (by batch size thresholds) + # This ensures efficient lookup + return dict(sorted(v.items(), key=lambda x: x[0])) + return v + @classmethod def from_dict(cls, data: dict): # dispatch to the correct decoding config diff --git a/tests/unittest/_torch/speculative/test_draft_len_schedule.py b/tests/unittest/_torch/speculative/test_draft_len_schedule.py new file mode 100644 index 00000000000..d1d28233ed4 --- /dev/null +++ b/tests/unittest/_torch/speculative/test_draft_len_schedule.py @@ -0,0 +1,138 @@ +""" +test_draft_len_schedule.py + +Tests for dynamic draft length (draft_len_schedule) feature - Stage 1. + +Stage 1 covers: +- NGramDrafter with dynamic draft_len +- ModelDrafter (2-model) with dynamic draft_len +- Draft-side compute savings only (target model still processes padded tokens) + +Not covered in Stage 1: +- ChainDrafter/Eagle3 static loops (Stage 3) +- Target model compute savings (Stage 2) +""" + +import os +import sys + +import pytest +import torch + +from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm.llmapi import (DraftTargetDecodingConfig, KvCacheConfig, + NGramDecodingConfig) + +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from utils.llm_data import llm_models_root +from utils.util import similar + + +# ============================================================================ +# P0-1: Correctness check - generation quality doesn't change +# ============================================================================ +@pytest.mark.parametrize("drafter_type,schedule", [ + ("ngram", { + 1: 3, + 4: 2, + 8: 1 + }), + ("model_drafter", { + 1: 3, + 4: 2, + 8: 1 + }), +]) +@pytest.mark.high_cuda_memory +def test_correctness_across_batch_sizes(drafter_type: str, schedule: dict): + """ + Test output correctness with various schedules and batch sizes. + + This is the primary correctness test that validates: + - Multiple different schedules work correctly + - Output matches non-speculative baseline + - Works across different batch size transitions + - Both NGram and ModelDrafter function correctly + + This test replaces separate basic tests for each drafter type. + """ + total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 + memory_required = 30 if drafter_type == "model_drafter" else 20 + if total_mem_gb < memory_required: + pytest.skip( + f"Not enough memory (need {memory_required}GB, have {total_mem_gb:.1f}GB)" + ) + + models_path = llm_models_root() + target_model = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct" + draft_model = f"{models_path}/llama-3.2-models/Llama-3.2-3B-Instruct" + + max_batch_size = 4 + max_draft_len = max(schedule.values()) # Use max from schedule + kv_cache_config = KvCacheConfig(enable_block_reuse=False, max_tokens=8192) + + llm_common_config = dict( + model=target_model, + backend='pytorch', + attn_backend="TRTLLM", + disable_overlap_scheduler=True, + max_batch_size=max_batch_size, + kv_cache_config=kv_cache_config, + max_num_tokens=2048, + ) + + if drafter_type == "ngram": + spec_config = NGramDecodingConfig( + max_draft_len=max_draft_len, + max_matching_ngram_size=2, + draft_len_schedule=schedule, + is_keep_all=True, + is_use_oldest=True, + is_public_pool=False, + ) + else: + spec_config = DraftTargetDecodingConfig( + max_draft_len=max_draft_len, + speculative_model_dir=str( + draft_model), # Use smaller 1B model as draft + draft_len_schedule=schedule, + ) + + prompts = [ + "The capital of France is", + "The president of the United States is", + "Machine learning is", + "The future of AI", + "What is the capital of Australia?", + "Explain in one sentence why the sky is blue.", + "Who wrote the book 'Pride and Prejudice'?", + "List three U.S. national holidays in the year 2025.", + "Who painted the Mona Lisa?", + ] + sampling_params = SamplingParams( + max_tokens=32, + temperature=0, + seed=42, + ) + + # With dynamic draft_len + llm_spec = LLM(**llm_common_config, speculative_config=spec_config) + results_spec = llm_spec.generate(prompts, sampling_params) + generated_text_spec = [result.outputs[0].text for result in results_spec] + llm_spec.shutdown() + + # Reference without speculation + llm_ref = LLM(**llm_common_config) + results_ref = llm_ref.generate(prompts, sampling_params) + generated_text_ref = [result.outputs[0].text for result in results_ref] + llm_ref.shutdown() + + # Verify correctness + if drafter_type == "ngram": + for text_spec, text_ref in zip(generated_text_spec, generated_text_ref): + assert similar(text_spec, text_ref), \ + f"NGram output should be similar. Got:\nSpec: {text_spec}\nRef: {text_ref}" + else: + for text_spec, text_ref in zip(generated_text_spec, generated_text_ref): + assert similar(text_spec, text_ref), \ + f"ModelDrafter output should be similar. Got:\nSpec: {text_spec}\nRef: {text_ref}" From 7568c4dcd576641deb1802710ddc73135a8ddc10 Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Wed, 8 Oct 2025 00:43:37 +0000 Subject: [PATCH 2/2] Clean. Signed-off-by: Zheyu Fu --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 11 ++++++----- .../_torch/pyexecutor/py_executor_creator.py | 4 +++- tensorrt_llm/_torch/speculative/drafter.py | 16 ++++++---------- tensorrt_llm/_torch/speculative/ngram.py | 13 +++---------- tensorrt_llm/llmapi/llm_args.py | 5 ++--- .../speculative/test_draft_len_schedule.py | 4 +--- 6 files changed, 21 insertions(+), 32 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 0a441186e32..42e71adc4fe 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -192,8 +192,8 @@ def __init__(self, # enqueue and _fetch_new_requests used data self.active = True self.max_beam_width = max_beam_width - self.max_draft_len = max_draft_len # Dynamic, if dynamic draft length is enabled (it will be dynamically updated before each scheduling step). Otherwise, it will be static. - self._static_max_draft_len = max_draft_len # Static, never changes + self.max_draft_len = max_draft_len # It's dynamic if draft_len_schedule is provided in spec_config (dynamic draft length based on runtime batch size is enabled). It's static in other cases. + self._static_max_draft_len = max_draft_len # It's always static self.max_num_tokens = model_engine.pytorch_backend_config.max_num_tokens self.print_log = model_engine.pytorch_backend_config.print_iter_log self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats @@ -1023,13 +1023,14 @@ def _prepare_and_schedule_batch(self): batch_size_input = len(self.active_requests) self.max_draft_len = self.drafter.get_draft_len_for_batch_size( - batch_size_input, - self.model_engine.spec_config.max_draft_len) + batch_size_input) self.drafter.update_max_draft_tokens(self.max_draft_len) # Check if draft_len=0 → immediately disable - if self.max_draft_len == 0: + # max_draft_len==0 is only possible when draft_len_schedule is provided + # for example, draft_len_schedule = {1:4, 4:2, 8:0}, batch_size >= 8 will set self.max_draft_len = 0 + if self.drafter.draft_len_schedule is not None and self.max_draft_len == 0: self.use_spec_decode = False self.model_engine.enable_spec_decode = False else: diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index aa8c6435dbe..191d3e34ed2 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -348,7 +348,9 @@ def create_py_executor( use_chain_drafter = ( guided_decoding_config is None and draft_spec_config._allow_greedy_draft_tokens - and pytorch_backend_config.attn_backend == "TRTLLM") + and pytorch_backend_config.attn_backend == "TRTLLM" + and draft_spec_config.draft_len_schedule is None + ) # currently ChainDrafter does not support dynamic draft length else: use_chain_drafter = False diff --git a/tensorrt_llm/_torch/speculative/drafter.py b/tensorrt_llm/_torch/speculative/drafter.py index 3a007f90a7e..a147f018f75 100644 --- a/tensorrt_llm/_torch/speculative/drafter.py +++ b/tensorrt_llm/_torch/speculative/drafter.py @@ -22,9 +22,9 @@ def __init__( self.max_concurrency = max_concurrency # Schedule is already validated and sorted by config validator self.draft_len_schedule = draft_len_schedule - # It will be updated before each scheduling step by the executor if dynamic draft length is enabled, otherwise it stays the same as the static value + # It's dynamic if draft_len_schedule is provided in spec_config (dynamic draft length based on runtime batch size is enabled). It's static in other cases. self.max_draft_tokens = max_draft_tokens - # original max_draft_tokens value from decode config + # It's always static self._static_max_draft_tokens = max_draft_tokens @abstractmethod @@ -42,20 +42,16 @@ def prepare_draft_tokens( raise NotImplementedError @final - def get_draft_len_for_batch_size(self, batch_size: int, - max_draft_len: int) -> int: + def get_draft_len_for_batch_size(self, batch_size: int) -> int: """ - Get the appropriate draft length for the given batch size. + Get the appropriate draft length for the given batch size using binary search. Args: - batch_size: Current batch size - max_draft_len: Maximum draft length (fallback if schedule not provided) + batch_size: Current batch size (has been sorted by config validator) Returns: The draft length to use for this batch size """ - if self.draft_len_schedule is None: - return max_draft_len # Binary search to find the largest threshold <= batch_size # draft_len_schedule is already sorted by config validator @@ -139,7 +135,7 @@ def run_drafter_post( def update_max_draft_tokens(self, new_max_draft_tokens: int) -> None: """ - Used when dynamic draft length based on batch size is enabled. + Used when draft_len_schedule is provided in spec_config (dynamic draft length based on runtime batch size is enabled) Update max_draft_tokens in drafter and propagate to any dependent components. Subclasses can override to propagate to their resource managers if needed. diff --git a/tensorrt_llm/_torch/speculative/ngram.py b/tensorrt_llm/_torch/speculative/ngram.py index 16d420ccdd0..c474eea7a70 100644 --- a/tensorrt_llm/_torch/speculative/ngram.py +++ b/tensorrt_llm/_torch/speculative/ngram.py @@ -26,7 +26,7 @@ class NGramPoolManager(BaseResourceManager): Arguments: max_draft_tokens: int - The length maximum of draft tokens (can be understood as length maximum of output draft tokens). If dynamic draft length based on batch size is enabled, this value will be overridden by the dynamic draft_len each step. + The length maximum of draft tokens (can be understood as length maximum of output draft tokens). If draft_len_schedule is provided in spec_config (dynamic draft length based on batch size is enabled), this value will be updated by the dynamic draft_len each step. max_matching_ngram_size: int The length maximum of searching tokens (can be understood as length maximum of input tokens to search). @@ -51,8 +51,8 @@ class NGramPoolManager(BaseResourceManager): def __init__(self, spec_config: "NGramDecodingConfig", max_num_requests: int): - self.max_draft_tokens = spec_config.max_draft_len # Dynamic, can be updated during execution - self._static_max_draft_tokens = spec_config.max_draft_len # Static, never changes + self.max_draft_tokens = spec_config.max_draft_len # It's dynamic if draft_len_schedule is provided in spec_config (dynamic draft length based on runtime batch size is enabled). It's static in other cases. + self._static_max_draft_tokens = spec_config.max_draft_len # It's always static self.max_matching_ngram_size = spec_config.max_matching_ngram_size self.is_keep_all = spec_config.is_keep_all self.is_use_oldest = spec_config.is_use_oldest # TODO: remove this if updating strategy is supported @@ -185,13 +185,6 @@ def prepare_draft_tokens( scheduled_requests: ScheduledRequests, resource_manager: Optional[ResourceManager] = None, ) -> None: - # # Override max_draft_tokens if dynamic draft_len is provided - # if self.draft_len_schedule is not None and hasattr(self, '_current_batch_draft_len'): - # # Use pre-determined value from executor (set BEFORE scheduling) - # self.max_draft_tokens = self._current_batch_draft_len - # # CRITICAL: Also update the pool manager's max_draft_tokens - # # This ensures pool building and start_index calculation use the correct dynamic length - # self.spec_resource_manager.max_draft_tokens = self._current_batch_draft_len # Sort by request_id when py_batch_idx is None as a fallback. # This happens in the disagg case: for a set of new requests, we draft diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 832713beaf6..6af0ae558c7 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -364,14 +364,13 @@ class DecodingBaseConfig(StrictBaseModel): # this value. Otherwise, speculation will always be on. max_concurrency: Optional[int] = None - # Developer interface: dynamically adjust draft length based on pre-scheduled batch size in runtime. + # Developer interface: dynamically adjust draft length based on active batch size in runtime. # Maps batch size to draft lengths. For example: # {1: 4, 4: 2, 8: 0} means: # - batch_size >= 1: use draft_len=4 # - batch_size >= 4: use draft_len=2 # - batch_size >= 8: use draft_len=0 (disable speculation) - # If not specified, this feature is disabled and will use max_draft_len for all batch sizes. - # draft_len_schedule is enforced to contain batch_size=1 and draft_len=max_draft_len for consistency + # draft_len_schedule is enforced to contain batch_size=1 and its according draft_len equals max_draft_len for consistency # for example, if max_draft_len=4, the schedule must contain {1: 4} draft_len_schedule: Optional[dict[int, int]] = None diff --git a/tests/unittest/_torch/speculative/test_draft_len_schedule.py b/tests/unittest/_torch/speculative/test_draft_len_schedule.py index d1d28233ed4..824a5d53042 100644 --- a/tests/unittest/_torch/speculative/test_draft_len_schedule.py +++ b/tests/unittest/_torch/speculative/test_draft_len_schedule.py @@ -28,9 +28,7 @@ from utils.util import similar -# ============================================================================ -# P0-1: Correctness check - generation quality doesn't change -# ============================================================================ +# Generation correctness check @pytest.mark.parametrize("drafter_type,schedule", [ ("ngram", { 1: 3,