Skip to content

Conversation

zheyuf
Copy link
Collaborator

@zheyuf zheyuf commented Oct 8, 2025

Summary by CodeRabbit

  • New Features

    • Dynamic draft length scheduling by batch size, automatically adjusting speculative decoding and draft token allocation at runtime.
    • New configuration option to define a draft_len_schedule with validation and sorting.
  • Behavior Changes

    • Chain Drafter now enabled only with a specific backend and when no schedule is provided.
    • Static draft loop is incompatible with dynamic scheduling; an error is raised if combined.
  • Tests

    • Added unit tests verifying correctness across batch sizes for both NGram and model-based drafters.

Description

We want to further improve speculative decoding performance by supporting dynamic draft length based on runtime active batch size. Currently, draft length in TensorRT-LLM is a static value that remains fixed after initialization from the spec config. However, a fixed draft length is sub-optimal because real-world workloads are not uniform. The optimal draft length changes depending on whether the system is memory-bound (small batches, favoring longer drafts) or compute-bound (large batches, favoring shorter drafts).

The feature is planned to be executed in three stages:
Stage 1 (This PR): Implement the developer interface and modify the drafting side for NGramDrafter and ModelDrafter (2-model setup). After this stage, we achieve real compute savings on the draft side by generating fewer tokens. However, draft-side-only savings have limited impact since the target model is larger and its potential savings are greater. Currently, there are no savings on the target side because it still processes the same padded token shapes from the draft side.

Stage 2 (Future): Implement target-side optimization by capturing target model CUDA graphs for each unique draft length and selecting the appropriate graph at runtime. This will eliminate wasteful operations spent on padding and achieve significant compute savings. The tradeoff is increased memory overhead and CUDA graph warmup time due to maintaining multiple graphs for different draft lengths.

Stage 3 (Future): Extend the feature to ChainDrafter, which were not addressed in Stage 1's drafting-side implementation. This will likely require multiple static drafters (or CUDA graphs) for different draft lengths. Extend the feature also to one-model Eagle3/MTP variants? (not sure)

We can determine the most user-friendly way to expose this functionality after implementation is complete (for example, we could conduct a performance study and add it to the AUTO heuristic).

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@zheyuf zheyuf requested a review from mikeiovine October 8, 2025 07:50
@zheyuf zheyuf marked this pull request as ready for review October 8, 2025 16:27
@zheyuf zheyuf requested review from a team as code owners October 8, 2025 16:27
@zheyuf zheyuf requested a review from hchings October 8, 2025 16:27
Copy link
Contributor

coderabbitai bot commented Oct 8, 2025

📝 Walkthrough

Walkthrough

Introduces dynamic draft length scheduling across speculative decoding components. PyExecutor computes and updates max_draft_len at runtime from a draft_len_schedule, toggles spec decode accordingly, and adjusts draft token allocation. Drafter and NGram components gain static vs dynamic max tracking and update propagation. Config adds draft_len_schedule with validation. ChainDrafter enablement is constrained. New tests cover schedule-driven correctness.

Changes

Cohort / File(s) Summary
PyExecutor dynamic draft length
tensorrt_llm/_torch/pyexecutor/py_executor.py
Adds runtime max_draft_len adjustment via Drafter.draft_len_schedule; maintains _static_max_draft_len; updates spec decode enablement; initializes/pads draft tokens using dynamic/static lengths as appropriate.
Executor creation constraints
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Restricts ChainDrafter usage: requires attn_backend="TRTLLM" and no draft_len_schedule when allow-chain-drafter is set.
Speculative base drafter
tensorrt_llm/_torch/speculative/drafter.py
Constructor now accepts max_draft_tokens, _static_max_draft_tokens, optional draft_len_schedule; adds get_draft_len_for_batch_size and update_max_draft_tokens; pads CUDA graphs using static max.
Model drafter integration
tensorrt_llm/_torch/speculative/model_drafter.py
Initializes base Drafter with dynamic/static max fields and schedule; removes direct self.max_draft_tokens; raises ValueError when static draft loop is combined with a schedule.
NGram drafter and pool
tensorrt_llm/_torch/speculative/ngram.py
Adds _static_max_draft_tokens tracking; implements update_max_draft_tokens on NGramPoolManager and NGramDrafter; propagates dynamic max updates.
Config and validation
tensorrt_llm/llmapi/llm_args.py
Adds DecodingBaseConfig.draft_len_schedule with validator enforcing keys/values, presence of batch_size=1, consistency with max_draft_len, and sorted normalization.
Tests
tests/unittest/_torch/speculative/test_draft_len_schedule.py
New tests validating dynamic draft length scheduling across batch sizes for NGram and ModelDrafter paths against non-speculative outputs, with memory-based skips.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant U as User
  participant L as LLM (PyExecutor)
  participant D as Drafter
  participant T as Target Model
  participant S as Spec Decode

  rect rgba(230,240,255,0.5)
  note over L,D: Initialization
  D->>D: init(max_draft_tokens, _static_max_draft_tokens, draft_len_schedule?)
  end

  U->>L: generate(prompts, batch_size=B)
  alt draft_len_schedule provided
    L->>D: get_draft_len_for_batch_size(B)
    D-->>L: max_draft_len_for_B
    L->>D: update_max_draft_tokens(new_max_draft_len)
    alt max_draft_len_for_B == 0
      L->>S: disable
    else
      L->>S: enable (should_use_spec_decode)
    end
  else no schedule
    note over L: use static _static_max_draft_len
  end

  par Prepare batch
    L->>L: allocate draft_tokens length = current max_draft_len (>0)
    L->>T: run forward pass
    opt spec enabled
      L->>D: draft tokens (ngram/model)
      D-->>L: drafts (<= current max_draft_len)
      L->>S: speculative verify/accept
    end
  end

  L-->>U: outputs
