Skip to content

Commit 5920ecb

Browse files
committed
[TRTLLM-8084][feat] Enhance overlap scheduler for two-model spec decoding
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent eeb56c2 commit 5920ecb

File tree

6 files changed

+816
-236
lines changed

6 files changed

+816
-236
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 266 additions & 39 deletions
Large diffs are not rendered by default.

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 141 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,13 @@
3535
from tensorrt_llm.mapping import CpType
3636
from tensorrt_llm.runtime.generation import CUASSERT
3737

38+
from ..attention_backend.trtllm import TrtllmAttention
3839
from ..distributed import Distributed
3940
from ..models.modeling_utils import DecoderModelForCausalLM
4041
from ..modules.decoder_layer import DecoderLayer
4142
from ..speculative.drafter import Drafter
43+
from ..speculative.model_drafter import ModelDrafter
44+
from ..speculative.mtp import SampleStateTensorsMTP
4245
from ..speculative.speculation_gate import SpeculationGate
4346
from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem
4447
from .guided_decoder import GuidedDecoder
@@ -275,8 +278,23 @@ def __init__(self,
275278

276279
if self.dist.pp_size > 1:
277280
self.event_loop = self._executor_loop_pp
281+
elif self.disable_overlap_scheduler:
282+
self.event_loop = self._executor_loop
278283
else:
279-
self.event_loop = self._executor_loop if disable_overlap_scheduler else self._executor_loop_overlap
284+
# TODO: Overlap scheduler is not supported for below cases:
285+
# 1. non-CDL is used
286+
# 2. non-TrtllmAttention attention backend is used
287+
overlap_not_supported = self.drafter is not None and isinstance(
288+
self.drafter, ModelDrafter) and (
289+
not self.drafter.use_static_draft_loop or not issubclass(
290+
self.draft_model_engine.attn_backend, TrtllmAttention))
291+
292+
if overlap_not_supported:
293+
logger.warning(
294+
"Overlap scheduler is disabled for draft model engine with non-CDL or non-TrtllmAttention attention backend."
295+
)
296+
self.disable_overlap_scheduler = True
297+
self.event_loop = self._executor_loop_overlap if not overlap_not_supported else self._executor_loop
280298
if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"):
281299
self.event_loop = trace_func(self.event_loop)
282300

@@ -1060,14 +1078,11 @@ def _prepare_and_schedule_batch(self):
10601078
0
10611079
] * max_total_draft_tokens if max_total_draft_tokens > 0 else []
10621080

1063-
# When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch,
1064-
# we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet.
1065-
if not self.has_previous_draft_tokens:
1066-
# If speculation is off, this function sets py_draft_tokens to []
1067-
# for all active requests. If it's on, we initialize py_draft_tokens
1068-
# with dummy draft tokens to make the scheduler aware of the fact
1069-
# that speculation is about to happen.
1070-
self._prepare_draft_requests()
1081+
# If speculation is off, this function sets py_draft_tokens to []
1082+
# for all active requests. If it's on, we initialize py_draft_tokens
1083+
# with dummy draft tokens to make the scheduler aware of the fact
1084+
# that speculation is about to happen.
1085+
self._prepare_draft_requests()
10711086

10721087
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
10731088
)
@@ -1317,6 +1332,8 @@ def _executor_loop_overlap(self):
13171332
with self._profiler() as profile_step:
13181333
iter_start_time = time.time()
13191334
iter_stats = None
1335+
target_inputs = None
1336+
previous_tensors_device = None
13201337
can_forward = False if self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver else True
13211338
while True:
13221339
profile_step()
@@ -1397,31 +1414,29 @@ def _executor_loop_overlap(self):
13971414
self.guided_decoder.init_disagg_gen_requests()
13981415

