Skip to content

Commit f78c246

Browse files
committed
chore: document the role of 'd2t'
Signed-off-by: ixlmar <[email protected]>
1 parent 98b3af4 commit f78c246

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,12 @@ def sample_async(
11831183
model_outputs: dict[str, torch.Tensor],
11841184
num_context_logits_prefix_sum: list[int],
11851185
resource_manager: Optional[ResourceManager] = None) -> SampleState:
1186+
# NB: The sampler is either called directly by PyExecutor, for the target model,
1187+
# or by ModelDrafter.prepare_draft_tokens(), for the draft model. In the former
1188+
# case there are 1 + get_draft_token_length(request) tokens per request. In the
1189+
# latter case, there is always only 1 token per request because draft
1190+
# tokens are sampled one-by-one.
1191+
11861192
requests = scheduled_requests.all_requests()
11871193
new_tokens = self.store.new_tokens
11881194
log_probs_host = self.log_probs_host(scheduled_requests)
@@ -1332,8 +1338,6 @@ def _sample_batched_by_strategy(
13321338
requests, pin_memory=True)
13331339
generator_cuda = self.get_generator(cuda_device)
13341340

1335-
# FIXME: This check should/could be performed in ModelDrafter.prepare_draft_tokens
1336-
#
13371341
# NB: Currently, "d2t" is applied to draft tokens, but not to draft logits,
13381342
# breaking _process_draft_tokens_rejection_sampling.
13391343
needs_d2t = "d2t" in model_outputs
@@ -1459,15 +1463,16 @@ def _sample_batched_by_strategy(
14591463
(batch_req_indices, batch_next_tokens_cuda_int,
14601464
batch_softmax_cuda), = batched_results
14611465

1462-
# FIXME: This should be done in ModelDrafter.prepare_draft_tokens, but for performance
1463-
# parity py_draft_tokens might need to be replaced / backed by a torch.Tensor, so
1464-
# that d2t can be applied in a batched manner similar to the code below.
1466+
# NB: 'd2t' contains offsets for transforming draft vocab token IDs into
1467+
# the target vocab. This is used by Eagle3ForCausalLM, whose input domain
1468+
# is the target vocab, whereas the output logits correspond to the draft
1469+
# vocab. Since the inputs/outputs are linked by TorchSampler.update_requests,
1470+
# they currently need to be handled within TorchSampler. Changing the model
1471+
# outputs to use the target vocab would require inflating the logit tensors,
1472+
# which is inefficient. Changing the inputs to use the draft vocab, might
1473+
# be cleaner, but would require applying 'd2t' in multiple locations:
1474+
# Prefill, Eagle3ForCausalLM embeddings, ModelDrafter
14651475
if needs_d2t:
1466-
# NB: The sampler is either called directly by PyExecutor, for the target model,
1467-
# or by ModelDrafter.prepare_draft_tokens(), for the draft model. In the former
1468-
# case there are 1 + get_draft_token_length(request) tokens per request. In the
1469-
# latter case, only there is always only 1 token per request because draft
1470-
# tokens are sampled one-by-one.
14711476
self._apply_d2t(batch_next_tokens_cuda_int, model_outputs)
14721477

14731478
return _BatchedSamplingResult(
@@ -1909,7 +1914,6 @@ def sample_async(
19091914
num_context_logits_prefix_sum: list[int],
19101915
resource_manager: Optional[ResourceManager] = None
19111916
) -> SampleStateTRTLLM:
1912-
19131917
batch_size = scheduled_requests.batch_size
19141918
beam_width = self.beam_width(scheduled_requests.all_requests())
19151919
if (batch_size > 1 and beam_width > 1

0 commit comments

Comments
 (0)