Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 60 additions & 17 deletions tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def __init__(self, engine: "PyTorchModelEngine"):
Callable[[], Optional[torch.Tensor]]] = {}
self.graph_metadata: Dict[Tuple[int, int, int], Dict[str, Any]] = {}
self.memory_pool = engine._cuda_graph_mem_pool
self.padding_dummy_request: Optional["Request"] = None
# Stage 2: Pre-allocate one padding dummy per unique draft_len for zero overhead
self.padding_dummies: Dict[int,
"Request"] = {} # draft_len -> dummy_request

self.shared_static_tensors: Dict[str, torch.Tensor] = {}
if self.enabled:
Expand Down Expand Up @@ -98,14 +100,38 @@ def max_possible_draft_len(self):
def get_graph_key(
self,
batch_size,
spec_resource_manager: Optional[BaseResourceManager] = None):
spec_resource_manager: Optional[BaseResourceManager] = None,
runtime_draft_len: Optional[int] = None):
"""
Get the CUDA graph key for the given batch and draft configuration.

Stage 2 Dynamic Draft Length: When runtime_draft_len is provided and
a draft_len_schedule exists, use the runtime draft length instead of
max_draft_len. This enables selecting the appropriate CUDA graph for
the current draft length.

Args:
batch_size: Batch size for the graph
spec_resource_manager: Optional resource manager for spec decoding
runtime_draft_len: Optional runtime draft length (Stage 2 feature)

Returns:
Tuple of (batch_size, draft_len, is_first_draft)
"""
engine = self._get_engine()
if engine.is_draft_model and spec_resource_manager is not None and isinstance(
spec_resource_manager, Eagle3ResourceManager):
draft_len = engine.original_max_draft_len if spec_resource_manager.is_first_draft else 0
key = (batch_size, draft_len, spec_resource_manager.is_first_draft)
else:
draft_len = self.spec_config.max_draft_len if self.enable_spec_decode else 0
# Stage 2: Use runtime draft length if provided and schedule exists
if (runtime_draft_len is not None and self.spec_config
and hasattr(self.spec_config, 'draft_len_schedule')
and self.spec_config.draft_len_schedule is not None):
draft_len = runtime_draft_len
else:
# Legacy behavior: use max_draft_len
draft_len = self.spec_config.max_draft_len if self.enable_spec_decode else 0
key = (batch_size, draft_len, False)
return key

