Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ def __init__(self,
# enqueue and _fetch_new_requests used data
self.active = True
self.max_beam_width = max_beam_width
self.max_draft_len = max_draft_len
self.max_draft_len = max_draft_len # It's dynamic if draft_len_schedule is provided in spec_config (dynamic draft length based on runtime batch size is enabled). It's static in other cases.
self._static_max_draft_len = max_draft_len # It's always static
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think the name of the variable is pretty self-explanatory, comment is unnecessary

self.max_num_tokens = model_engine.pytorch_backend_config.max_num_tokens
self.print_log = model_engine.pytorch_backend_config.print_iter_log
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
Expand Down Expand Up @@ -1017,22 +1018,36 @@ def _prepare_and_schedule_batch(self):
self._pad_attention_dp_dummy_request()

if self.drafter is not None:
self.use_spec_decode = self.drafter.should_use_spec_decode(
self.active_requests, self.max_batch_size,
self.model_engine.max_num_tokens,
self.model_engine.spec_config.max_draft_len)
self.model_engine.enable_spec_decode = self.use_spec_decode
# Update draft_len based on schedule (if exists)
if self.drafter.draft_len_schedule is not None:
batch_size_input = len(self.active_requests)

self.max_draft_len = self.drafter.get_draft_len_for_batch_size(
batch_size_input)

self.drafter.update_max_draft_tokens(self.max_draft_len)

# Check if draft_len=0 → immediately disable
# max_draft_len==0 is only possible when draft_len_schedule is provided
# for example, draft_len_schedule = {1:4, 4:2, 8:0}, batch_size >= 8 will set self.max_draft_len = 0
if self.drafter.draft_len_schedule is not None and self.max_draft_len == 0:
self.use_spec_decode = False
self.model_engine.enable_spec_decode = False
else:
# Check should_use_spec_decode (max_concurrency logic)
self.use_spec_decode = self.drafter.should_use_spec_decode(
self.active_requests, self.max_batch_size,
self.model_engine.max_num_tokens, self.max_draft_len)
self.model_engine.enable_spec_decode = self.use_spec_decode

# Set up draft_tokens in active_requests, because they could be used in the scheduling stage.
for request in self.active_requests:
if request.state not in (
LlmRequestState.GENERATION_IN_PROGRESS,
LlmRequestState.DISAGG_GENERATION_INIT):
continue
max_draft_len = self.model_engine.spec_config.max_draft_len
request.draft_tokens = [
0
] * max_draft_len if max_draft_len > 0 else []
] * self.max_draft_len if self.max_draft_len > 0 else []

# When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch,
# we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet.
Expand Down Expand Up @@ -1203,11 +1218,10 @@ def _prepare_draft_requests(self):
continue

req.py_last_draft_tokens = req.py_draft_tokens
max_draft_len = self.model_engine.spec_config.max_draft_len

