Skip to content

Commit 00e94d8

Browse files
committed
[TRTLLM-8084][feat] Enhance overlap scheduler for two-model spec decoding
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent 67208f1 commit 00e94d8

File tree

6 files changed

+816
-236
lines changed

6 files changed

+816
-236
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 266 additions & 39 deletions
Large diffs are not rendered by default.

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 141 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,13 @@
3535
from tensorrt_llm.mapping import CpType
3636
from tensorrt_llm.runtime.generation import CUASSERT
3737

38+
from ..attention_backend.trtllm import TrtllmAttention
3839
from ..distributed import Distributed
3940
from ..models.modeling_utils import DecoderModelForCausalLM
4041
from ..modules.decoder_layer import DecoderLayer
4142
from ..speculative.drafter import Drafter
43+
from ..speculative.model_drafter import ModelDrafter
44+
from ..speculative.mtp import SampleStateTensorsMTP
4245
from ..speculative.speculation_gate import SpeculationGate
4346
from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem
4447
from .guided_decoder import GuidedDecoder
@@ -271,8 +274,23 @@ def __init__(self,
271274

272275
if self.dist.pp_size > 1:
273276
self.event_loop = self._executor_loop_pp
277+
elif self.disable_overlap_scheduler:
278+
self.event_loop = self._executor_loop
274279
else:
275-
self.event_loop = self._executor_loop if disable_overlap_scheduler else self._executor_loop_overlap
280+
# TODO: Overlap scheduler is not supported for below cases:
281+
# 1. non-CDL is used
282+
# 2. non-TrtllmAttention attention backend is used
283+
overlap_not_supported = self.drafter is not None and isinstance(
284+
self.drafter, ModelDrafter) and (
285+
not self.drafter.use_static_draft_loop or not issubclass(
286+
self.draft_model_engine.attn_backend, TrtllmAttention))
287+
288+
if overlap_not_supported:
289+
logger.warning(
290+
"Overlap scheduler is disabled for draft model engine with non-CDL or non-TrtllmAttention attention backend."
291+
)
292+
self.disable_overlap_scheduler = True
293+
self.event_loop = self._executor_loop_overlap if not overlap_not_supported else self._executor_loop
276294
if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"):
277295
self.event_loop = trace_func(self.event_loop)
278296

@@ -1045,14 +1063,11 @@ def _prepare_and_schedule_batch(self):
10451063
0
10461064
] * max_total_draft_tokens if max_total_draft_tokens > 0 else []
10471065

1048-
# When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch,
1049-
# we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet.
1050-
if not self.has_previous_draft_tokens:
1051-
# If speculation is off, this function sets py_draft_tokens to []
1052-
# for all active requests. If it's on, we initialize py_draft_tokens
1053-
# with dummy draft tokens to make the scheduler aware of the fact
1054-
# that speculation is about to happen.
1055-
self._prepare_draft_requests()
1066+
# If speculation is off, this function sets py_draft_tokens to []
1067+
# for all active requests. If it's on, we initialize py_draft_tokens
1068+
# with dummy draft tokens to make the scheduler aware of the fact
1069+
# that speculation is about to happen.
1070+
self._prepare_draft_requests()
10561071

10571072
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
10581073
)
@@ -1256,6 +1271,8 @@ def _executor_loop_overlap(self):
12561271
with self._profiler() as profile_step:
12571272
iter_start_time = time.time()
12581273
iter_stats = None
1274+
target_inputs = None
1275+
previous_tensors_device = None
12591276
can_forward = False if self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver else True
12601277
while True:
12611278
profile_step()
@@ -1329,31 +1346,29 @@ def _executor_loop_overlap(self):
13291346
self.guided_decoder.init_disagg_gen_requests()
13301347