13991416
previous_tensors = self.previous_batch and self.previous_batch.sample_state
1400-
target_inputs = None
1401-
draft_outputs = None
14021417
# If there are previous draft tokens, we need to update the target requests to accept some draft tokens.
14031418
# When there's any accepted tokens, we can't directly use the previous batch's outputs in this iteration for the target model,
14041419
# so we'll set the target model's input to None and skip updating the target requests after target model forward.
14051420
use_previous_draft_tokens = self.has_previous_draft_tokens
14061421
if self.drafter is not None and (self.use_spec_decode or
14071422
use_previous_draft_tokens):
1408-
target_inputs, draft_outputs, draft_batch = self._handle_speculative_decoding(
1409-
scheduled_batch, previous_tensors)
1423+
target_inputs = self._handle_speculative_decoding(
1424+
scheduled_batch, previous_tensors,
1425+
previous_tensors_device)
14101426

14111427
# Use the draft_model's outputs if we've launched the draft model.
14121428
# Otherwise, use the previous batch's outputs.
1413-
if target_inputs is not None or use_previous_draft_tokens:
1429+
if (target_inputs is not None
1430+
and target_inputs.next_draft_tokens
1431+
is not None) or use_previous_draft_tokens:
14141432
previous_tensors_device = target_inputs
14151433
else:
14161434
previous_tensors_device = self.previous_batch and self.previous_batch.sample_state and self.previous_batch.sample_state.device
14171435

14181436
batch_outputs = self._forward_step(scheduled_batch,
14191437
previous_tensors_device)
14201438

1421-
if target_inputs is not None:
1422-
self._process_draft_results(scheduled_batch,
1423-
draft_outputs, draft_batch)
1424-
elif self.previous_batch is not None and not use_previous_draft_tokens:
1439+
if self.previous_batch is not None:
14251440
self._update_requests(self.previous_batch.sample_state)
14261441

14271442
if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver:
@@ -1436,6 +1451,10 @@ def _executor_loop_overlap(self):
14361451
(req, block_id,
14371452
self.ctx_in_transmission_counter))
14381453

1454+
if self.drafter is not None and self.use_spec_decode:
1455+
# Cleanup previous draft resources used in the draft model
1456+
self.drafter.cleanup_previous_draft_resources()
1457+
14391458
if self.guided_decoder is not None:
14401459
# add_batch must be called again to have updated new tokens.
14411460
self.guided_decoder.add_batch(scheduled_batch)
@@ -1470,6 +1489,94 @@ def _executor_loop_overlap(self):
14701489

14711490
self._kv_connector_terminate_requests()
14721491