Loading
sequenceDiagram
  autonumber
  participant C as Creator
  participant L as PyExecutor
  participant D as ChainDrafter

  C->>C: parse config (attn_backend, allow-chain-drafter, draft_len_schedule)
  alt allow-chain-drafter AND attn_backend=="TRTLLM" AND draft_len_schedule is None
    C->>D: construct ChainDrafter
    D-->>L: attach
  else
    note over C,L: ChainDrafter not used
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description Check ⚠️ Warning The PR description includes a clear Description section but leaves the Test Coverage section empty and does not provide the generated summary after the @coderabbitai summary directive, which diverges from the required template’s structure. Please replace the @coderabbitai summary directive with a concise summary of the changes and populate the Test Coverage section with the specific test files or test cases that exercise the new dynamic draft length functionality.
Docstring Coverage ⚠️ Warning Docstring coverage is 56.52% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The pull request title follows the repository’s conventions by including a valid JIRA ticket, a lowercase type, and concisely summarizing the main change (“Dynamic draft length in spec decode (stage 1)”), which accurately reflects the PR’s primary objective.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/pyexecutor/py_executor.py (1)

1607-1614: Fix attribute mismatch for static draft length.

self.static_max_draft_len is referenced when padding the attention-DP dummy request, but the constructor only sets self._static_max_draft_len. This will raise AttributeError as soon as _pad_attention_dp_dummy_request() runs (spec decode or not). Rename the constructor field to self.static_max_draft_len (or add a property) so the attribute exists.

-        self._static_max_draft_len = max_draft_len  # It's always static
+        self.static_max_draft_len = max_draft_len  # It's always static
🧹 Nitpick comments (1)
tensorrt_llm/_torch/speculative/drafter.py (1)

64-71: Clarify edge-case handling in warning message.

The warning message "batch_size < 1" is confusing since the validator in llm_args.py (lines 389-392) already enforces batch_size >= 1 in the schedule. This edge case would only occur if get_draft_len_for_batch_size is called with an invalid batch_size argument (e.g., 0 or negative), which would be a programming error rather than a configuration issue.

Consider clarifying the message to indicate this is an unexpected programming error:

         if idx == 0:
-            # batch_size is smaller than smallest threshold (batch_size smaller than 1)
-            # This shouldn't happen in practice, but handle defensively
+            # batch_size is smaller than smallest threshold in schedule
+            # This should never happen since validator enforces batch_size >= 1 in schedule
             logger.warning(
-                f"get_draft_len_for_batch_size called with batch_size={batch_size} < 1. "
+                f"get_draft_len_for_batch_size called with invalid batch_size={batch_size}. "
                 f"This is unexpected. Disabling speculation (returning draft_len=0)."
             )
             return 0
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e986165 and bc061dd.