13311348
previous_tensors = self.previous_batch and self.previous_batch.sample_state
1332-
target_inputs = None
1333-
draft_outputs = None
13341349
# If there are previous draft tokens, we need to update the target requests to accept some draft tokens.
13351350
# When there's any accepted tokens, we can't directly use the previous batch's outputs in this iteration for the target model,
13361351
# so we'll set the target model's input to None and skip updating the target requests after target model forward.
13371352
use_previous_draft_tokens = self.has_previous_draft_tokens
13381353
if self.drafter is not None and (self.use_spec_decode or
13391354
use_previous_draft_tokens):
1340-
target_inputs, draft_outputs, draft_batch = self._handle_speculative_decoding(
1341-
scheduled_batch, previous_tensors)
1355+
target_inputs = self._handle_speculative_decoding(
1356+
scheduled_batch, previous_tensors,
1357+
previous_tensors_device)
13421358

13431359
# Use the draft_model's outputs if we've launched the draft model.
13441360
# Otherwise, use the previous batch's outputs.
1345-
if target_inputs is not None or use_previous_draft_tokens:
1361+
if (target_inputs is not None
1362+
and target_inputs.next_draft_tokens
1363+
is not None) or use_previous_draft_tokens:
13461364
previous_tensors_device = target_inputs
13471365
else:
13481366
previous_tensors_device = self.previous_batch and self.previous_batch.sample_state and self.previous_batch.sample_state.device
13491367

13501368
batch_outputs = self._forward_step(scheduled_batch,
13511369
previous_tensors_device)
13521370

1353-
if target_inputs is not None:
1354-
self._process_draft_results(scheduled_batch,
1355-
draft_outputs, draft_batch)
1356-
elif self.previous_batch is not None and not use_previous_draft_tokens:
1371+
if self.previous_batch is not None:
13571372
self._update_requests(self.previous_batch.sample_state)
13581373

13591374
if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver:
@@ -1368,6 +1383,10 @@ def _executor_loop_overlap(self):
13681383
(req, block_id,
13691384
self.ctx_in_transmission_counter))
13701385

1386+
if self.drafter is not None and self.use_spec_decode:
1387+
# Cleanup previous draft resources used in the draft model
1388+
self.drafter.cleanup_previous_draft_resources()
1389+
13711390
if self.guided_decoder is not None:
13721391
# add_batch must be called again to have updated new tokens.
13731392
self.guided_decoder.add_batch(scheduled_batch)
@@ -1402,6 +1421,94 @@ def _executor_loop_overlap(self):
14021421

14031422
self._kv_connector_terminate_requests()
14041423

