Skip to content

Commit 55a6c7a

Browse files
committed
[TRTLLM-8084][feat] Enhance overlap scheduler for two-model spec decoding
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent 67208f1 commit 55a6c7a

File tree

6 files changed

+810
-236
lines changed

6 files changed

+810
-236
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 266 additions & 39 deletions
Large diffs are not rendered by default.

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 135 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,12 @@
3535
from tensorrt_llm.mapping import CpType
3636
from tensorrt_llm.runtime.generation import CUASSERT
3737

38+
from ..attention_backend.trtllm import TrtllmAttention
3839
from ..distributed import Distributed
3940
from ..models.modeling_utils import DecoderModelForCausalLM
4041
from ..modules.decoder_layer import DecoderLayer
4142
from ..speculative.drafter import Drafter
43+
from ..speculative.mtp import SampleStateTensorsMTP
4244
from ..speculative.speculation_gate import SpeculationGate
4345
from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem
4446
from .guided_decoder import GuidedDecoder
@@ -271,8 +273,18 @@ def __init__(self,
271273

272274
if self.dist.pp_size > 1:
273275
self.event_loop = self._executor_loop_pp
276+
elif self.disable_overlap_scheduler:
277+
self.event_loop = self._executor_loop
278+
elif self.drafter is not None and (
279+
not self.drafter.use_static_draft_loop or not issubclass(
280+
self.draft_model_engine.attn_backend, TrtllmAttention)):
281+
logger.warning(
282+
"Overlap scheduler is disabled for draft model engine with non-CDL or non-TrtllmAttention attention backend."
283+
)
284+
self.disable_overlap_scheduler = True
285+
self.event_loop = self._executor_loop
274286
else:
275-
self.event_loop = self._executor_loop if disable_overlap_scheduler else self._executor_loop_overlap
287+
self.event_loop = self._executor_loop_overlap
276288
if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"):
277289
self.event_loop = trace_func(self.event_loop)
278290

@@ -1045,14 +1057,11 @@ def _prepare_and_schedule_batch(self):
10451057
0
10461058
] * max_total_draft_tokens if max_total_draft_tokens > 0 else []
10471059

1048-
# When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch,
1049-
# we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet.
1050-
if not self.has_previous_draft_tokens:
1051-
# If speculation is off, this function sets py_draft_tokens to []
1052-
# for all active requests. If it's on, we initialize py_draft_tokens
1053-
# with dummy draft tokens to make the scheduler aware of the fact
1054-
# that speculation is about to happen.
1055-
self._prepare_draft_requests()
1060+
# If speculation is off, this function sets py_draft_tokens to []
1061+
# for all active requests. If it's on, we initialize py_draft_tokens
1062+
# with dummy draft tokens to make the scheduler aware of the fact
1063+
# that speculation is about to happen.
1064+
self._prepare_draft_requests()
10561065

10571066
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
10581067
)
@@ -1256,6 +1265,8 @@ def _executor_loop_overlap(self):
12561265
with self._profiler() as profile_step:
12571266
iter_start_time = time.time()
12581267
iter_stats = None
1268+
target_inputs = None
1269+
previous_tensors_device = None
12591270
can_forward = False if self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver else True
12601271
while True:
12611272
profile_step()
@@ -1329,31 +1340,29 @@ def _executor_loop_overlap(self):
13291340
self.guided_decoder.init_disagg_gen_requests()
13301341

13311342
previous_tensors = self.previous_batch and self.previous_batch.sample_state
1332-
target_inputs = None
1333-
draft_outputs = None
13341343
# If there are previous draft tokens, we need to update the target requests to accept some draft tokens.
13351344
# When there's any accepted tokens, we can't directly use the previous batch's outputs in this iteration for the target model,
13361345
# so we'll set the target model's input to None and skip updating the target requests after target model forward.
13371346
use_previous_draft_tokens = self.has_previous_draft_tokens
13381347
if self.drafter is not None and (self.use_spec_decode or
13391348
use_previous_draft_tokens):
1340-
target_inputs, draft_outputs, draft_batch = self._handle_speculative_decoding(
1341-
scheduled_batch, previous_tensors)
1349+
target_inputs = self._handle_speculative_decoding(
1350+
scheduled_batch, previous_tensors,
1351+
previous_tensors_device)
13421352

