Skip to content

Commit a1f626c

Browse files
committed
chore: apply d2t in ModelDrafter, not TorchSampler
Signed-off-by: ixlmar <[email protected]>
1 parent 98b3af4 commit a1f626c

File tree

3 files changed

+32
-51
lines changed

3 files changed

+32
-51
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,9 +1055,6 @@ def _tree_sampling_batch(self, requests: list[LlmRequest],
10551055
seq_slots, top_k_list_cumsum[i] -
10561056
top_k_list_i:top_k_list_cumsum[i]] = indices[request_index]
10571057

1058-
# 5) Append eagle3 d2t.
1059-
self._apply_d2t(new_draft_tokens_cuda, model_outputs)
1060-
10611058
# 6) Copy back to the output tensor.
10621059
int_new_draft_tokens = new_draft_tokens_cuda.transpose(0, 1).to(
10631060
torch.int, non_blocking=True).unsqueeze(dim=-1)
@@ -1206,16 +1203,6 @@ def sample_async(
12061203
host=SampleStateTensors(new_tokens=new_tokens_host),
12071204
sampler_event=sampler_event)
12081205

1209-
@staticmethod
1210-
def _apply_d2t(tokens: torch.Tensor, model_outputs) -> None:
1211-
"""Applies draft-to-target token translation table.
1212-
1213-
Modifies tokens in-place.
1214-
"""
1215-
if "d2t" in model_outputs:
1216-
d2t = model_outputs["d2t"][tokens]
1217-
tokens += d2t
1218-
12191206
@staticmethod
12201207
def _apply_embedding_bias(
12211208
logits: torch.Tensor,
@@ -1332,16 +1319,6 @@ def _sample_batched_by_strategy(
13321319
requests, pin_memory=True)
13331320
generator_cuda = self.get_generator(cuda_device)
13341321

1335-
# FIXME: This check should/could be performed in ModelDrafter.prepare_draft_tokens
1336-
#
1337-
# NB: Currently, "d2t" is applied to draft tokens, but not to draft logits,
1338-
# breaking _process_draft_tokens_rejection_sampling.
1339-
needs_d2t = "d2t" in model_outputs
1340-
if needs_d2t and (len(requests_by_strategy) > 1 or
1341-
(requests_by_strategy
1342-
and next(iter(requests_by_strategy)) != GREEDY)):
1343-
raise ValueError("d2t does not yet support non-greedy sampling")
1344-
13451322
# Indexer for accessing tokens in 'logits_cuda', corresponding to the
13461323
# requests in 'requests'.
13471324
logits_cuda_indexer = _PackedStepIndexer(
@@ -1459,17 +1436,6 @@ def _sample_batched_by_strategy(
14591436
(batch_req_indices, batch_next_tokens_cuda_int,
14601437
batch_softmax_cuda), = batched_results
14611438

1462-
# FIXME: This should be done in ModelDrafter.prepare_draft_tokens, but for performance
1463-
# parity py_draft_tokens might need to be replaced / backed by a torch.Tensor, so
1464-
# that d2t can be applied in a batched manner similar to the code below.
1465-
if needs_d2t:
1466-
# NB: The sampler is either called directly by PyExecutor, for the target model,
1467-
# or by ModelDrafter.prepare_draft_tokens(), for the draft model. In the former
1468-
# case there are 1 + get_draft_token_length(request) tokens per request. In the
1469-
# latter case, only there is always only 1 token per request because draft
1470-
# tokens are sampled one-by-one.
1471-
self._apply_d2t(batch_next_tokens_cuda_int, model_outputs)
1472-
14731439
return _BatchedSamplingResult(
14741440
batch_req_indices=batch_req_indices,
14751441
batch_next_tokens_cuda_int=batch_next_tokens_cuda_int,

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import traceback
4-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
4+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
55

66
import torch
77

@@ -73,14 +73,30 @@ def __init__(
7373
self.guided_decoder = guided_decoder
7474

7575
self.use_static_draft_loop = draft_model_engine.model_is_wrapped
76+
self.d2t: Optional[torch.Tensor] = None
7677
if self.use_static_draft_loop:
7778
# TODO: enable sampling/guided decoding on static draft loop
7879
assert guided_decoder is None
7980
assert spec_config._allow_greedy_draft_tokens
81+
else:
82+
# Handle d2t data if available. Static drafting loops should incorporate d2t
83+
# in their implementations.
84+
if hasattr(self.draft_model_engine.model.model, "d2t"):
85+
self.d2t = self.draft_model_engine.model.model.d2t.data
86+
self.d2t_host: Optional[torch.Tensor] = None
87+
if self.d2t is not None:
88+
self.d2t_host = self.d2t.to(device="cpu")
8089

8190
def _create_draft_request(self, request: LlmRequest,
8291
input_tokens: Optional[List]) -> LlmRequest:
8392
"""Create a draft request with common parameters."""
93+
needs_probs = self.sampler.should_provide_draft_probs(request)
94+
95+
# NB: Currently, "d2t" is applied to draft tokens, but not to draft logits,
96+
# breaking _process_draft_tokens_rejection_sampling.
97+
if self.d2t is not None and needs_probs:
98+
raise ValueError("d2t does not yet support non-greedy sampling")
99+
84100
return LlmRequest(
85101
input_tokens=input_tokens,
86102
request_id=request.py_request_id,
@@ -94,8 +110,7 @@ def _create_draft_request(self, request: LlmRequest,
94110
True, # prepare_draft_tokens uses overlap scheduling
95111
is_draft=True,
96112
# NB: self.sampler is shared with PyExecutor
97-
return_generation_logits=self.sampler.should_provide_draft_probs(
98-
request))
113+
return_generation_logits=needs_probs)
99114

100115
def _initialize_draft_tokens(self, request: LlmRequest) -> Tuple[int, int]:
101116
"""Initialize draft token tracking for a request."""
@@ -301,12 +316,6 @@ def forward_draft_model(
301316
resource_manager,
302317
new_tensors_device=previous_tensors)
303318

304-
# Handle d2t data if available. Static drafting loops should incorporate d2t
305-
# in their implementations.
306-
if not self.use_static_draft_loop and hasattr(
307-
self.draft_model_engine.model.model, 'd2t'):
308-
outputs['d2t'] = self.draft_model_engine.model.model.d2t.data
309-
310319
return outputs
311320

312321
def sample_async(
@@ -365,6 +374,7 @@ def update_requests(
365374
"""Update requests with sample state."""
366375
self.sampler.update_requests(sample_state, resource_manager)
367376

377+
@torch.inference_mode()
368378
def process_decoded_tokens(
369379
self, draft_batch: ScheduledRequests,
370380
req_id_to_old_request: Dict[int, LlmRequest]) -> List[LlmRequest]:
@@ -378,7 +388,16 @@ def process_decoded_tokens(
378388
self.draft_seq_slot_manager.free_resources(req)
379389
continue
380390

381-
target_model_req.py_draft_tokens.append(req.get_last_tokens(0))
391+
draft_tokens: List[int] = req.get_last_tokens(0)
392+
if self.d2t_host is not None:
393+
# NB: This is not batched over requests, but considered acceptable given
394+
# that the code already loops over requests and there are few draft
395+
# tokens per request.
396+
draft_tokens = [
397+
tok + cast(int, self.d2t_host[tok].item())
398+
for tok in draft_tokens
399+
]
400+
target_model_req.py_draft_tokens.append(draft_tokens)
382401
target_model_req.py_draft_logits = req.py_result.generation_logits # forwards Nones
383402
if req.state != LlmRequestState.GENERATION_COMPLETE and len(
384403
target_model_req.py_draft_tokens
@@ -591,8 +610,7 @@ def _execute_draft_iteration(
591610

592611
if self.guided_decoder is not None:
593612
self.guided_decoder.add_batch(draft_batch)
594-
self.guided_decoder.execute(outputs['logits'],
595-
d2t=outputs.get('d2t'))
613+
self.guided_decoder.execute(outputs['logits'], d2t=self.d2t)
596614

597615
sample_state = self.sample_async(draft_batch, outputs, resource_manager)
598616
self.update_request_states(draft_batch)
@@ -726,8 +744,7 @@ def generate_draft_tokens_with_overlap(
726744
# Handle guided decoder and sampling for non-static loop
727745
if self.guided_decoder is not None:
728746
self.guided_decoder.add_batch(draft_batch)
729-
self.guided_decoder.execute(outputs['logits'],
730-
d2t=outputs.get('d2t'))
747+
self.guided_decoder.execute(outputs['logits'], d2t=self.d2t)
731748
draft_sample_state = self.sample_async(draft_batch, outputs,
732749
resource_manager)
733750

@@ -791,8 +808,7 @@ def prepare_draft_tokens(
791808

792809
if self.guided_decoder is not None:
793810
self.guided_decoder.add_batch(draft_batch)
794-
self.guided_decoder.execute(outputs['logits'],
795-
d2t=outputs.get('d2t'))
811+
self.guided_decoder.execute(outputs['logits'], d2t=self.d2t)
796812
sample_state = self.sample_async(draft_batch, outputs,
797813
resource_manager)
798814
self.update_request_states(draft_batch)

tensorrt_llm/_torch/speculative/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ def get_spec_decoder(sampler_args: TorchSampler.Args,
159159
nextn=spec_config.num_nextn_predict_layers)
160160
if spec_config.spec_dec_mode.is_eagle3(
161161
) or spec_config.spec_dec_mode.is_mtp_eagle():
162-
# TorchSampler handles Eagle3 gracefully, by integrating d2t into the sampling process
163162
return TorchSampler(sampler_args)
164163
if spec_config.spec_dec_mode.is_eagle3_one_model():
165164
return Eagle3OneModelSampler(sampler_args)

0 commit comments

Comments
 (0)