1424+
def _accept_draft_tokens(
1425+
self, scheduled_batch: ScheduledRequests,
1426+
target_outputs: SampleStateTensors,
1427+
target_inputs: Optional[SampleStateTensors]
1428+
) -> Tuple[SampleStateTensorsMTP, Optional[torch.Tensor]]:
1429+
"""
1430+
Prepare target device inputs after computing draft token acceptance.
1431+
1432+
This function:
1433+
1. If draft tokens exist: compares sampled tokens with draft tokens to compute acceptance
1434+
2. If no draft tokens: directly uses the first sampled token
1435+
3. Creates new_tokens by extracting accepted tokens per request
1436+
1437+
Args:
1438+
scheduled_batch: The scheduled requests
1439+
target_outputs: Contains new_tokens [max_draft_len + 1, batch_size, beam_width]
1440+
or [1, batch_size, beam_width] if no draft tokens
1441+
target_inputs: Contains next_draft_tokens [batch_size, max_draft_len]
1442+
Returns:
1443+
Tuple of:
1444+
- SampleStateTensorsMTP with new_tokens set to accepted tokens,
1445+
new_tokens_lens and next_draft_tokens set to None
1446+
- num_accepted_tokens: [batch_size] tensor with acceptance counts per request,
1447+
or None if no draft tokens
1448+
"""
1449+
has_draft_tokens = target_inputs is not None and isinstance(
1450+
target_inputs, SampleStateTensorsMTP
1451+
) and target_inputs.next_draft_tokens is not None
1452+
target_tokens = target_outputs.new_tokens # [max_draft_len + 1, batch_size, beam_width] or [1, batch_size, beam_width]
1453+
new_tokens = torch.zeros_like(target_tokens)
1454+
1455+
# Squeeze the beam dimension (beam_width=1 for greedy or single beam)
1456+
target_tokens = target_tokens.squeeze(
1457+
-1) # [max_draft_len + 1, batch_size] or [1, batch_size]
1458+
1459+
batch_size = target_tokens.shape[1]
1460+
device = target_tokens.device
1461+
# Compute number of accepted tokens per request
1462+
num_accepted_tokens = torch.zeros(batch_size,
1463+
dtype=torch.int32,
1464+
device=device)
1465+
# Handle case where there are no draft tokens
1466+
if has_draft_tokens:
1467+
# Draft tokens exist, compute acceptance
1468+
draft_tokens = target_inputs.next_draft_tokens # [batch_size, max_draft_len]
1469+
max_draft_len = draft_tokens.shape[1]
1470+
1471+
# Compute number of accepted tokens per request
1472+
# Generation requests: compare with draft tokens to find acceptance
1473+
num_contexts = len(scheduled_batch.context_requests)
1474+
if batch_size > num_contexts:
1475+
# Use .T to transpose: [max_draft_len + 1, num_gens] -> [num_gens, max_draft_len + 1]
1476+
gen_target_tokens = target_tokens[:,
1477+
num_contexts:].T # [num_gens, max_draft_len + 1]
1478+
1479+
# Compare draft tokens with target tokens to find acceptance
1480+
# Use cumprod to find the first rejection point
1481+
draft_tokens_gen = draft_tokens[
1482+
num_contexts:, :].int() # [num_gens, max_draft_len]
1483+
num_accepted_tokens[num_contexts:] += torch.cumprod(
1484+
(draft_tokens_gen == gen_target_tokens[:, :max_draft_len]
1485+
).int(),
1486+
dim=-1).sum(dim=1)
1487+
1488+
# Vectorized extraction using advanced indexing (no GPU-CPU sync)
1489+
# Use num_accepted_tokens as indices to gather the right tokens
1490+
batch_indices = torch.arange(batch_size, device=device)
1491+
new_tokens[0, :, 0] = target_tokens[num_accepted_tokens,
1492+
batch_indices]
1493+
else:
1494+
# No draft tokens to accept, just use the first (and only) sampled token
1495+
batch_indices = torch.arange(batch_size, device=device)
1496+
new_tokens[0, :, 0] = target_tokens[0, batch_indices]
1497+
1498+
# Create the updated SampleStateTensorsMTP
1499+
# new_tokens_lens and next_draft_tokens are left as None
1500+
result_tensors = SampleStateTensorsMTP(
1501+
new_tokens=new_tokens,
1502+
log_probs=target_outputs.log_probs,
1503+
new_tokens_lens=None,
1504+
next_draft_tokens=None)
1505+
1506+
# Copy logits if available
1507+
if hasattr(target_outputs, 'logits'):
1508+
result_tensors.logits = target_outputs.logits
1509+
1510+
return result_tensors, num_accepted_tokens
1511+
14051512
def _process_previous_batch(self):
14061513
if self.kv_cache_transceiver and self.previous_batch.ctx_transmission_reqs:
14071514
for req in self.previous_batch.ctx_transmission_reqs:
@@ -2298,7 +2405,8 @@ def _remove_inflight_ids(self, scheduled_requests):
22982405
for req in scheduled_requests.all_requests():
22992406
self.inflight_req_ids.erase(req.request_id)
23002407

2301-
def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
2408+
def _handle_speculative_decoding(self, scheduled_batch, previous_tensors,
2409+
target_inputs):
23022410
with request_context(is_draft=self.draft_model_engine is not None,
23032411
scheduled_requests=scheduled_batch):
23042412
# Do an early checking to see if we need to forward the draft model.
@@ -2308,20 +2416,25 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
23082416
self.previous_batch is not None and self.use_spec_decode
23092417
and self.drafter.should_forward_draft_model(scheduled_batch))
23102418