Expand Down Expand Up @@ -135,10 +161,14 @@ def _get_engine(self) -> "PyTorchModelEngine":
def maybe_get_cuda_graph(
self,
batch: ScheduledRequests,
spec_resource_manager: Optional[BaseResourceManager] = None):
spec_resource_manager: Optional[BaseResourceManager] = None,
runtime_draft_len: Optional[int] = None):
"""
Determines if the current batch can be run with a CUDA graph.

Stage 2 Dynamic Draft Length: When runtime_draft_len is provided,
selects the CUDA graph corresponding to that draft length.

Returns a tuple containing:
- A boolean indicating if a graph can be used.
- The attn_metadata for the graph, if applicable.
Expand Down Expand Up @@ -168,7 +198,8 @@ def maybe_get_cuda_graph(

if not self.enabled or not can_run_cuda_graph:
return False, None, None, None
key = self.get_graph_key(batch_size, spec_resource_manager)
key = self.get_graph_key(batch_size, spec_resource_manager,
runtime_draft_len)

if key in self.graphs:
return True, self.graph_metadata[key][
Expand Down Expand Up @@ -342,29 +373,41 @@ def _get_padded_batch(self, batch: ScheduledRequests,
if padding_size + batch.batch_size > engine.batch_size:
return 0

# No padding if it would create too many concurrent requests.
# This is not strictly required, but we should probably
# respect the requirement just in case that changes in the future.
if self.padding_dummy_request is None:
# Stage 2: Get or create padding dummy for current runtime draft length
# Pre-allocation strategy: Create one dummy per unique draft_len (very small cost)
# This avoids recreation overhead while preserving Stage 2 benefits
runtime_draft_len = engine.max_draft_len # Current draft_len for this iteration

# Check if dummy for this draft_len already exists
if runtime_draft_len not in self.padding_dummies:
available_blocks = kv_cache_manager.get_num_free_blocks()
# No padding if not enough KV cache space
if available_blocks < 1:
return 0

self.padding_dummy_request = kv_cache_manager.add_dummy_requests(
[CUDA_GRAPH_DUMMY_REQUEST_ID],
# Create dummy for this specific draft_len (happens once per unique draft_len)
# Use unique request ID per draft_len to avoid conflicts
dummy_req_id = CUDA_GRAPH_DUMMY_REQUEST_ID + runtime_draft_len

dummy = kv_cache_manager.add_dummy_requests(
[dummy_req_id],
is_gen=True,
max_num_draft_tokens=engine.runtime_draft_len,
max_num_draft_tokens=runtime_draft_len,
use_mrope=engine.use_mrope,
max_beam_width=engine.max_beam_width)[0]
self.padding_dummy_request.is_cuda_graph_dummy = True
dummy.is_cuda_graph_dummy = True

spec_res_mgr = resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER)
if spec_res_mgr:
spec_res_mgr.add_dummy_requests([CUDA_GRAPH_DUMMY_REQUEST_ID])
spec_res_mgr.add_dummy_requests([dummy_req_id])

# Store for reuse
self.padding_dummies[runtime_draft_len] = dummy

batch.generation_requests.extend([self.padding_dummy_request] *
padding_size)
# Select the appropriate dummy for current draft_len (zero overhead!)
padding_dummy = self.padding_dummies[runtime_draft_len]
batch.generation_requests.extend([padding_dummy] * padding_size)
return padding_size

def _round_up_batch_size(self, batch_size: int) -> int:
Expand Down Expand Up @@ -397,7 +440,7 @@ def clear(self):
self.graphs.clear()
self.graph_outputs.clear()
self.graph_metadata.clear()
self.padding_dummy_request = None
self.padding_dummies.clear()
del self.memory_pool
self.memory_pool = None
torch.cuda.empty_cache()
197 changes: 151 additions & 46 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,124 @@ def _run_autotuner_warmup(self, resource_manager: ResourceManager):
)
AutoTuner.get().print_profiling_cache()

def _get_runtime_draft_len(self,
scheduled_requests: ScheduledRequests) -> int:
"""
Get the runtime draft length for the current batch.

Stage 2 Dynamic Draft Length: Returns the max_draft_tokens that was set
for this batch based on batch size. This is the value that all requests
in the batch will be padded to.

Note: Individual requests may have fewer draft tokens due to NGram
mismatches or early stopping, but they will all be padded to this value
for CUDA graph compatibility.

Args:
scheduled_requests: The scheduled requests for this iteration

Returns:
The runtime max_draft_tokens for this batch, or 0 if spec decode disabled
"""
if not self.enable_spec_decode:
return 0

# Return the current max_draft_len (which was set based on batch size)
# This is the value that drafter.max_draft_tokens is set to
return self.runtime_draft_len

def _get_graphs_to_capture(self) -> List[Tuple[int, int]]:
"""
Determine which (batch_size, draft_len) combinations to capture CUDA graphs for.

Stage 2 Optimization: When draft_len_schedule is provided, only capture
graphs that will actually be used based on the schedule and batch size padding.
This avoids over-capturing and saves memory + warmup time.