13431353
# Use the draft_model's outputs if we've launched the draft model.
13441354
# Otherwise, use the previous batch's outputs.
1345-
if target_inputs is not None or use_previous_draft_tokens:
1355+
if (target_inputs is not None
1356+
and target_inputs.next_draft_tokens
1357+
is not None) or use_previous_draft_tokens:
13461358
previous_tensors_device = target_inputs
13471359
else:
13481360
previous_tensors_device = self.previous_batch and self.previous_batch.sample_state and self.previous_batch.sample_state.device
13491361

13501362
batch_outputs = self._forward_step(scheduled_batch,
13511363
previous_tensors_device)
13521364

1353-
if target_inputs is not None:
1354-
self._process_draft_results(scheduled_batch,
1355-
draft_outputs, draft_batch)
1356-
elif self.previous_batch is not None and not use_previous_draft_tokens:
1365+
if self.previous_batch is not None:
13571366
self._update_requests(self.previous_batch.sample_state)
13581367

13591368
if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver:
@@ -1368,6 +1377,10 @@ def _executor_loop_overlap(self):
13681377
(req, block_id,
13691378
self.ctx_in_transmission_counter))
13701379

1380+
if self.drafter is not None and self.use_spec_decode:
1381+
# Cleanup previous draft resources used in the draft model
1382+
self.drafter.cleanup_previous_draft_resources()
1383+
13711384
if self.guided_decoder is not None:
13721385
# add_batch must be called again to have updated new tokens.
13731386
self.guided_decoder.add_batch(scheduled_batch)
@@ -1402,6 +1415,94 @@ def _executor_loop_overlap(self):
14021415

14031416
self._kv_connector_terminate_requests()
14041417

