1
1
from __future__ import annotations
2
2
3
3
import traceback
4
- from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple
4
+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , cast
5
5
6
6
import torch
7
7
@@ -73,14 +73,30 @@ def __init__(
73
73
self .guided_decoder = guided_decoder
74
74
75
75
self .use_static_draft_loop = draft_model_engine .model_is_wrapped
76
+ self .d2t : Optional [torch .Tensor ] = None
76
77
if self .use_static_draft_loop :
77
78
# TODO: enable sampling/guided decoding on static draft loop
78
79
assert guided_decoder is None
79
80
assert spec_config ._allow_greedy_draft_tokens
81
+ else :
82
+ # Handle d2t data if available. Static drafting loops should incorporate d2t
83
+ # in their implementations.
84
+ if hasattr (self .draft_model_engine .model .model , "d2t" ):
85
+ self .d2t = self .draft_model_engine .model .model .d2t .data
86
+ self .d2t_host : Optional [torch .Tensor ] = None
87
+ if self .d2t is not None :
88
+ self .d2t_host = self .d2t .to (device = "cpu" )
80
89
81
90
def _create_draft_request (self , request : LlmRequest ,
82
91
input_tokens : Optional [List ]) -> LlmRequest :
83
92
"""Create a draft request with common parameters."""
93
+ needs_probs = self .sampler .should_provide_draft_probs (request )
94
+
95
+ # NB: Currently, "d2t" is applied to draft tokens, but not to draft logits,
96
+ # breaking _process_draft_tokens_rejection_sampling.
97
+ if self .d2t is not None and needs_probs :
98
+ raise ValueError ("d2t does not yet support non-greedy sampling" )
99
+
84
100
return LlmRequest (
85
101
input_tokens = input_tokens ,
86
102
request_id = request .py_request_id ,
@@ -94,8 +110,7 @@ def _create_draft_request(self, request: LlmRequest,
94
110
True , # prepare_draft_tokens uses overlap scheduling
95
111
is_draft = True ,
96
112
# NB: self.sampler is shared with PyExecutor
97
- return_generation_logits = self .sampler .should_provide_draft_probs (
98
- request ))
113
+ return_generation_logits = needs_probs )
99
114
100
115
def _initialize_draft_tokens (self , request : LlmRequest ) -> Tuple [int , int ]:
101
116
"""Initialize draft token tracking for a request."""
@@ -301,12 +316,6 @@ def forward_draft_model(
301
316
resource_manager ,
302
317
new_tensors_device = previous_tensors )
303
318
304
- # Handle d2t data if available. Static drafting loops should incorporate d2t
305
- # in their implementations.
306
- if not self .use_static_draft_loop and hasattr (
307
- self .draft_model_engine .model .model , 'd2t' ):
308
- outputs ['d2t' ] = self .draft_model_engine .model .model .d2t .data
309
-
310
319
return outputs
311
320
312
321
def sample_async (
@@ -365,6 +374,7 @@ def update_requests(
365
374
"""Update requests with sample state."""
366
375
self .sampler .update_requests (sample_state , resource_manager )
367
376
377
+ @torch .inference_mode ()
368
378
def process_decoded_tokens (
369
379
self , draft_batch : ScheduledRequests ,
370
380
req_id_to_old_request : Dict [int , LlmRequest ]) -> List [LlmRequest ]:
@@ -378,7 +388,16 @@ def process_decoded_tokens(
378
388
self .draft_seq_slot_manager .free_resources (req )
379
389
continue
380
390
381
- target_model_req .py_draft_tokens .append (req .get_last_tokens (0 ))
391
+ draft_tokens : List [int ] = req .get_last_tokens (0 )
392
+ if self .d2t_host is not None :
393
+ # NB: This is not batched over requests, but considered acceptable given
394
+ # that the code already loops over requests and there are few draft
395
+ # tokens per request.
396
+ draft_tokens = [
397
+ tok + cast (int , self .d2t_host [tok ].item ())
398
+ for tok in draft_tokens
399
+ ]
400
+ target_model_req .py_draft_tokens .append (draft_tokens )
382
401
target_model_req .py_draft_logits = req .py_result .generation_logits # forwards Nones
383
402
if req .state != LlmRequestState .GENERATION_COMPLETE and len (
384
403
target_model_req .py_draft_tokens
@@ -591,8 +610,7 @@ def _execute_draft_iteration(
591
610
592
611
if self .guided_decoder is not None :
593
612
self .guided_decoder .add_batch (draft_batch )
594
- self .guided_decoder .execute (outputs ['logits' ],
595
- d2t = outputs .get ('d2t' ))
613
+ self .guided_decoder .execute (outputs ['logits' ], d2t = self .d2t )
596
614
597
615
sample_state = self .sample_async (draft_batch , outputs , resource_manager )
598
616
self .update_request_states (draft_batch )
@@ -726,8 +744,7 @@ def generate_draft_tokens_with_overlap(
726
744
# Handle guided decoder and sampling for non-static loop
727
745
if self .guided_decoder is not None :
728
746
self .guided_decoder .add_batch (draft_batch )
729
- self .guided_decoder .execute (outputs ['logits' ],
730
- d2t = outputs .get ('d2t' ))
747
+ self .guided_decoder .execute (outputs ['logits' ], d2t = self .d2t )
731
748
draft_sample_state = self .sample_async (draft_batch , outputs ,
732
749
resource_manager )
733
750
@@ -791,8 +808,7 @@ def prepare_draft_tokens(
791
808
792
809
if self .guided_decoder is not None :
793
810
self .guided_decoder .add_batch (draft_batch )
794
- self .guided_decoder .execute (outputs ['logits' ],
795
- d2t = outputs .get ('d2t' ))
811
+ self .guided_decoder .execute (outputs ['logits' ], d2t = self .d2t )
796
812
sample_state = self .sample_async (draft_batch , outputs ,
797
813
resource_manager )
798
814
self .update_request_states (draft_batch )
0 commit comments