Skip to content

Commit 7b6803b

Browse files
authored
[TRTLLM-7769][chore] document the role of 'd2t' (#8174)
Signed-off-by: ixlmar <[email protected]>
1 parent ccd949e commit 7b6803b

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,6 +1247,12 @@ def sample_async(
12471247
model_outputs: dict[str, torch.Tensor],
12481248
num_context_logits_prefix_sum: list[int],
12491249
resource_manager: Optional[ResourceManager] = None) -> SampleState:
1250+
# NB: The sampler is either called directly by PyExecutor, for the target model,
1251+
# or by ModelDrafter.prepare_draft_tokens(), for the draft model. In the former
1252+
# case there are 1 + get_draft_token_length(request) tokens per request. In the
1253+
# latter case, there is always only 1 token per request because draft
1254+
# tokens are sampled one-by-one.
1255+
12501256
requests = scheduled_requests.all_requests()
12511257
new_tokens = self.store.new_tokens
12521258
log_probs_host = self.log_probs_host(scheduled_requests)
@@ -1396,8 +1402,6 @@ def _sample_batched_by_strategy(
13961402
requests, pin_memory=True, vocab_size=logits_cuda.size(1))
13971403
generator_cuda = self.get_generator(cuda_device)
13981404

1399-
# FIXME: This check should/could be performed in ModelDrafter.prepare_draft_tokens
1400-
#
14011405
# NB: Currently, "d2t" is applied to draft tokens, but not to draft logits,
14021406
# breaking _process_draft_tokens_rejection_sampling.
14031407
needs_d2t = "d2t" in model_outputs
@@ -1523,15 +1527,12 @@ def _sample_batched_by_strategy(
15231527
(batch_req_indices, batch_next_tokens_cuda_int,
15241528
batch_softmax_cuda), = batched_results
15251529

1526-
# FIXME: This should be done in ModelDrafter.prepare_draft_tokens, but for performance
1527-
# parity py_draft_tokens might need to be replaced / backed by a torch.Tensor, so
1528-
# that d2t can be applied in a batched manner similar to the code below.
1530+
# NB: 'd2t' contains offsets for transforming draft vocab token IDs into
1531+
# the target vocab. This is used by Eagle3ForCausalLM, whose input domain
1532+
# is the target vocab, whereas the output logits correspond to the draft
1533+
# vocab. Since the inputs/outputs are linked by TorchSampler.update_requests,
1534+
# they currently need to be handled within TorchSampler.
15291535
if needs_d2t:
1530-
# NB: The sampler is either called directly by PyExecutor, for the target model,
1531-
# or by ModelDrafter.prepare_draft_tokens(), for the draft model. In the former
1532-
# case there are 1 + get_draft_token_length(request) tokens per request. In the
1533-
# latter case, only there is always only 1 token per request because draft
1534-
# tokens are sampled one-by-one.
15351536
self._apply_d2t(batch_next_tokens_cuda_int, model_outputs)
15361537

15371538
return _BatchedSamplingResult(
@@ -1982,7 +1983,6 @@ def sample_async(
19821983
num_context_logits_prefix_sum: list[int],
19831984
resource_manager: Optional[ResourceManager] = None
19841985
) -> SampleStateTRTLLM:
1985-
19861986
batch_size = scheduled_requests.batch_size
19871987
beam_width = self.beam_width(scheduled_requests.all_requests())
19881988
if (batch_size > 1 and beam_width > 1

0 commit comments

Comments
 (0)