📒 Files selected for processing (7)
  • tensorrt_llm/_torch/pyexecutor/py_executor.py (4 hunks)
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1 hunks)
  • tensorrt_llm/_torch/speculative/drafter.py (5 hunks)
  • tensorrt_llm/_torch/speculative/model_drafter.py (4 hunks)
  • tensorrt_llm/_torch/speculative/ngram.py (4 hunks)
  • tensorrt_llm/llmapi/llm_args.py (1 hunks)
  • tests/unittest/_torch/speculative/test_draft_len_schedule.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/_torch/speculative/model_drafter.py
  • tensorrt_llm/_torch/speculative/ngram.py
  • tensorrt_llm/llmapi/llm_args.py
  • tests/unittest/_torch/speculative/test_draft_len_schedule.py
  • tensorrt_llm/_torch/speculative/drafter.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/_torch/speculative/model_drafter.py
  • tensorrt_llm/_torch/speculative/ngram.py
  • tensorrt_llm/llmapi/llm_args.py
  • tests/unittest/_torch/speculative/test_draft_len_schedule.py
  • tensorrt_llm/_torch/speculative/drafter.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/_torch/speculative/model_drafter.py
  • tensorrt_llm/_torch/speculative/ngram.py
  • tensorrt_llm/llmapi/llm_args.py
  • tests/unittest/_torch/speculative/test_draft_len_schedule.py
  • tensorrt_llm/_torch/speculative/drafter.py
🧬 Code graph analysis (5)
tensorrt_llm/_torch/pyexecutor/py_executor.py (4)
tensorrt_llm/_torch/speculative/drafter.py (3)
  • get_draft_len_for_batch_size (45-75)
  • update_max_draft_tokens (136-145)
  • should_use_spec_decode (78-102)
tensorrt_llm/_torch/speculative/ngram.py (1)
  • update_max_draft_tokens (209-212)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (1)
  • enable_spec_decode (90-91)
cpp/include/tensorrt_llm/batch_manager/llmRequest.h (1)
  • LlmRequestState (47-210)
tensorrt_llm/_torch/speculative/model_drafter.py (1)
tensorrt_llm/runtime/generation.py (1)
  • max_draft_tokens (1283-1286)
tensorrt_llm/_torch/speculative/ngram.py (2)
tensorrt_llm/_torch/speculative/drafter.py (1)
  • update_max_draft_tokens (136-145)
tensorrt_llm/_torch/pyexecutor/scheduler.py (1)
  • ScheduledRequests (18-39)
tests/unittest/_torch/speculative/test_draft_len_schedule.py (2)
tensorrt_llm/llmapi/llm_args.py (4)
  • DraftTargetDecodingConfig (705-714)
  • KvCacheConfig (1152-1286)
  • NGramDecodingConfig (670-702)
  • speculative_model_dir (1621-1622)
tensorrt_llm/_torch/pyexecutor/py_executor.py (1)
  • shutdown (446-459)
tensorrt_llm/_torch/speculative/drafter.py (5)
tensorrt_llm/runtime/generation.py (1)
  • max_draft_tokens (1283-1286)
tensorrt_llm/_torch/pyexecutor/scheduler.py (1)
  • batch_size (35-36)
tensorrt_llm/logger.py (1)
  • warning (132-133)
tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
  • get_draft_token_length (703-714)
tensorrt_llm/_torch/speculative/ngram.py (1)
  • update_max_draft_tokens (209-212)
🪛 Ruff (0.13.3)
tensorrt_llm/_torch/speculative/model_drafter.py

88-93: Avoid specifying long messages outside the exception class

(TRY003)

tensorrt_llm/_torch/speculative/ngram.py

186-186: Unused method argument: resource_manager

(ARG002)

tensorrt_llm/llmapi/llm_args.py

390-392: Avoid specifying long messages outside the exception class

(TRY003)


394-396: Avoid specifying long messages outside the exception class

(TRY003)


400-403: Avoid specifying long messages outside the exception class

(TRY003)


408-411: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (12)
tensorrt_llm/_torch/speculative/drafter.py (3)

44-75: LGTM: Binary search implementation is correct.

The binary search logic using bisect_right correctly finds the largest threshold <= batch_size. The returned draft length appropriately matches the configuration's intent for dynamic scheduling.


105-122: LGTM: CUDA graph padding correctly uses static max.

The updated padding logic correctly uses _static_max_draft_tokens to ensure compatibility with CUDA graphs, which require fixed tensor shapes. The comments clearly explain the rationale.


136-145: LGTM: Update method provides correct extension point.

The update_max_draft_tokens method provides a clean extension point for subclasses to propagate updates to dependent components (as demonstrated by NGramDrafter in lines 209-212 of ngram.py).

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)

347-353: LGTM: ChainDrafter correctly disabled for dynamic draft length.