2311-
if has_draft_batch or self.has_previous_draft_tokens:
2312-
self._update_requests(self.previous_batch.sample_state)
2313-
if self.has_previous_draft_tokens:
2314-
self._prepare_draft_requests()
2419+
new_target_inputs = None
2420+
if has_draft_batch:
2421+
target_outputs = self.previous_batch.sample_state and self.previous_batch.sample_state.device
2422+
assert target_outputs is not None, "target_outputs should not be None"
2423+
new_target_inputs, num_accepted_tokens_device = self._accept_draft_tokens(
2424+
scheduled_batch=scheduled_batch,
2425+
target_inputs=target_inputs,
2426+
target_outputs=target_outputs)
23152427

23162428
if has_draft_batch:
2317-
target_inputs, draft_outputs, draft_batch = self.drafter.generate_draft_tokens_with_overlap(
2429+
self.drafter.generate_draft_tokens_with_overlap(
23182430
scheduled_batch, self.resource_manager,
2319-
previous_tensors.device if previous_tensors else None)
2431+
previous_tensors.device if previous_tensors else None,
2432+
new_target_inputs, num_accepted_tokens_device)
23202433

2321-
self.has_previous_draft_tokens = target_inputs is not None and target_inputs.next_draft_tokens is not None
2434+
# Pad draft tokens to the max draft length for CUDA graph compatibility
2435+
self.has_previous_draft_tokens = new_target_inputs is not None and new_target_inputs.next_draft_tokens is not None
23222436
else:
23232437
self.has_previous_draft_tokens = False
2324-
target_inputs, draft_outputs, draft_batch = None, None, None
23252438
# We are not running the draft model. Remove the draft tokens and turn off spec
23262439
# decode so that the requests get handled correctly.
23272440
# One corner case: when we have at least one context request, we have to keep spec
@@ -2334,34 +2447,7 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
23342447
for request in scheduled_batch.all_requests():
23352448
request.py_draft_tokens = []
23362449

2337-
return target_inputs, draft_outputs, draft_batch
2338-
2339-
def _process_draft_results(self, scheduled_batch, draft_outputs,
2340-
draft_batch):
2341-
"""
2342-
Append the draft tokens to the target requests, and clean up the draft resources.
2343-
"""
2344-
with request_context(is_draft=self.draft_model_engine is not None,
2345-
scheduled_requests=scheduled_batch):
2346-
req_id_to_old_request = {
2347-
req.py_request_id: req
2348-
for req in scheduled_batch.all_requests()
2349-
}
2350-
2351-
if self.drafter.use_static_draft_loop:
2352-
self.drafter.process_static_draft_outputs(
2353-
draft_outputs, draft_batch, req_id_to_old_request)
2354-
elif draft_outputs is not None:
2355-
self.drafter.process_dynamic_draft_outputs(
2356-
draft_outputs, req_id_to_old_request)
2357-
2358-
# Pad draft tokens to the max draft length. This is for CUDA graph compatibility.
2359-
self.drafter.pad_draft_tokens_for_cuda_graph(scheduled_batch)
2360-
# add_batch must be called again to restore to target requests with updated draft tokens.
2361-
if self.guided_decoder is not None:
2362-
self.guided_decoder.add_batch(scheduled_batch)
2363-
if hasattr(self.drafter, "guided_decoder"):
2364-
self.guided_decoder.rollback_draft_tokens()
2450+
return new_target_inputs
23652451

23662452

23672453
class DisaggPPTerminationHandler:

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,9 @@ def _group_requests_by_strategy_key(
277277
for req_index, req in enumerate(requests):
278278
strategy = _request_strategy(req, vocab_size=vocab_size)
279279
strategy_key = strategy_to_key(strategy)
280-
speculation_needs_probs = req.py_draft_logits is not None and strategy is not GREEDY
280+
# In the overlap path, py_draft_logits is not updated yet,
281+
# so we use get_draft_token_length() for the checking.
282+
speculation_needs_probs = get_draft_token_length(req) > 0 and strategy is not GREEDY
281283
group_dict_entry = group_dict[(strategy_key, speculation_needs_probs)]
282284
group_dict_entry[0].append(req_index)
283285
group_dict_entry[1].append(strategy)

0 commit comments

Comments
 (0)