Returns:
List of (batch_size, draft_len) tuples to capture graphs for.
"""
spec_resource_manager = self.resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER) if hasattr(
self, 'resource_manager') else None

# Draft model logic (unchanged)
if self.is_draft_model:
if self.model_is_wrapped and self.is_spec_decode and spec_resource_manager is not None:
from ..speculative.eagle3 import Eagle3ResourceManager
if isinstance(spec_resource_manager, Eagle3ResourceManager):
draft_len = self.original_max_draft_len
return [(bs, draft_len)
for bs in self._cuda_graph_batch_sizes]
draft_len = self.max_draft_len
return [(bs, draft_len) for bs in self._cuda_graph_batch_sizes]

# Target model with schedule: compute exact reachable set
if (self.spec_config and hasattr(self.spec_config, 'draft_len_schedule')
and self.spec_config.draft_len_schedule is not None):

graphs_needed = self._compute_reachable_graphs()
logger.info(
f"Stage 2 Dynamic Draft Length: Capturing {len(graphs_needed)} CUDA graphs "
f"(from schedule {self.spec_config.draft_len_schedule}): {sorted(graphs_needed)}"
)
return sorted(graphs_needed)

# Legacy: all batch sizes with same draft_len(s)
draft_lengths = []
if (self.max_draft_len > 0
and not self.spec_config.spec_dec_mode.use_one_engine()
and self.spec_config.max_concurrency is not None):
draft_lengths.append(0)
draft_lengths.append(self.max_draft_len)

graphs = []
for bs in self._cuda_graph_batch_sizes:
for draft_len in draft_lengths:
graphs.append((bs, draft_len))
return graphs

def _compute_reachable_graphs(self) -> set:
"""
Compute the set of (batch_size, draft_len) pairs that are actually reachable.

Takes into account:
1. Schedule: which draft_len for each actual batch size
2. Batch padding: actual batch size might be padded to larger graph size