1418+
def _accept_draft_tokens(
1419+
self, scheduled_batch: ScheduledRequests,
1420+
target_outputs: SampleStateTensors,
1421+
target_inputs: Optional[SampleStateTensors]
1422+
) -> Tuple[SampleStateTensorsMTP, Optional[torch.Tensor]]:
1423+
"""
1424+
Prepare target device inputs after computing draft token acceptance.
1425+
1426+
This function:
1427+
1. If draft tokens exist: compares sampled tokens with draft tokens to compute acceptance
1428+
2. If no draft tokens: directly uses the first sampled token
1429+
3. Creates new_tokens by extracting accepted tokens per request
1430+
1431+
Args:
1432+
scheduled_batch: The scheduled requests
1433+
target_outputs: Contains new_tokens [max_draft_len + 1, batch_size, beam_width]
1434+
or [1, batch_size, beam_width] if no draft tokens
1435+
target_inputs: Contains next_draft_tokens [batch_size, max_draft_len]
1436+
Returns:
1437+
Tuple of:
1438+
- SampleStateTensorsMTP with new_tokens set to accepted tokens,
1439+
new_tokens_lens and next_draft_tokens set to None
1440+
- num_accepted_tokens: [batch_size] tensor with acceptance counts per request,
1441+
or None if no draft tokens
1442+
"""
1443+
has_draft_tokens = target_inputs is not None and isinstance(
1444+
target_inputs, SampleStateTensorsMTP
1445+
) and target_inputs.next_draft_tokens is not None
1446+
target_tokens = target_outputs.new_tokens # [max_draft_len + 1, batch_size, beam_width] or [1, batch_size, beam_width]
1447+
new_tokens = torch.zeros_like(target_tokens)
1448+
1449+
# Squeeze the beam dimension (beam_width=1 for greedy or single beam)
1450+
target_tokens = target_tokens.squeeze(
1451+
-1) # [max_draft_len + 1, batch_size] or [1, batch_size]
1452+
1453+
batch_size = target_tokens.shape[1]
1454+
device = target_tokens.device
1455+
# Compute number of accepted tokens per request
1456+
num_accepted_tokens = torch.zeros(batch_size,
1457+
dtype=torch.int32,
1458+
device=device)
1459+
# Handle case where there are no draft tokens
1460+
if has_draft_tokens:
1461+
# Draft tokens exist, compute acceptance
1462+
draft_tokens = target_inputs.next_draft_tokens # [batch_size, max_draft_len]
1463+
max_draft_len = draft_tokens.shape[1]
1464+
1465+
# Compute number of accepted tokens per request
1466+
# Generation requests: compare with draft tokens to find acceptance
1467+
num_contexts = len(scheduled_batch.context_requests)
1468+
if batch_size > num_contexts:
1469+
# Use .T to transpose: [max_draft_len + 1, num_gens] -> [num_gens, max_draft_len + 1]
1470+
gen_target_tokens = target_tokens[:,
1471+
num_contexts:].T # [num_gens, max_draft_len + 1]
1472+
1473+
# Compare draft tokens with target tokens to find acceptance
1474+
# Use cumprod to find the first rejection point
1475+
draft_tokens_gen = draft_tokens[
1476+
num_contexts:, :].int() # [num_gens, max_draft_len]
1477+
num_accepted_tokens[num_contexts:] += torch.cumprod(
1478+
(draft_tokens_gen == gen_target_tokens[:, :max_draft_len]
1479+
).int(),
1480+
dim=-1).sum(dim=1)
1481+
1482+
# Vectorized extraction using advanced indexing (no GPU-CPU sync)
1483+
# Use num_accepted_tokens as indices to gather the right tokens
1484+
batch_indices = torch.arange(batch_size, device=device)
1485+
new_tokens[0, :, 0] = target_tokens[num_accepted_tokens,
1486+
batch_indices]
1487+
else:
1488+
# No draft tokens to accept, just use the first (and only) sampled token
1489+
batch_indices = torch.arange(batch_size, device=device)
1490+
new_tokens[0, :, 0] = target_tokens[0, batch_indices]
1491+
1492+
# Create the updated SampleStateTensorsMTP
1493+
# new_tokens_lens and next_draft_tokens are left as None
1494+
result_tensors = SampleStateTensorsMTP(
1495+
new_tokens=new_tokens,
1496+
log_probs=target_outputs.log_probs,
1497+
new_tokens_lens=None,
1498+
next_draft_tokens=None)
1499+
1500+
# Copy logits if available
1501+
if hasattr(target_outputs, 'logits'):
1502+
result_tensors.logits = target_outputs.logits
1503+
1504+
return result_tensors, num_accepted_tokens
1505+
14051506
def _process_previous_batch(self):
14061507
if self.kv_cache_transceiver and self.previous_batch.ctx_transmission_reqs:
14071508
for req in self.previous_batch.ctx_transmission_reqs:
@@ -2298,7 +2399,8 @@ def _remove_inflight_ids(self, scheduled_requests):
22982399
for req in scheduled_requests.all_requests():
22992400
self.inflight_req_ids.erase(req.request_id)
23002401

2301-
def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
2402+
def _handle_speculative_decoding(self, scheduled_batch, previous_tensors,
2403+
target_inputs):
23022404
with request_context(is_draft=self.draft_model_engine is not None,
23032405
scheduled_requests=scheduled_batch):
23042406
# Do an early checking to see if we need to forward the draft model.
@@ -2308,20 +2410,25 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
23082410
self.previous_batch is not None and self.use_spec_decode
23092411
and self.drafter.should_forward_draft_model(scheduled_batch))
23102412