1492+
def _accept_draft_tokens(
1493+
self, scheduled_batch: ScheduledRequests,
1494+
target_outputs: SampleStateTensors,
1495+
target_inputs: Optional[SampleStateTensors]
1496+
) -> Tuple[SampleStateTensorsMTP, Optional[torch.Tensor]]:
1497+
"""
1498+
Prepare target device inputs after computing draft token acceptance.
1499+
1500+
This function:
1501+
1. If draft tokens exist: compares sampled tokens with draft tokens to compute acceptance
1502+
2. If no draft tokens: directly uses the first sampled token
1503+
3. Creates new_tokens by extracting accepted tokens per request
1504+
1505+
Args:
1506+
scheduled_batch: The scheduled requests
1507+
target_outputs: Contains new_tokens [max_draft_len + 1, batch_size, beam_width]
1508+
or [1, batch_size, beam_width] if no draft tokens
1509+
target_inputs: Contains next_draft_tokens [batch_size, max_draft_len]
1510+
Returns:
1511+
Tuple of:
1512+
- SampleStateTensorsMTP with new_tokens set to accepted tokens,
1513+
new_tokens_lens and next_draft_tokens set to None
1514+
- num_accepted_tokens: [batch_size] tensor with acceptance counts per request,
1515+
or None if no draft tokens
1516+
"""
1517+
has_draft_tokens = target_inputs is not None and isinstance(
1518+
target_inputs, SampleStateTensorsMTP
1519+
) and target_inputs.next_draft_tokens is not None
1520+
target_tokens = target_outputs.new_tokens # [max_draft_len + 1, batch_size, beam_width] or [1, batch_size, beam_width]
1521+
new_tokens = torch.zeros_like(target_tokens)
1522+
1523+
# Squeeze the beam dimension (beam_width=1 for greedy or single beam)
1524+
target_tokens = target_tokens.squeeze(
1525+
-1) # [max_draft_len + 1, batch_size] or [1, batch_size]
1526+
1527+
batch_size = target_tokens.shape[1]
1528+
device = target_tokens.device
1529+
# Compute number of accepted tokens per request
1530+
num_accepted_tokens = torch.zeros(batch_size,
1531+
dtype=torch.int32,
1532+
device=device)
1533+
# Handle case where there are no draft tokens
1534+
if has_draft_tokens:
1535+
# Draft tokens exist, compute acceptance
1536+
draft_tokens = target_inputs.next_draft_tokens # [batch_size, max_draft_len]
1537+
max_draft_len = draft_tokens.shape[1]
1538+
1539+
# Compute number of accepted tokens per request
1540+
# Generation requests: compare with draft tokens to find acceptance
1541+
num_contexts = len(scheduled_batch.context_requests)
1542+
if batch_size > num_contexts:
1543+
# Use .T to transpose: [max_draft_len + 1, num_gens] -> [num_gens, max_draft_len + 1]
1544+
gen_target_tokens = target_tokens[:,
1545+
num_contexts:].T # [num_gens, max_draft_len + 1]
1546+
1547+
# Compare draft tokens with target tokens to find acceptance
1548+
# Use cumprod to find the first rejection point
1549+
draft_tokens_gen = draft_tokens[
1550+
num_contexts:, :].int() # [num_gens, max_draft_len]
1551+
num_accepted_tokens[num_contexts:] += torch.cumprod(
1552+
(draft_tokens_gen == gen_target_tokens[:, :max_draft_len]
1553+
).int(),
1554+
dim=-1).sum(dim=1)
1555+
1556+
# Vectorized extraction using advanced indexing (no GPU-CPU sync)
1557+
# Use num_accepted_tokens as indices to gather the right tokens
1558+
batch_indices = torch.arange(batch_size, device=device)
1559+
new_tokens[0, :, 0] = target_tokens[num_accepted_tokens,
1560+
batch_indices]
1561+
else:
1562+
# No draft tokens to accept, just use the first (and only) sampled token
1563+
batch_indices = torch.arange(batch_size, device=device)
1564+
new_tokens[0, :, 0] = target_tokens[0, batch_indices]
1565+
1566+
# Create the updated SampleStateTensorsMTP
1567+
# new_tokens_lens and next_draft_tokens are left as None
1568+
result_tensors = SampleStateTensorsMTP(
1569+
new_tokens=new_tokens,
1570+
log_probs=target_outputs.log_probs,
1571+
new_tokens_lens=None,
1572+
next_draft_tokens=None)
1573+
1574+
# Copy logits if available
1575+
if hasattr(target_outputs, 'logits'):
1576+
result_tensors.logits = target_outputs.logits
1577+
1578+
return result_tensors, num_accepted_tokens
1579+
14731580
def _process_previous_batch(self):
14741581
if self.kv_cache_transceiver and self.previous_batch.ctx_transmission_reqs:
14751582
for req in self.previous_batch.ctx_transmission_reqs:
@@ -2366,7 +2473,8 @@ def _remove_inflight_ids(self, scheduled_requests):
23662473
for req in scheduled_requests.all_requests():
23672474
self.inflight_req_ids.erase(req.request_id)
23682475

2369-
def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
2476+
def _handle_speculative_decoding(self, scheduled_batch, previous_tensors,
2477+
target_inputs):
23702478
with request_context(is_draft=self.draft_model_engine is not None,
23712479
scheduled_requests=scheduled_batch):
23722480
# Do an early checking to see if we need to forward the draft model.
@@ -2376,20 +2484,25 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
23762484
self.previous_batch is not None and self.use_spec_decode
23772485
and self.drafter.should_forward_draft_model(scheduled_batch))
23782486

