-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[TRTLLM-8136][feat] Dynamic draft length in spec decode (stage 1). #8194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -348,7 +348,9 @@ def create_py_executor( | |
use_chain_drafter = ( | ||
guided_decoding_config is None | ||
and draft_spec_config._allow_greedy_draft_tokens | ||
and pytorch_backend_config.attn_backend == "TRTLLM") | ||
and pytorch_backend_config.attn_backend == "TRTLLM" | ||
and draft_spec_config.draft_len_schedule is None | ||
) # currently ChainDrafter does not support dynamic draft length | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unnecessary comment. I think it's clear from the context that all the skips are for stuff that the ChainDrafter does not support |
||
else: | ||
use_chain_drafter = False | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,5 +1,8 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from abc import ABC, abstractmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from typing import List, Optional, final | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from bisect import bisect_right | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from typing import Dict, List, Optional, final | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from tensorrt_llm.logger import logger | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from ..pyexecutor.llm_request import LlmRequest, get_draft_token_length | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from ..pyexecutor.resource_manager import ResourceManager | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -9,8 +12,20 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
class Drafter(ABC): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Abstract base class for all drafter implementations.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __init__(self, max_concurrency: Optional[int] = None) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
max_draft_tokens: int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
_static_max_draft_tokens: int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unused argument. Both fields are initialized with |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
max_concurrency: Optional[int] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
draft_len_schedule: Optional[Dict[int, int]] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.max_concurrency = max_concurrency | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Schedule is already validated and sorted by config validator | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.draft_len_schedule = draft_len_schedule | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# It's dynamic if draft_len_schedule is provided in spec_config (dynamic draft length based on runtime batch size is enabled). It's static in other cases. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.max_draft_tokens = max_draft_tokens | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# It's always static | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._static_max_draft_tokens = max_draft_tokens | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+15
to
+28
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Critical: Fix assignment of Line 28 assigns Apply this diff to use the correct parameter: def __init__(
self,
max_draft_tokens: int,
_static_max_draft_tokens: int,
max_concurrency: Optional[int] = None,
draft_len_schedule: Optional[Dict[int, int]] = None,
) -> None:
self.max_concurrency = max_concurrency
# Schedule is already validated and sorted by config validator
self.draft_len_schedule = draft_len_schedule
# It's dynamic if draft_len_schedule is provided in spec_config (dynamic draft length based on runtime batch size is enabled). It's static in other cases.
self.max_draft_tokens = max_draft_tokens
# It's always static
- self._static_max_draft_tokens = max_draft_tokens
+ self._static_max_draft_tokens = _static_max_draft_tokens 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@abstractmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def prepare_draft_tokens( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -26,6 +41,39 @@ def prepare_draft_tokens( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
raise NotImplementedError | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@final | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def get_draft_len_for_batch_size(self, batch_size: int) -> int: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Get the appropriate draft length for the given batch size using binary search. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
batch_size: Current batch size (has been sorted by config validator) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
The draft length to use for this batch size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Binary search to find the largest threshold <= batch_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# draft_len_schedule is already sorted by config validator | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
thresholds = list(self.draft_len_schedule.keys()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# bisect_right finds where to insert batch_size to keep list sorted | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# The element before insertion point is the largest threshold <= batch_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
idx = bisect_right(thresholds, batch_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if idx == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# batch_size is smaller than smallest threshold (batch_size smaller than 1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# This shouldn't happen in practice, but handle defensively | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger.warning( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
f"get_draft_len_for_batch_size called with batch_size={batch_size} < 1. " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
f"This is unexpected. Disabling speculation (returning draft_len=0)." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Return draft_len for the largest threshold <= batch_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
threshold = thresholds[idx - 1] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self.draft_len_schedule[threshold] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@final | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def should_use_spec_decode(self, requests: List[LlmRequest], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
max_batch_size: int, max_num_tokens: int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -59,14 +107,19 @@ def pad_draft_tokens_for_cuda_graph( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Pad draft tokens to the max draft length for CUDA graph compatibility. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Note: Always pads to the STATIC max_draft_len (not dynamic) because | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
CUDA graphs are compiled with fixed tensor shapes based on max_draft_len. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+110
to
+112
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are planning on changing this in the near future with a follow up, right? |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
scheduled_requests: The scheduled requests to pad | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for req in scheduled_requests.generation_requests: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
max_draft_tokens = self.max_draft_tokens | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Use static max_draft_tokens for CUDA graph compatibility | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# CUDA graphs are sized for the maximum, even if we generate fewer tokens dynamically | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
num_draft_tokens = get_draft_token_length(req) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
req.py_draft_tokens.extend( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
0 for _ in range(max_draft_tokens - num_draft_tokens)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
0 for _ in range(self._static_max_draft_tokens - | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
num_draft_tokens)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def run_drafter_post( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -79,3 +132,14 @@ def run_drafter_post( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
this method can be overridden to do that. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Used in SaveHiddenStatesDrafter (to ensure correct input_ids) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def update_max_draft_tokens(self, new_max_draft_tokens: int) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Used when draft_len_schedule is provided in spec_config (dynamic draft length based on runtime batch size is enabled) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Update max_draft_tokens in drafter and propagate to any dependent components. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Subclasses can override to propagate to their resource managers if needed. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
new_max_draft_tokens: The new max draft tokens | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.max_draft_tokens = new_max_draft_tokens |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,22 +52,26 @@ def __init__( | |
spec_resource_manager: Optional[BaseResourceManager] = None, | ||
guided_decoder: Optional[GuidedDecoder] = None, | ||
): | ||
super().__init__(spec_config.max_concurrency) | ||
|
||
# Validate required parameters | ||
if draft_model_engine is None: | ||
raise ValueError("draft_model_engine cannot be None") | ||
if max_draft_tokens < 0: | ||
raise ValueError("max_draft_tokens must be >= 0") | ||
|
||
super().__init__( | ||
max_draft_tokens=spec_config.max_draft_len, | ||
_static_max_draft_tokens=spec_config.max_draft_len, | ||
max_concurrency=spec_config.max_concurrency, | ||
draft_len_schedule=spec_config.draft_len_schedule, | ||
) | ||
|
||
# Model and resource management | ||
self.draft_model_engine = draft_model_engine | ||
self.draft_seq_slot_manager = draft_seq_slot_manager | ||
self.spec_resource_manager = spec_resource_manager | ||
|
||
# Configuration | ||
self.spec_config = spec_config | ||
self.max_draft_tokens = max_draft_tokens | ||
# Sampling | ||
self.sampler = sampler | ||
self.guided_decoder = guided_decoder | ||
|
@@ -78,6 +82,16 @@ def __init__( | |
assert guided_decoder is None | ||
assert spec_config._allow_greedy_draft_tokens | ||
|
||
# Currently dynamic draft length is not compatible with static draft loops | ||
# TODO: support static draft loops with dynamic draft_len | ||
if self.draft_len_schedule is not None: | ||
raise ValueError( | ||
"Dynamic draft length (draft_len_schedule) is not supported with " | ||
"static draft loops (fused ChainDrafter/Eagle3). Static loops have " | ||
"fixed iteration counts compiled into the model.\n" | ||
"To use draft_len_schedule, please use ModelDrafter (2-model setup) " | ||
"or NGramDrafter instead.") | ||
|
||
def _create_draft_request(self, request: LlmRequest, | ||
input_tokens: Optional[List]) -> LlmRequest: | ||
"""Create a draft request with common parameters.""" | ||
|
@@ -682,6 +696,17 @@ def generate_draft_tokens_with_overlap( | |
- Updated target inputs or None | ||
- Draft sample state or None | ||
""" | ||
# # Use pre-determined draft_len (set by executor BEFORE scheduling) | ||
# if self.draft_len_schedule is not None and hasattr(self, '_current_batch_draft_len'): | ||
# # Use pre-determined value from executor | ||
# dynamic_draft_len = self._current_batch_draft_len | ||
|
||
# # Override max_draft_tokens to the dynamic value | ||
# self.max_draft_tokens = dynamic_draft_len | ||
|
||
# # Note: If draft_len=0, this method won't be called anyway | ||
# # (executor sets use_spec_decode=False and clears py_draft_tokens) | ||
|
||
draft_batch, req_id_to_old_request = self._setup_draft_batch_and_resources( | ||
scheduled_batch) | ||
if draft_batch is None: | ||
|
@@ -770,6 +795,17 @@ def prepare_draft_tokens( | |
if resource_manager is None: | ||
raise ValueError("Resource manager is required") | ||
|
||
# # Use pre-determined draft_len (set by executor BEFORE scheduling) | ||
# if self.draft_len_schedule is not None and hasattr(self, '_current_batch_draft_len'): | ||
# # Use pre-determined value from executor | ||
# dynamic_draft_len = self._current_batch_draft_len | ||
|
||
# # Override max_draft_tokens to the dynamic value | ||
# self.max_draft_tokens = dynamic_draft_len | ||
|
||
# # Note: If draft_len=0, this method won't be called anyway | ||
# # (executor sets use_spec_decode=False and clears py_draft_tokens) | ||
Comment on lines
+799
to
+807
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unused code, remove? |
||
|
||
try: | ||
draft_batch, req_id_to_old_request = self._setup_draft_batch_and_resources( | ||
scheduled_requests) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -364,11 +364,57 @@ class DecodingBaseConfig(StrictBaseModel): | |
# this value. Otherwise, speculation will always be on. | ||
max_concurrency: Optional[int] = None | ||
|
||
# Developer interface: dynamically adjust draft length based on active batch size in runtime. | ||
# Maps batch size to draft lengths. For example: | ||
# {1: 4, 4: 2, 8: 0} means: | ||
# - batch_size >= 1: use draft_len=4 | ||
# - batch_size >= 4: use draft_len=2 | ||
# - batch_size >= 8: use draft_len=0 (disable speculation) | ||
# draft_len_schedule is enforced to contain batch_size=1 and its according draft_len equals max_draft_len for consistency | ||
# for example, if max_draft_len=4, the schedule must contain {1: 4} | ||
draft_len_schedule: Optional[dict[int, int]] = None | ||
|
||
load_format: Optional[str] = None | ||
|
||
# If set, drafting uses greedy sampling, irrespective of sampling parameters. | ||
_allow_greedy_draft_tokens: bool = PrivateAttr(True) | ||
|
||
@field_validator('draft_len_schedule') | ||
@classmethod | ||
def validate_draft_len_schedule_and_sort(cls, v, info): | ||
"""Validate and sort draft_len_schedule by batch size thresholds.""" | ||
if v is not None: | ||
# Validate values | ||
for batch_size, draft_len in v.items(): | ||
if batch_size < 1: | ||
raise ValueError( | ||
f"draft_len_schedule: batch size threshold must be >= 1, got {batch_size}" | ||
) | ||
if draft_len < 0: | ||
raise ValueError( | ||
f"draft_len_schedule: draft length must be >= 0, got {draft_len}" | ||
) | ||
|
||
# Require batch_size=1 in schedule | ||
if 1 not in v: | ||
raise ValueError( | ||
"draft_len_schedule must include batch_size=1. " | ||
"All systems can have batch_size=1. Add {1: <max_draft_len>} to your schedule." | ||
) | ||
|
||
# Enforce schedule[1] == max_draft_len for consistency | ||
max_draft_len = info.data.get('max_draft_len') | ||
if max_draft_len is not None and v[1] != max_draft_len: | ||
raise ValueError( | ||
f"draft_len_schedule[1] must equal max_draft_len for consistency. " | ||
f"Got schedule[1]={v[1]}, but max_draft_len={max_draft_len}. " | ||
f"batch_size=1 should use maximum draft length.") | ||
|
||
# Return sorted dict (by batch size thresholds) | ||
# This ensures efficient lookup | ||
return dict(sorted(v.items(), key=lambda x: x[0])) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to use |
||
return v | ||
|
||
@classmethod | ||
def from_dict(cls, data: dict): | ||
# dispatch to the correct decoding config | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think the name of the variable is pretty self-explanatory, comment is unnecessary