2311-
if has_draft_batch or self.has_previous_draft_tokens:
2312-
self._update_requests(self.previous_batch.sample_state)
2313-
if self.has_previous_draft_tokens:
2314-
self._prepare_draft_requests()
2413+
new_target_inputs = None
2414+
if has_draft_batch:
2415+
target_outputs = self.previous_batch.sample_state and self.previous_batch.sample_state.device
2416+
assert target_outputs is not None, "target_outputs should not be None"
2417+
new_target_inputs, num_accepted_tokens_device = self._accept_draft_tokens(
2418+
scheduled_batch=scheduled_batch,
2419+
target_inputs=target_inputs,
2420+
target_outputs=target_outputs)
23152421

23162422
if has_draft_batch:
2317-
target_inputs, draft_outputs, draft_batch = self.drafter.generate_draft_tokens_with_overlap(
2423+
self.drafter.generate_draft_tokens_with_overlap(
23182424
scheduled_batch, self.resource_manager,
2319-
previous_tensors.device if previous_tensors else None)
2425+
previous_tensors.device if previous_tensors else None,
2426+
new_target_inputs, num_accepted_tokens_device)
23202427

2321-
self.has_previous_draft_tokens = target_inputs is not None and target_inputs.next_draft_tokens is not None
2428+
# Pad draft tokens to the max draft length for CUDA graph compatibility
2429+
self.has_previous_draft_tokens = new_target_inputs is not None and new_target_inputs.next_draft_tokens is not None
23222430
else:
23232431
self.has_previous_draft_tokens = False
2324-
target_inputs, draft_outputs, draft_batch = None, None, None
23252432
# We are not running the draft model. Remove the draft tokens and turn off spec
23262433
# decode so that the requests get handled correctly.
23272434
# One corner case: when we have at least one context request, we have to keep spec
@@ -2334,34 +2441,7 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
23342441
for request in scheduled_batch.all_requests():
23352442
request.py_draft_tokens = []
23362443

2337-
return target_inputs, draft_outputs, draft_batch
2338-
2339-
def _process_draft_results(self, scheduled_batch, draft_outputs,
2340-
draft_batch):
2341-
"""
2342-
Append the draft tokens to the target requests, and clean up the draft resources.
2343-
"""
2344-
with request_context(is_draft=self.draft_model_engine is not None,
2345-
scheduled_requests=scheduled_batch):
2346-
req_id_to_old_request = {
2347-
req.py_request_id: req
2348-
for req in scheduled_batch.all_requests()
2349-
}
2350-
2351-
if self.drafter.use_static_draft_loop:
2352-
self.drafter.process_static_draft_outputs(
2353-
draft_outputs, draft_batch, req_id_to_old_request)
2354-
elif draft_outputs is not None:
2355-
self.drafter.process_dynamic_draft_outputs(
2356-
draft_outputs, req_id_to_old_request)
2357-
2358-
# Pad draft tokens to the max draft length. This is for CUDA graph compatibility.
2359-
self.drafter.pad_draft_tokens_for_cuda_graph(scheduled_batch)
2360-
# add_batch must be called again to restore to target requests with updated draft tokens.
2361-
if self.guided_decoder is not None:
2362-
self.guided_decoder.add_batch(scheduled_batch)
2363-
if hasattr(self.drafter, "guided_decoder"):
2364-
self.guided_decoder.rollback_draft_tokens()
2444+
return new_target_inputs
23652445

23662446

23672447
class DisaggPPTerminationHandler:

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,9 @@ def _group_requests_by_strategy_key(
277277
for req_index, req in enumerate(requests):
278278
strategy = _request_strategy(req, vocab_size=vocab_size)
279279
strategy_key = strategy_to_key(strategy)
280-
speculation_needs_probs = req.py_draft_logits is not None and strategy is not GREEDY
280+
# In the overlap path, py_draft_logits is not updated yet,
281+
# so we use get_draft_token_length() for the checking.
282+
speculation_needs_probs = get_draft_token_length(req) > 0 and strategy is not GREEDY
281283
group_dict_entry = group_dict[(strategy_key, speculation_needs_probs)]
282284
group_dict_entry[0].append(req_index)
283285
group_dict_entry[1].append(strategy)

0 commit comments

Comments
 (0)