The additional condition draft_spec_config.draft_len_schedule is None appropriately prevents ChainDrafter usage when dynamic draft length scheduling is enabled. The inline comment clearly explains this limitation aligns with the validation in model_drafter.py (lines 87-93) that prevents static draft loops with dynamic scheduling.

tensorrt_llm/_torch/speculative/ngram.py (3)

54-55: LGTM: Static max draft tokens correctly tracked.

The addition of _static_max_draft_tokens alongside max_draft_tokens correctly implements the dual tracking pattern introduced in the base Drafter class, enabling dynamic draft length support while preserving the static maximum.


172-177: LGTM: Constructor correctly initializes base class.

The constructor properly passes both max_draft_tokens and _static_max_draft_tokens to the base Drafter class, along with max_concurrency and draft_len_schedule, fully supporting the dynamic draft length feature.


209-212: LGTM: Update propagation correctly implemented.

The update_max_draft_tokens override properly propagates changes to NGramPoolManager, ensuring consistency when dynamic draft length adjusts the maximum at runtime.

tensorrt_llm/llmapi/llm_args.py (2)

367-375: LGTM: Field documentation is clear and helpful.

The draft_len_schedule field documentation clearly explains the dynamic draft length feature with a concrete example showing how batch size thresholds map to draft lengths, making the configuration intent obvious to users.


382-416: LGTM: Validator enforces critical consistency constraints.

The validator correctly enforces:

  1. Valid ranges (batch_size >= 1, draft_len >= 0)
  2. Required batch_size=1 entry (all systems can have batch_size=1)
  3. Consistency between schedule[1] and max_draft_len (prevents configuration errors)
  4. Sorted output for efficient binary search in get_draft_len_for_batch_size

The error messages are clear and actionable, guiding users to fix configuration issues.

tensorrt_llm/_torch/speculative/model_drafter.py (3)

61-66: LGTM: Constructor correctly integrates with base class.

The constructor properly passes both max_draft_tokens (from spec_config.max_draft_len) and _static_max_draft_tokens to the base Drafter class, along with max_concurrency and draft_len_schedule, enabling full dynamic draft length support for ModelDrafter (two-model setup).


85-93: LGTM: Validation prevents incompatible configuration.

The validation correctly prevents combining static draft loops (fused ChainDrafter/Eagle3) with dynamic draft length scheduling, since static loops have fixed iteration counts compiled into the model. The error message clearly explains the constraint and suggests alternatives (ModelDrafter or NGramDrafter).


699-708: Commented code is appropriately marked for Stage 2.

The commented-out code for dynamic draft length is clearly marked and aligns with the PR objectives stating that Stage 1 (this PR) focuses on the developer interface and drafting-side updates, while Stage 2 will handle target-side optimization. The comments indicate the intended integration points for when this feature is enabled.

Also applies to: 798-807

Comment on lines +15 to +28
def __init__(
self,
max_draft_tokens: int,
_static_max_draft_tokens: int,
max_concurrency: Optional[int] = None,
draft_len_schedule: Optional[Dict[int, int]] = None,
) -> None:
self.max_concurrency = max_concurrency
# Schedule is already validated and sorted by config validator
self.draft_len_schedule = draft_len_schedule
# It's dynamic if draft_len_schedule is provided in spec_config (dynamic draft length based on runtime batch size is enabled). It's static in other cases.
self.max_draft_tokens = max_draft_tokens
# It's always static
self._static_max_draft_tokens = max_draft_tokens
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Fix assignment of _static_max_draft_tokens.

Line 28 assigns self._static_max_draft_tokens = max_draft_tokens, which sets the "always static" value to the potentially dynamic max_draft_tokens parameter. This contradicts the comment on line 27 stating "It's always static" and breaks the intended separation between dynamic and static max draft tokens.

Apply this diff to use the correct parameter:

     def __init__(
         self,
         max_draft_tokens: int,
         _static_max_draft_tokens: int,
         max_concurrency: Optional[int] = None,
         draft_len_schedule: Optional[Dict[int, int]] = None,
     ) -> None:
         self.max_concurrency = max_concurrency
         # Schedule is already validated and sorted by config validator
         self.draft_len_schedule = draft_len_schedule
         # It's dynamic if draft_len_schedule is provided in spec_config (dynamic draft length based on runtime batch size is enabled). It's static in other cases.
         self.max_draft_tokens = max_draft_tokens
         # It's always static
