@@ -1247,6 +1247,12 @@ def sample_async(
12471247 model_outputs : dict [str , torch .Tensor ],
12481248 num_context_logits_prefix_sum : list [int ],
12491249 resource_manager : Optional [ResourceManager ] = None ) -> SampleState :
1250+ # NB: The sampler is either called directly by PyExecutor, for the target model,
1251+ # or by ModelDrafter.prepare_draft_tokens(), for the draft model. In the former
1252+ # case there are 1 + get_draft_token_length(request) tokens per request. In the
1253+ # latter case, there is always only 1 token per request because draft
1254+ # tokens are sampled one-by-one.
1255+
12501256 requests = scheduled_requests .all_requests ()
12511257 new_tokens = self .store .new_tokens
12521258 log_probs_host = self .log_probs_host (scheduled_requests )
@@ -1396,8 +1402,6 @@ def _sample_batched_by_strategy(
13961402 requests , pin_memory = True , vocab_size = logits_cuda .size (1 ))
13971403 generator_cuda = self .get_generator (cuda_device )
13981404
1399- # FIXME: This check should/could be performed in ModelDrafter.prepare_draft_tokens
1400- #
14011405 # NB: Currently, "d2t" is applied to draft tokens, but not to draft logits,
14021406 # breaking _process_draft_tokens_rejection_sampling.
14031407 needs_d2t = "d2t" in model_outputs
@@ -1523,15 +1527,12 @@ def _sample_batched_by_strategy(
15231527 (batch_req_indices , batch_next_tokens_cuda_int ,
15241528 batch_softmax_cuda ), = batched_results
15251529
1526- # FIXME: This should be done in ModelDrafter.prepare_draft_tokens, but for performance
1527- # parity py_draft_tokens might need to be replaced / backed by a torch.Tensor, so
1528- # that d2t can be applied in a batched manner similar to the code below.
1530+ # NB: 'd2t' contains offsets for transforming draft vocab token IDs into
1531+ # the target vocab. This is used by Eagle3ForCausalLM, whose input domain
1532+ # is the target vocab, whereas the output logits correspond to the draft
1533+ # vocab. Since the inputs/outputs are linked by TorchSampler.update_requests,
1534+ # they currently need to be handled within TorchSampler.
15291535 if needs_d2t :
1530- # NB: The sampler is either called directly by PyExecutor, for the target model,
1531- # or by ModelDrafter.prepare_draft_tokens(), for the draft model. In the former
1532- # case there are 1 + get_draft_token_length(request) tokens per request. In the
1533- # latter case, only there is always only 1 token per request because draft
1534- # tokens are sampled one-by-one.
15351536 self ._apply_d2t (batch_next_tokens_cuda_int , model_outputs )
15361537
15371538 return _BatchedSamplingResult (
@@ -1982,7 +1983,6 @@ def sample_async(
19821983 num_context_logits_prefix_sum : list [int ],
19831984 resource_manager : Optional [ResourceManager ] = None
19841985 ) -> SampleStateTRTLLM :
1985-
19861986 batch_size = scheduled_requests .batch_size
19871987 beam_width = self .beam_width (scheduled_requests .all_requests ())
19881988 if (batch_size > 1 and beam_width > 1
0 commit comments