11from __future__ import annotations
22
33import traceback
4- from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple
4+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , cast
55
66import torch
77
@@ -73,14 +73,30 @@ def __init__(
7373 self .guided_decoder = guided_decoder
7474
7575 self .use_static_draft_loop = draft_model_engine .model_is_wrapped
76+ self .d2t : Optional [torch .Tensor ] = None
7677 if self .use_static_draft_loop :
7778 # TODO: enable sampling/guided decoding on static draft loop
7879 assert guided_decoder is None
7980 assert spec_config ._allow_greedy_draft_tokens
81+ else :
82+ # Handle d2t data if available. Static drafting loops should incorporate d2t
83+ # in their implementations.
84+ if hasattr (self .draft_model_engine .model .model , "d2t" ):
85+ self .d2t = self .draft_model_engine .model .model .d2t .data
86+ self .d2t_host : Optional [torch .Tensor ] = None
87+ if self .d2t is not None :
88+ self .d2t_host = self .d2t .to (device = "cpu" )
8089
8190 def _create_draft_request (self , request : LlmRequest ,
8291 input_tokens : Optional [List ]) -> LlmRequest :
8392 """Create a draft request with common parameters."""
93+ needs_probs = self .sampler .should_provide_draft_probs (request )
94+
95+ # NB: Currently, "d2t" is applied to draft tokens, but not to draft logits,
96+ # breaking _process_draft_tokens_rejection_sampling.
97+ if self .d2t is not None and needs_probs :
98+ raise ValueError ("d2t does not yet support non-greedy sampling" )
99+
84100 return LlmRequest (
85101 input_tokens = input_tokens ,
86102 request_id = request .py_request_id ,
@@ -94,8 +110,7 @@ def _create_draft_request(self, request: LlmRequest,
94110 True , # prepare_draft_tokens uses overlap scheduling
95111 is_draft = True ,
96112 # NB: self.sampler is shared with PyExecutor
97- return_generation_logits = self .sampler .should_provide_draft_probs (
98- request ))
113+ return_generation_logits = needs_probs )
99114
100115 def _initialize_draft_tokens (self , request : LlmRequest ) -> Tuple [int , int ]:
101116 """Initialize draft token tracking for a request."""
@@ -301,12 +316,6 @@ def forward_draft_model(
301316 resource_manager ,
302317 new_tensors_device = previous_tensors )
303318
304- # Handle d2t data if available. Static drafting loops should incorporate d2t
305- # in their implementations.
306- if not self .use_static_draft_loop and hasattr (
307- self .draft_model_engine .model .model , 'd2t' ):
308- outputs ['d2t' ] = self .draft_model_engine .model .model .d2t .data
309-
310319 return outputs
311320
312321 def sample_async (
@@ -365,6 +374,7 @@ def update_requests(
365374 """Update requests with sample state."""
366375 self .sampler .update_requests (sample_state , resource_manager )
367376
377+ @torch .inference_mode ()
368378 def process_decoded_tokens (
369379 self , draft_batch : ScheduledRequests ,
370380 req_id_to_old_request : Dict [int , LlmRequest ]) -> List [LlmRequest ]:
@@ -378,7 +388,12 @@ def process_decoded_tokens(
378388 self .draft_seq_slot_manager .free_resources (req )
379389 continue
380390
381- target_model_req .py_draft_tokens .append (req .get_last_tokens (0 ))
391+ draft_token : int = req .get_last_tokens (0 )
392+ if self .d2t_host is not None :
393+ # NB: This is not batched over requests, but considered acceptable given
394+ # that the code already loops over requests/tokens.
395+ draft_token += cast (int , self .d2t_host [draft_token ].item ())
396+ target_model_req .py_draft_tokens .append (draft_token )
382397 target_model_req .py_draft_logits = req .py_result .generation_logits # forwards Nones
383398 if req .state != LlmRequestState .GENERATION_COMPLETE and len (
384399 target_model_req .py_draft_tokens
@@ -591,8 +606,7 @@ def _execute_draft_iteration(
591606
592607 if self .guided_decoder is not None :
593608 self .guided_decoder .add_batch (draft_batch )
594- self .guided_decoder .execute (outputs ['logits' ],
595- d2t = outputs .get ('d2t' ))
609+ self .guided_decoder .execute (outputs ['logits' ], d2t = self .d2t )
596610
597611 sample_state = self .sample_async (draft_batch , outputs , resource_manager )
598612 self .update_request_states (draft_batch )
@@ -726,8 +740,7 @@ def generate_draft_tokens_with_overlap(
726740 # Handle guided decoder and sampling for non-static loop
727741 if self .guided_decoder is not None :
728742 self .guided_decoder .add_batch (draft_batch )
729- self .guided_decoder .execute (outputs ['logits' ],
730- d2t = outputs .get ('d2t' ))
743+ self .guided_decoder .execute (outputs ['logits' ], d2t = self .d2t )
731744 draft_sample_state = self .sample_async (draft_batch , outputs ,
732745 resource_manager )
733746
@@ -791,8 +804,7 @@ def prepare_draft_tokens(
791804
792805 if self .guided_decoder is not None :
793806 self .guided_decoder .add_batch (draft_batch )
794- self .guided_decoder .execute (outputs ['logits' ],
795- d2t = outputs .get ('d2t' ))
807+ self .guided_decoder .execute (outputs ['logits' ], d2t = self .d2t )
796808 sample_state = self .sample_async (draft_batch , outputs ,
797809 resource_manager )
798810 self .update_request_states (draft_batch )
0 commit comments