Returns:
Set of (batch_size, draft_len) tuples
"""
graphs_needed = set()
schedule = self.spec_config.draft_len_schedule

# For each possible actual batch size
for actual_bs in range(1, self.batch_size + 1):
# Determine draft_len for this actual batch size using same logic as drafter
# Use bisect_right to find the largest threshold <= actual_bs
from bisect import bisect_right
thresholds = list(schedule.keys())
idx = bisect_right(thresholds, actual_bs)
if idx == 0:
draft_len = 0 # Defensive - shouldn't happen with valid schedules
else:
draft_len = schedule[thresholds[idx - 1]]

# Determine padded batch size
padded_bs = self._round_up_to_graph_size(actual_bs)

if padded_bs > 0: # Valid graph size exists
graphs_needed.add((padded_bs, draft_len))

return graphs_needed

def _round_up_to_graph_size(self, actual_bs: int) -> int:
"""Round up actual batch size to nearest CUDA graph batch size."""
for graph_bs in sorted(self._cuda_graph_batch_sizes):
if actual_bs <= graph_bs:
return graph_bs
return 0 # Too large, no graph available

def _run_cuda_graph_warmup(self, resource_manager: ResourceManager):
"""Captures CUDA graphs for various batch sizes and draft lengths."""
if not (self.cuda_graph_runner.enabled
Expand All @@ -572,55 +690,36 @@ def _capture_generation_cuda_graphs(self,
spec_resource_manager = resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER)

# Stage 2 Optimization: Only capture graphs that will actually be used
graphs_to_capture = self._get_graphs_to_capture()

# Reverse order so smaller graphs can reuse memory from larger ones
cuda_graph_batch_sizes = sorted(self._cuda_graph_batch_sizes,
reverse=True)
# Create CUDA graphs for different draft lengths
draft_lengths = []
if self.is_draft_model:
if self.model_is_wrapped and self.is_spec_decode and spec_resource_manager is not None and isinstance(
spec_resource_manager, Eagle3ResourceManager):
# The CDL path uses draft_len > 0 for the number of iterations in the drafting loop.
draft_lengths.append(self.original_max_draft_len)
else:
draft_lengths.append(self.max_draft_len)
else:
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
# so that when we disable spec decode at runtime, we can still run the captured graph.
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
if (self.max_draft_len > 0
and not self.spec_config.spec_dec_mode.use_one_engine()
# Assume that speculation is always on if the user didn't give us a max_concurrency
# value. This will save on memory.
and self.spec_config.max_concurrency is not None):
draft_lengths.append(0)
draft_lengths = [self.max_draft_len]

for bs in cuda_graph_batch_sizes:
if bs > self.batch_size:
graphs_to_capture = sorted(graphs_to_capture, reverse=True)

for batch_size, draft_len in graphs_to_capture:
if batch_size > self.batch_size:
continue

for draft_len in draft_lengths:
warmup_request = self._create_cuda_graph_warmup_request(
resource_manager, bs, draft_len)
with self._release_batch_context(warmup_request,
resource_manager) as batch:
if batch is None:
# No KV cache space, cannot continue capturing graphs
return
warmup_request = self._create_cuda_graph_warmup_request(
resource_manager, batch_size, draft_len)
with self._release_batch_context(warmup_request,
resource_manager) as batch:
if batch is None:
# No KV cache space, cannot continue capturing graphs
return

logger.info(
f"Run generation-only CUDA graph warmup for batch size={bs}, draft_len={draft_len}"
)
logger.info(
f"Run generation-only CUDA graph warmup for batch size={batch_size}, draft_len={draft_len}"
)

self.enable_spec_decode = draft_len > 0 or self.is_draft_model
self._update_draft_inference_state_for_warmup(
batch, draft_len > 0, resource_manager)
self.enable_spec_decode = draft_len > 0 or self.is_draft_model
self._update_draft_inference_state_for_warmup(
batch, draft_len > 0, resource_manager)

self.forward(batch,
new_tensors_device=None,
resource_manager=resource_manager)
torch.cuda.synchronize()
self.forward(batch,
new_tensors_device=None,
resource_manager=resource_manager)
torch.cuda.synchronize()

def _capture_piecewise_cuda_graphs(self, resource_manager: ResourceManager):
"""Captures piecewise CUDA graphs for context/prefill steps via torch.compile."""
Expand Down Expand Up @@ -695,8 +794,9 @@ def _create_warmup_request(
spec_resource_manager = resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER)

# Warmup uses static max since it happens before dynamic updates
available_tokens = kv_cache_manager.get_num_available_tokens(
self.runtime_draft_len)
self.original_max_draft_len)
available_blocks = kv_cache_manager.get_num_free_blocks()
if num_tokens > self.max_num_tokens or num_tokens > available_tokens:
return None
Expand Down Expand Up @@ -736,11 +836,12 @@ def _create_warmup_request(
if num_left_over_tokens > 0:
ctx_token_nums.append(num_left_over_tokens)

# Warmup dummy requests use static max since warmup happens before dynamic updates
ctx_requests = kv_cache_manager.add_dummy_requests(
list(range(num_ctx_requests)),
token_nums=ctx_token_nums,
is_gen=False,
max_num_draft_tokens=self.runtime_draft_len,
max_num_draft_tokens=self.original_max_draft_len,
use_mrope=self.use_mrope)

if spec_resource_manager is not None:
Expand Down Expand Up @@ -2275,8 +2376,12 @@ def forward(
with self.cuda_graph_runner.pad_batch(
scheduled_requests, resource_manager) as padded_requests:

# Stage 2 Dynamic Draft Length: Get runtime draft length from the batch
runtime_draft_len = self._get_runtime_draft_len(
padded_requests) if not self.is_draft_model else None

maybe_graph, maybe_attn_metadata, maybe_spec_metadata, key = self.cuda_graph_runner.maybe_get_cuda_graph(
padded_requests, spec_resource_manager)
padded_requests, spec_resource_manager, runtime_draft_len)
if maybe_graph:
attn_metadata = maybe_attn_metadata
spec_metadata = maybe_spec_metadata
Expand Down
Loading