2379-
if has_draft_batch or self.has_previous_draft_tokens:
2380-
self._update_requests(self.previous_batch.sample_state)
2381-
if self.has_previous_draft_tokens:
2382-
self._prepare_draft_requests()
2487+
new_target_inputs = None
2488+
if has_draft_batch:
2489+
target_outputs = self.previous_batch.sample_state and self.previous_batch.sample_state.device
2490+
assert target_outputs is not None, "target_outputs should not be None"
2491+
new_target_inputs, num_accepted_tokens_device = self._accept_draft_tokens(
2492+
scheduled_batch=scheduled_batch,
2493+
target_inputs=target_inputs,
2494+
target_outputs=target_outputs)
23832495

23842496
if has_draft_batch:
2385-
target_inputs, draft_outputs, draft_batch = self.drafter.generate_draft_tokens_with_overlap(
2497+
self.drafter.generate_draft_tokens_with_overlap(
23862498
scheduled_batch, self.resource_manager,
2387-
previous_tensors.device if previous_tensors else None)
2499+
previous_tensors.device if previous_tensors else None,
2500+
new_target_inputs, num_accepted_tokens_device)
23882501

2389-
self.has_previous_draft_tokens = target_inputs is not None and target_inputs.next_draft_tokens is not None
2502+
# Pad draft tokens to the max draft length for CUDA graph compatibility
2503+
self.has_previous_draft_tokens = new_target_inputs is not None and new_target_inputs.next_draft_tokens is not None
23902504
else:
23912505
self.has_previous_draft_tokens = False
2392-
target_inputs, draft_outputs, draft_batch = None, None, None
23932506
# We are not running the draft model. Remove the draft tokens and turn off spec
23942507
# decode so that the requests get handled correctly.
23952508
# One corner case: when we have at least one context request, we have to keep spec
@@ -2402,34 +2515,7 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
24022515
for request in scheduled_batch.all_requests():
24032516
request.py_draft_tokens = []
24042517

2405-
return target_inputs, draft_outputs, draft_batch
2406-
2407-
def _process_draft_results(self, scheduled_batch, draft_outputs,
2408-
draft_batch):
2409-
"""
2410-
Append the draft tokens to the target requests, and clean up the draft resources.
2411-
"""
2412-
with request_context(is_draft=self.draft_model_engine is not None,
2413-
scheduled_requests=scheduled_batch):
2414-
req_id_to_old_request = {
2415-
req.py_request_id: req
2416-
for req in scheduled_batch.all_requests()
2417-
}
2418-
2419-
if self.drafter.use_static_draft_loop:
2420-
self.drafter.process_static_draft_outputs(
2421-
draft_outputs, draft_batch, req_id_to_old_request)
2422-
elif draft_outputs is not None:
2423-
self.drafter.process_dynamic_draft_outputs(
2424-
draft_outputs, req_id_to_old_request)
2425-
2426-
# Pad draft tokens to the max draft length. This is for CUDA graph compatibility.
2427-
self.drafter.pad_draft_tokens_for_cuda_graph(scheduled_batch)
2428-
# add_batch must be called again to restore to target requests with updated draft tokens.
2429-
if self.guided_decoder is not None:
2430-
self.guided_decoder.add_batch(scheduled_batch)
2431-
if hasattr(self.drafter, "guided_decoder"):
2432-
self.guided_decoder.rollback_draft_tokens()
2518+
return new_target_inputs
24332519

24342520
def reset_prefix_cache(self):
24352521
self.kv_cache_manager.reset_reuse_state()

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,9 @@ def _group_requests_by_strategy_key(
277277
for req_index, req in enumerate(requests):
278278
strategy = _request_strategy(req, vocab_size=vocab_size)
279279
strategy_key = strategy_to_key(strategy)
280-
speculation_needs_probs = req.py_draft_logits is not None and strategy is not GREEDY
280+
# In the overlap path, py_draft_logits is not updated yet,
281+
# so we use get_draft_token_length() for the checking.
282+
speculation_needs_probs = get_draft_token_length(req) > 0 and strategy is not GREEDY
281283
group_dict_entry = group_dict[(strategy_key, speculation_needs_probs)]
282284
group_dict_entry[0].append(req_index)
283285
group_dict_entry[1].append(strategy)

0 commit comments

Comments
 (0)