if max_draft_len > 0 and self.use_spec_decode:
req.py_draft_tokens = [0] * max_draft_len
req.py_draft_pages_allocated = max_draft_len
if self.max_draft_len > 0 and self.use_spec_decode:
req.py_draft_tokens = [0] * self.max_draft_len
req.py_draft_pages_allocated = self.max_draft_len
else:
req.py_draft_tokens = []
req.py_draft_pages_allocated = 0
Expand Down Expand Up @@ -1595,7 +1609,7 @@ def _pad_attention_dp_dummy_request(self):
request_ids=[0],
is_gen=True,
prepare_resource=True,
max_num_draft_tokens=self.max_draft_len,
max_num_draft_tokens=self.static_max_draft_len,
)[0]
llm_request.is_attention_dp_dummy = True
spec_resource_manager = self.resource_manager.get_resource_manager(
Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,9 @@ def create_py_executor(
use_chain_drafter = (
guided_decoding_config is None
and draft_spec_config._allow_greedy_draft_tokens
and pytorch_backend_config.attn_backend == "TRTLLM")
and pytorch_backend_config.attn_backend == "TRTLLM"
and draft_spec_config.draft_len_schedule is None
) # currently ChainDrafter does not support dynamic draft length
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary comment. I think it's clear from the context that all the skips are for stuff that the ChainDrafter does not support

else:
use_chain_drafter = False

Expand Down
72 changes: 68 additions & 4 deletions tensorrt_llm/_torch/speculative/drafter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from abc import ABC, abstractmethod
from typing import List, Optional, final
from bisect import bisect_right
from typing import Dict, List, Optional, final

from tensorrt_llm.logger import logger

from ..pyexecutor.llm_request import LlmRequest, get_draft_token_length
from ..pyexecutor.resource_manager import ResourceManager
Expand All @@ -9,8 +12,20 @@
class Drafter(ABC):
"""Abstract base class for all drafter implementations."""

def __init__(self, max_concurrency: Optional[int] = None) -> None:
def __init__(
self,
max_draft_tokens: int,
_static_max_draft_tokens: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused argument. Both fields are initialized with max_draft_tokens, which makes sense

max_concurrency: Optional[int] = None,
draft_len_schedule: Optional[Dict[int, int]] = None,
) -> None:
self.max_concurrency = max_concurrency
# Schedule is already validated and sorted by config validator
self.draft_len_schedule = draft_len_schedule
# It's dynamic if draft_len_schedule is provided in spec_config (dynamic draft length based on runtime batch size is enabled). It's static in other cases.
self.max_draft_tokens = max_draft_tokens
# It's always static
self._static_max_draft_tokens = max_draft_tokens
Comment on lines +15 to +28
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Fix assignment of _static_max_draft_tokens.

Line 28 assigns self._static_max_draft_tokens = max_draft_tokens, which sets the "always static" value to the potentially dynamic max_draft_tokens parameter. This contradicts the comment on line 27 stating "It's always static" and breaks the intended separation between dynamic and static max draft tokens.

Apply this diff to use the correct parameter:

     def __init__(
         self,
         max_draft_tokens: int,
         _static_max_draft_tokens: int,
         max_concurrency: Optional[int] = None,
         draft_len_schedule: Optional[Dict[int, int]] = None,
     ) -> None:
         self.max_concurrency = max_concurrency
         # Schedule is already validated and sorted by config validator
         self.draft_len_schedule = draft_len_schedule
         # It's dynamic if draft_len_schedule is provided in spec_config (dynamic draft length based on runtime batch size is enabled). It's static in other cases.
         self.max_draft_tokens = max_draft_tokens
         # It's always static
-        self._static_max_draft_tokens = max_draft_tokens
+        self._static_max_draft_tokens = _static_max_draft_tokens
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def __init__(
self,
max_draft_tokens: int,
_static_max_draft_tokens: int,
max_concurrency: Optional[int] = None,
draft_len_schedule: Optional[Dict[int, int]] = None,
) -> None:
self.max_concurrency = max_concurrency
# Schedule is already validated and sorted by config validator
self.draft_len_schedule = draft_len_schedule
# It's dynamic if draft_len_schedule is provided in spec_config (dynamic draft length based on runtime batch size is enabled). It's static in other cases.
self.max_draft_tokens = max_draft_tokens
# It's always static
self._static_max_draft_tokens = max_draft_tokens
def __init__(
self,
max_draft_tokens: int,
_static_max_draft_tokens: int,
max_concurrency: Optional[int] = None,
draft_len_schedule: Optional[Dict[int, int]] = None,
) -> None:
self.max_concurrency = max_concurrency
# Schedule is already validated and sorted by config validator
self.draft_len_schedule = draft_len_schedule
# It's dynamic if draft_len_schedule is provided in spec_config (dynamic draft length based on runtime batch size is enabled). It's static in other cases.
self.max_draft_tokens = max_draft_tokens
# It's always static
self._static_max_draft_tokens = _static_max_draft_tokens
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/speculative/drafter.py around lines 15 to 28, the
constructor incorrectly sets self._static_max_draft_tokens = max_draft_tokens (a
possibly dynamic value); change the assignment to use the provided
_static_max_draft_tokens parameter instead (self._static_max_draft_tokens =
_static_max_draft_tokens) so the "always static" value remains the intended
static parameter.


@abstractmethod
def prepare_draft_tokens(
Expand All @@ -26,6 +41,39 @@ def prepare_draft_tokens(
"""
raise NotImplementedError

@final
def get_draft_len_for_batch_size(self, batch_size: int) -> int:
"""
Get the appropriate draft length for the given batch size using binary search.
Args:
batch_size: Current batch size (has been sorted by config validator)
Returns:
The draft length to use for this batch size
"""

# Binary search to find the largest threshold <= batch_size
# draft_len_schedule is already sorted by config validator
thresholds = list(self.draft_len_schedule.keys())

# bisect_right finds where to insert batch_size to keep list sorted
# The element before insertion point is the largest threshold <= batch_size
idx = bisect_right(thresholds, batch_size)

if idx == 0:
# batch_size is smaller than smallest threshold (batch_size smaller than 1)
# This shouldn't happen in practice, but handle defensively
logger.warning(
f"get_draft_len_for_batch_size called with batch_size={batch_size} < 1. "
f"This is unexpected. Disabling speculation (returning draft_len=0)."
)
return 0

# Return draft_len for the largest threshold <= batch_size
threshold = thresholds[idx - 1]
return self.draft_len_schedule[threshold]

@final
def should_use_spec_decode(self, requests: List[LlmRequest],
max_batch_size: int, max_num_tokens: int,
Expand Down Expand Up @@ -59,14 +107,19 @@ def pad_draft_tokens_for_cuda_graph(
"""
Pad draft tokens to the max draft length for CUDA graph compatibility.
Note: Always pads to the STATIC max_draft_len (not dynamic) because
CUDA graphs are compiled with fixed tensor shapes based on max_draft_len.
Comment on lines +110 to +112
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are planning on changing this in the near future with a follow up, right?

Args:
scheduled_requests: The scheduled requests to pad
"""
for req in scheduled_requests.generation_requests:
max_draft_tokens = self.max_draft_tokens
# Use static max_draft_tokens for CUDA graph compatibility
# CUDA graphs are sized for the maximum, even if we generate fewer tokens dynamically
num_draft_tokens = get_draft_token_length(req)
req.py_draft_tokens.extend(
0 for _ in range(max_draft_tokens - num_draft_tokens))
0 for _ in range(self._static_max_draft_tokens -
num_draft_tokens))

def run_drafter_post(
self,
Expand All @@ -79,3 +132,14 @@ def run_drafter_post(
this method can be overridden to do that.
Used in SaveHiddenStatesDrafter (to ensure correct input_ids)
"""

def update_max_draft_tokens(self, new_max_draft_tokens: int) -> None:
"""
Used when draft_len_schedule is provided in spec_config (dynamic draft length based on runtime batch size is enabled)
Update max_draft_tokens in drafter and propagate to any dependent components.
Subclasses can override to propagate to their resource managers if needed.
Args:
new_max_draft_tokens: The new max draft tokens
"""
self.max_draft_tokens = new_max_draft_tokens
42 changes: 39 additions & 3 deletions tensorrt_llm/_torch/speculative/model_drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,26 @@ def __init__(
spec_resource_manager: Optional[BaseResourceManager] = None,
guided_decoder: Optional[GuidedDecoder] = None,
):
super().__init__(spec_config.max_concurrency)

# Validate required parameters
if draft_model_engine is None:
raise ValueError("draft_model_engine cannot be None")
if max_draft_tokens < 0:
raise ValueError("max_draft_tokens must be >= 0")

super().__init__(
max_draft_tokens=spec_config.max_draft_len,
_static_max_draft_tokens=spec_config.max_draft_len,
max_concurrency=spec_config.max_concurrency,
draft_len_schedule=spec_config.draft_len_schedule,
)

# Model and resource management
self.draft_model_engine = draft_model_engine
self.draft_seq_slot_manager = draft_seq_slot_manager
self.spec_resource_manager = spec_resource_manager

# Configuration
self.spec_config = spec_config
self.max_draft_tokens = max_draft_tokens
# Sampling
self.sampler = sampler
self.guided_decoder = guided_decoder
Expand All @@ -78,6 +82,16 @@ def __init__(
assert guided_decoder is None
assert spec_config._allow_greedy_draft_tokens

# Currently dynamic draft length is not compatible with static draft loops
# TODO: support static draft loops with dynamic draft_len
if self.draft_len_schedule is not None:
raise ValueError(
"Dynamic draft length (draft_len_schedule) is not supported with "
"static draft loops (fused ChainDrafter/Eagle3). Static loops have "
"fixed iteration counts compiled into the model.\n"
"To use draft_len_schedule, please use ModelDrafter (2-model setup) "
"or NGramDrafter instead.")

def _create_draft_request(self, request: LlmRequest,
input_tokens: Optional[List]) -> LlmRequest:
"""Create a draft request with common parameters."""
Expand Down Expand Up @@ -682,6 +696,17 @@ def generate_draft_tokens_with_overlap(
- Updated target inputs or None
- Draft sample state or None
"""
# # Use pre-determined draft_len (set by executor BEFORE scheduling)
# if self.draft_len_schedule is not None and hasattr(self, '_current_batch_draft_len'):
# # Use pre-determined value from executor
# dynamic_draft_len = self._current_batch_draft_len

# # Override max_draft_tokens to the dynamic value
# self.max_draft_tokens = dynamic_draft_len

# # Note: If draft_len=0, this method won't be called anyway
# # (executor sets use_spec_decode=False and clears py_draft_tokens)

draft_batch, req_id_to_old_request = self._setup_draft_batch_and_resources(
scheduled_batch)
if draft_batch is None:
Expand Down Expand Up @@ -770,6 +795,17 @@ def prepare_draft_tokens(
if resource_manager is None:
raise ValueError("Resource manager is required")

# # Use pre-determined draft_len (set by executor BEFORE scheduling)
# if self.draft_len_schedule is not None and hasattr(self, '_current_batch_draft_len'):
# # Use pre-determined value from executor
# dynamic_draft_len = self._current_batch_draft_len

# # Override max_draft_tokens to the dynamic value
# self.max_draft_tokens = dynamic_draft_len

# # Note: If draft_len=0, this method won't be called anyway
# # (executor sets use_spec_decode=False and clears py_draft_tokens)
Comment on lines +799 to +807
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused code, remove?


try:
draft_batch, req_id_to_old_request = self._setup_draft_batch_and_resources(
scheduled_requests)
Expand Down
21 changes: 17 additions & 4 deletions tensorrt_llm/_torch/speculative/ngram.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class NGramPoolManager(BaseResourceManager):

Arguments:
max_draft_tokens: int
The length maximum of draft tokens (can be understood as length maximum of output draft tokens).
The length maximum of draft tokens (can be understood as length maximum of output draft tokens). If draft_len_schedule is provided in spec_config (dynamic draft length based on batch size is enabled), this value will be updated by the dynamic draft_len each step.

max_matching_ngram_size: int
The length maximum of searching tokens (can be understood as length maximum of input tokens to search).
Expand All @@ -51,7 +51,8 @@ class NGramPoolManager(BaseResourceManager):

def __init__(self, spec_config: "NGramDecodingConfig",
max_num_requests: int):
self.max_draft_tokens = spec_config.max_draft_len
self.max_draft_tokens = spec_config.max_draft_len # It's dynamic if draft_len_schedule is provided in spec_config (dynamic draft length based on runtime batch size is enabled). It's static in other cases.
self._static_max_draft_tokens = spec_config.max_draft_len # It's always static
self.max_matching_ngram_size = spec_config.max_matching_ngram_size
self.is_keep_all = spec_config.is_keep_all
self.is_use_oldest = spec_config.is_use_oldest # TODO: remove this if updating strategy is supported
Expand Down Expand Up @@ -167,17 +168,24 @@ def __init__(
spec_config: NGramDecodingConfig,
ngram_pool_manager: NGramPoolManager = None,
):
super().__init__(spec_config.max_concurrency)

super().__init__(
max_draft_tokens=spec_config.max_draft_len,
_static_max_draft_tokens=spec_config.max_draft_len,
max_concurrency=spec_config.max_concurrency,
draft_len_schedule=spec_config.draft_len_schedule,
)

assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool."
self.spec_config = spec_config
self.max_draft_tokens = spec_config.max_draft_len
self.spec_resource_manager = ngram_pool_manager

def prepare_draft_tokens(
self,
scheduled_requests: ScheduledRequests,
resource_manager: Optional[ResourceManager] = None,
) -> None:

# Sort by request_id when py_batch_idx is None as a fallback.
# This happens in the disagg case: for a set of new requests, we draft
# before forward_step, so py_batch_idx is not assigned.
Expand All @@ -197,3 +205,8 @@ def prepare_draft_tokens(
request.py_max_new_tokens,
)
request.py_draft_tokens = draft_tokens

def update_max_draft_tokens(self, new_max_draft_tokens: int) -> None:
"""Override to propagate to NGramPoolManager."""
super().update_max_draft_tokens(new_max_draft_tokens)
self.spec_resource_manager.max_draft_tokens = new_max_draft_tokens
46 changes: 46 additions & 0 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,11 +364,57 @@ class DecodingBaseConfig(StrictBaseModel):
# this value. Otherwise, speculation will always be on.
max_concurrency: Optional[int] = None

# Developer interface: dynamically adjust draft length based on active batch size in runtime.
# Maps batch size to draft lengths. For example:
# {1: 4, 4: 2, 8: 0} means:
# - batch_size >= 1: use draft_len=4
# - batch_size >= 4: use draft_len=2
# - batch_size >= 8: use draft_len=0 (disable speculation)
# draft_len_schedule is enforced to contain batch_size=1 and its according draft_len equals max_draft_len for consistency
# for example, if max_draft_len=4, the schedule must contain {1: 4}
draft_len_schedule: Optional[dict[int, int]] = None

load_format: Optional[str] = None

# If set, drafting uses greedy sampling, irrespective of sampling parameters.
_allow_greedy_draft_tokens: bool = PrivateAttr(True)

@field_validator('draft_len_schedule')
@classmethod
def validate_draft_len_schedule_and_sort(cls, v, info):
"""Validate and sort draft_len_schedule by batch size thresholds."""
if v is not None:
# Validate values
for batch_size, draft_len in v.items():
if batch_size < 1:
raise ValueError(
f"draft_len_schedule: batch size threshold must be >= 1, got {batch_size}"
)
if draft_len < 0:
raise ValueError(
f"draft_len_schedule: draft length must be >= 0, got {draft_len}"
)

# Require batch_size=1 in schedule
if 1 not in v:
raise ValueError(
"draft_len_schedule must include batch_size=1. "
"All systems can have batch_size=1. Add {1: <max_draft_len>} to your schedule."
)

# Enforce schedule[1] == max_draft_len for consistency
max_draft_len = info.data.get('max_draft_len')
if max_draft_len is not None and v[1] != max_draft_len:
raise ValueError(
f"draft_len_schedule[1] must equal max_draft_len for consistency. "
f"Got schedule[1]={v[1]}, but max_draft_len={max_draft_len}. "
f"batch_size=1 should use maximum draft length.")

# Return sorted dict (by batch size thresholds)
# This ensures efficient lookup
return dict(sorted(v.items(), key=lambda x: x[0]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to use collections.OrderedDict? I can't remember if this version of Python guarantees that the ordering will be preserved

return v

@classmethod
def from_dict(cls, data: dict):
# dispatch to the correct decoding config
Expand Down
Loading
Loading