@@ -247,7 +247,9 @@ def get_output(self) -> ModelRunnerOutput:
247247
248248 max_gen_len = self ._sampled_token_ids_cpu .shape [- 1 ]
249249 if max_gen_len == 1 :
250- valid_sampled_token_ids : list [np .ndarray ] = [row for row in self ._sampled_token_ids_cpu .numpy ()]
250+ valid_sampled_token_ids : list [np .ndarray ] = [
251+ row for row in self ._sampled_token_ids_cpu .numpy ()
252+ ]
251253 else :
252254 valid_sampled_token_ids = RejectionSampler .parse_output (
253255 self ._sampled_token_ids_cpu ,
@@ -596,7 +598,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
596598 dtype = torch .int64 ,
597599 device = "cpu" ,
598600 pin_memory = self .pin_memory ,
599- )
601+ )
600602 # Input Batch
601603 # NOTE(Chen): Ideally, we should initialize the input batch inside
602604 # `initialize_kv_cache` based on the kv cache config. However, as in
@@ -843,7 +845,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
843845 req_state .prev_num_draft_len = 0
844846 else :
845847 assert self .input_batch .prev_req_id_to_index is not None
846- prev_req_index = self .input_batch .prev_req_id_to_index [req_id ]
848+ prev_req_index = self .input_batch .prev_req_id_to_index [
849+ req_id ]
847850 num_accepted = valid_sampled_token_count [prev_req_index ] - 1
848851 num_rejected = req_state .prev_num_draft_len - num_accepted
849852 num_computed_tokens -= num_rejected
@@ -935,15 +938,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
935938 def _get_valid_sampled_token_count (self ) -> list [int ]:
936939 # Wait until valid_sampled_tokens_count is copied to cpu,
937940 prev_sampled_token_ids = self .input_batch .prev_sampled_token_ids
938- if (
939- self .valid_sampled_token_count_event is None
940- or prev_sampled_token_ids is None
941- ):
941+ if (self .valid_sampled_token_count_event is None
942+ or prev_sampled_token_ids is None ):
942943 return []
943944
944945 counts_cpu = self .valid_sampled_token_count_cpu
945946 self .valid_sampled_token_count_event .synchronize ()
946- return counts_cpu [: prev_sampled_token_ids .shape [0 ]].tolist ()
947+ return counts_cpu [:prev_sampled_token_ids .shape [0 ]].tolist ()
947948
948949 def _init_mrope_positions (self , req_state : CachedRequestState ):
949950 assert supports_mrope (self .model ), "MROPE is not supported"
@@ -1278,7 +1279,8 @@ def _get_cumsum_and_arange(
12781279
12791280 return cu_num_tokens , arange
12801281
1281- def _prepare_input_ids (self , scheduler_output : "SchedulerOutput" , total_num_scheduled_tokens : int ,
1282+ def _prepare_input_ids (self , scheduler_output : "SchedulerOutput" ,
1283+ total_num_scheduled_tokens : int ,
12821284 cu_num_tokens : np .ndarray ) -> None :
12831285 """Prepare the input IDs for the current batch.
12841286
@@ -1295,7 +1297,7 @@ def _prepare_input_ids(self, scheduler_output: "SchedulerOutput", total_num_sche
12951297 self .inputs_embeds .copy_to_gpu (total_num_scheduled_tokens )
12961298 self .is_token_ids .copy_to_gpu (total_num_scheduled_tokens )
12971299 return
1298-
1300+
12991301 # Async scheduling case, where some decode requests from the previous
13001302 # iteration won't have entries in input_ids_cpu and need to be copied
13011303 # on the NPU from prev_sampled_token_ids.
@@ -1322,23 +1324,22 @@ def _prepare_input_ids(self, scheduler_output: "SchedulerOutput", total_num_sche
13221324 # spec_flattened_indices = [1, 3, 4, 6, 7]
13231325 sample_flattened_indices .append (flattened_index - draft_len )
13241326 spec_flattened_indices .extend (
1325- range (flattened_index - draft_len + 1 , flattened_index + 1 )
1326- )
1327+ range (flattened_index - draft_len + 1 ,
1328+ flattened_index + 1 ) )
13271329 start = prev_index * self .num_spec_tokens
13281330 # prev_draft_token_indices is used to find which draft_tokens_id
13291331 # should be copied to input_ids
13301332 # example: prev draft_tokens_id [[1,2], [3,4], [5, 6]]
13311333 # flatten draft_tokens_id [1,2,3,4,5,6]
13321334 # draft_len of each request [1, 2, 1]
13331335 # then prev_draft_token_indices is [0, 2, 3, 4]
1334- prev_draft_token_indices .extend (range (start , start + draft_len ))
1336+ prev_draft_token_indices .extend (range (start ,
1337+ start + draft_len ))
13351338 indices_match &= prev_index == flattened_index
13361339 max_flattened_index = max (max_flattened_index , flattened_index )
13371340 num_commmon_tokens = len (sample_flattened_indices )
1338- total_without_spec = (
1339- total_num_scheduled_tokens
1340- - total_num_spec_tokens
1341- )
1341+ total_without_spec = (total_num_scheduled_tokens -
1342+ total_num_spec_tokens )
13421343 if num_commmon_tokens < total_without_spec :
13431344 # If not all requests are decodes from the last iteration,
13441345 # We need to copy the input_ids_cpu to the NPU first.
@@ -1365,17 +1366,18 @@ def _prepare_input_ids(self, scheduler_output: "SchedulerOutput", total_num_sche
13651366 return
13661367 # Upload the index tensors asynchronously so the scatter can be non-blocking.
13671368 sampled_tokens_index_tensor = torch .tensor (
1368- sample_flattened_indices , dtype = torch .int64 , pin_memory = self .pin_memory
1369- ).to (self .device , non_blocking = True )
1369+ sample_flattened_indices ,
1370+ dtype = torch .int64 ,
1371+ pin_memory = self .pin_memory ).to (self .device , non_blocking = True )
13701372 prev_common_req_indices_tensor = torch .tensor (
1371- prev_common_req_indices , dtype = torch .int64 , pin_memory = self .pin_memory
1372- ).to (self .device , non_blocking = True )
1373+ prev_common_req_indices ,
1374+ dtype = torch .int64 ,
1375+ pin_memory = self .pin_memory ).to (self .device , non_blocking = True )
13731376 self .input_ids .scatter_ (
13741377 dim = 0 ,
13751378 index = sampled_tokens_index_tensor ,
13761379 src = self .input_batch .prev_sampled_token_ids [
1377- prev_common_req_indices_tensor , 0
1378- ],
1380+ prev_common_req_indices_tensor , 0 ],
13791381 )
13801382
13811383 # scatter the draft tokens after the sampled tokens are scattered.
@@ -1384,11 +1386,13 @@ def _prepare_input_ids(self, scheduler_output: "SchedulerOutput", total_num_sche
13841386
13851387 assert isinstance (self ._draft_token_ids , torch .Tensor )
13861388 draft_tokens_index_tensor = torch .tensor (
1387- spec_flattened_indices , dtype = torch .int64 , pin_memory = self .pin_memory
1388- ).to (self .device , non_blocking = True )
1389+ spec_flattened_indices ,
1390+ dtype = torch .int64 ,
1391+ pin_memory = self .pin_memory ).to (self .device , non_blocking = True )
13891392 prev_draft_token_indices_tensor = torch .tensor (
1390- prev_draft_token_indices , dtype = torch .int64 , pin_memory = self .pin_memory
1391- ).to (self .device , non_blocking = True )
1393+ prev_draft_token_indices ,
1394+ dtype = torch .int64 ,
1395+ pin_memory = self .pin_memory ).to (self .device , non_blocking = True )
13921396
13931397 # because input_ids dtype is torch.int32,
13941398 # so convert draft_token_ids to torch.int32 here.
@@ -1672,9 +1676,8 @@ def _prepare_inputs(
16721676 self .query_lens = torch .from_numpy (num_scheduled_tokens )
16731677
16741678 # Copy the tensors to the NPU.
1675- self ._prepare_input_ids (
1676- scheduler_output , total_num_scheduled_tokens , cu_num_tokens
1677- )
1679+ self ._prepare_input_ids (scheduler_output , total_num_scheduled_tokens ,
1680+ cu_num_tokens )
16781681 self .positions_cpu [total_num_scheduled_tokens :num_input_tokens ].zero_ ()
16791682 self .positions [:num_input_tokens ].copy_ (
16801683 self .positions_cpu [:num_input_tokens ], non_blocking = True )
@@ -2122,8 +2125,9 @@ def _calc_spec_decode_metadata(
21222125 cu_num_scheduled_tokens - num_sampled_tokens ,
21232126 num_sampled_tokens )
21242127 logits_indices_pcp += arange
2125- logits_indices_pcp = torch .from_numpy (logits_indices_pcp ).pin_memory ().to (
2126- self .device , non_blocking = True )
2128+ logits_indices_pcp = torch .from_numpy (
2129+ logits_indices_pcp ).pin_memory ().to (self .device ,
2130+ non_blocking = True )
21272131
21282132 # Compute the bonus logits indices.
21292133 bonus_logits_indices = cu_num_sampled_tokens - 1
@@ -2145,27 +2149,19 @@ def _calc_spec_decode_metadata(
21452149
21462150 # TODO: Optimize the CPU -> NPU copy.
21472151 cu_num_draft_tokens = (
2148- torch .from_numpy (cu_num_draft_tokens )
2149- .pin_memory ()
2150- .to (self .device , non_blocking = True )
2151- )
2152+ torch .from_numpy (cu_num_draft_tokens ).pin_memory ().to (
2153+ self .device , non_blocking = True ))
21522154 cu_num_sampled_tokens = (
2153- torch .from_numpy (cu_num_sampled_tokens )
2154- .pin_memory ()
2155- .to (self .device , non_blocking = True )
2156- )
2157- logits_indices = (
2158- torch .from_numpy (logits_indices )
2159- .pin_memory ()
2160- .to (self .device , non_blocking = True )
2161- )
2155+ torch .from_numpy (cu_num_sampled_tokens ).pin_memory ().to (
2156+ self .device , non_blocking = True ))
2157+ logits_indices = (torch .from_numpy (logits_indices ).pin_memory ().to (
2158+ self .device , non_blocking = True ))
21622159 target_logits_indices = (
2163- torch .from_numpy (target_logits_indices )
2164- .pin_memory ()
2165- .to (self .device , non_blocking = True )
2166- )
2167- bonus_logits_indices = torch .from_numpy (bonus_logits_indices ).pin_memory ().to (
2168- self .device , non_blocking = True )
2160+ torch .from_numpy (target_logits_indices ).pin_memory ().to (
2161+ self .device , non_blocking = True ))
2162+ bonus_logits_indices = torch .from_numpy (
2163+ bonus_logits_indices ).pin_memory ().to (self .device ,
2164+ non_blocking = True )
21692165
21702166 # Compute the draft token ids.
21712167 # draft_token_indices: [ 1, 2, 3, 105, 106, 208]
@@ -2654,7 +2650,7 @@ def sample_tokens(
26542650 # when preparing inputs.
26552651 self .input_batch .prev_sampled_token_ids = sampled_token_ids
26562652
2657-
2653+
26582654 self .input_batch .prev_sampled_token_ids_invalid_indices = \
26592655 invalid_req_indices_set
26602656 self .input_batch .prev_req_id_to_index = {
@@ -2671,8 +2667,9 @@ def sample_tokens(
26712667 for req_idx in range (num_sampled_tokens ):
26722668 sampled_ids : np .ndarray | None
26732669 if self .use_async_scheduling :
2674- sampled_ids = (np .array ([- 1 ]) if req_idx
2675- not in invalid_req_indices_set else None )
2670+ sampled_ids = (np .array ([
2671+ - 1
2672+ ]) if req_idx not in invalid_req_indices_set else None )
26762673 else :
26772674 sampled_ids = valid_sampled_token_ids [req_idx ]
26782675 if sampled_ids is None or sampled_ids .shape [0 ] == 0 :
@@ -2685,16 +2682,17 @@ def sample_tokens(
26852682 f"Total number of tokens: { end_idx } > max_model_len: "
26862683 f"{ self .model_config .max_model_len } " )
26872684
2688- self .input_batch .token_ids_cpu [req_idx ,
2689- start_idx :end_idx ] = sampled_ids
2685+ self .input_batch .token_ids_cpu [
2686+ req_idx , start_idx :end_idx ] = sampled_ids
26902687 self .input_batch .is_token_ids [req_idx ,
2691- start_idx :end_idx ] = True
2688+ start_idx :end_idx ] = True
26922689 self .input_batch .num_tokens_no_spec [req_idx ] = end_idx
26932690 self .input_batch .num_tokens [req_idx ] = end_idx
26942691 req_id = self .input_batch .req_ids [req_idx ]
26952692 req_state = self .requests [req_id ]
26962693 req_state .output_token_ids .extend (sampled_ids .tolist ())
26972694 self .input_batch .prev_sampled_token_ids = None
2695+
26982696 def propose_draft_token_ids (sampled_token_ids ):
26992697 assert self .spec_decode_common_attn_metadata is not None
27002698 self ._draft_token_ids = self .propose_draft_token_ids (
0 commit comments