-        self._static_max_draft_tokens = max_draft_tokens
+        self._static_max_draft_tokens = _static_max_draft_tokens
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def __init__(
self,
max_draft_tokens: int,
_static_max_draft_tokens: int,
max_concurrency: Optional[int] = None,
draft_len_schedule: Optional[Dict[int, int]] = None,
) -> None:
self.max_concurrency = max_concurrency
# Schedule is already validated and sorted by config validator
self.draft_len_schedule = draft_len_schedule
# It's dynamic if draft_len_schedule is provided in spec_config (dynamic draft length based on runtime batch size is enabled). It's static in other cases.
self.max_draft_tokens = max_draft_tokens
# It's always static
self._static_max_draft_tokens = max_draft_tokens
def __init__(
self,
max_draft_tokens: int,
_static_max_draft_tokens: int,
max_concurrency: Optional[int] = None,
draft_len_schedule: Optional[Dict[int, int]] = None,
) -> None:
self.max_concurrency = max_concurrency
# Schedule is already validated and sorted by config validator
self.draft_len_schedule = draft_len_schedule
# It's dynamic if draft_len_schedule is provided in spec_config (dynamic draft length based on runtime batch size is enabled). It's static in other cases.
self.max_draft_tokens = max_draft_tokens
# It's always static
self._static_max_draft_tokens = _static_max_draft_tokens
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/speculative/drafter.py around lines 15 to 28, the
constructor incorrectly sets self._static_max_draft_tokens = max_draft_tokens (a
possibly dynamic value); change the assignment to use the provided
_static_max_draft_tokens parameter instead (self._static_max_draft_tokens =
_static_max_draft_tokens) so the "always static" value remains the intended
static parameter.

Comment on lines +1 to +14
"""
test_draft_len_schedule.py
Tests for dynamic draft length (draft_len_schedule) feature - Stage 1.
Stage 1 covers:
- NGramDrafter with dynamic draft_len
- ModelDrafter (2-model) with dynamic draft_len
- Draft-side compute savings only (target model still processes padded tokens)
Not covered in Stage 1:
- ChainDrafter/Eagle3 static loops (Stage 3)
- Target model compute savings (Stage 2)
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add required NVIDIA Apache-2.0 header.

Per project guidelines, every new Python source must start with the NVIDIA Apache-2.0 copyright header for the current year. Please prepend the standard header block above the module docstring.

🤖 Prompt for AI Agents
In tests/unittest/_torch/speculative/test_draft_len_schedule.py around lines 1
to 14, the file is missing the required NVIDIA Apache-2.0 copyright header;
prepend the standard NVIDIA Apache-2.0 header block (updated with the current
year) at the very top of the file above the module docstring so the file begins
with the full copyright and license header followed by the existing
triple-quoted module docstring.

Copy link
Collaborator

@mikeiovine mikeiovine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks! Only have a few nits

self.max_beam_width = max_beam_width
self.max_draft_len = max_draft_len
self.max_draft_len = max_draft_len # It's dynamic if draft_len_schedule is provided in spec_config (dynamic draft length based on runtime batch size is enabled). It's static in other cases.
self._static_max_draft_len = max_draft_len # It's always static
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think the name of the variable is pretty self-explanatory, comment is unnecessary


# Return sorted dict (by batch size thresholds)
# This ensures efficient lookup
return dict(sorted(v.items(), key=lambda x: x[0]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to use collections.OrderedDict? I can't remember if this version of Python guarantees that the ordering will be preserved

Comment on lines +799 to +807
# if self.draft_len_schedule is not None and hasattr(self, '_current_batch_draft_len'):
# # Use pre-determined value from executor
# dynamic_draft_len = self._current_batch_draft_len

# # Override max_draft_tokens to the dynamic value
# self.max_draft_tokens = dynamic_draft_len

# # Note: If draft_len=0, this method won't be called anyway
# # (executor sets use_spec_decode=False and clears py_draft_tokens)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused code, remove?

Comment on lines +110 to +112
Note: Always pads to the STATIC max_draft_len (not dynamic) because
CUDA graphs are compiled with fixed tensor shapes based on max_draft_len.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are planning on changing this in the near future with a follow up, right?

def __init__(
self,
max_draft_tokens: int,
_static_max_draft_tokens: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused argument. Both fields are initialized with max_draft_tokens, which makes sense

and pytorch_backend_config.attn_backend == "TRTLLM")
and pytorch_backend_config.attn_backend == "TRTLLM"
and draft_spec_config.draft_len_schedule is None
) # currently ChainDrafter does not support dynamic draft length
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary comment. I think it's clear from the context that all the skips are for stuff that the ChainDrafter does not support

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants