Skip to content

Commit 6580cb7

Browse files
committed
fix yapf error
Signed-off-by: Ronald1995 <[email protected]>
1 parent 152e671 commit 6580cb7

File tree

4 files changed

+96
-92
lines changed

4 files changed

+96
-92
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,8 @@ def build(
348348
device=query_start_loc_cpu.device)
349349
])
350350

351-
query_start_loc = query_start_loc_cpu.pin_memory().to(self.device,
352-
non_blocking=True)
351+
query_start_loc = query_start_loc_cpu.pin_memory().to(
352+
self.device, non_blocking=True)
353353

354354
if get_ascend_device_type() == AscendDeviceType._310P:
355355
if attn_state == AscendAttentionState.PrefillNoCache:

vllm_ascend/attention/mla_v1.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -568,41 +568,41 @@ def build(
568568
)
569569
chunked_context_metadata = AscendMLAPrefillMetadata.ChunkedContextMetadata(
570570
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
571-
device, non_blocking=True
572-
),
571+
device, non_blocking=True),
573572
starts=local_chunk_starts.pin_memory().to(
574-
device, non_blocking=True
575-
),
576-
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
573+
device, non_blocking=True),
574+
seq_tot=padded_local_chunk_seq_lens.sum(
575+
dim=1).tolist(),
577576
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
578577
chunk_seq_lens=chunk_seq_lens,
579578
chunk_seq_lens_npu=chunk_seq_lens.npu(),
580579
workspace=self.chunked_prefill_workspace,
581-
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(),
582-
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
583-
local_context_lens_allranks=local_context_lens_allranks.tolist(),
584-
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.pin_memory().to(
585-
device, non_blocking=True
586-
),
580+
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.
581+
npu(),
582+
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens
583+
.tolist(),
584+
local_context_lens_allranks=local_context_lens_allranks
585+
.tolist(),
586+
padded_local_cu_seq_lens=
587+
padded_local_cu_chunk_seq_lens_cpu.pin_memory().to(
588+
device, non_blocking=True),
587589
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
588590
chunk_size=padded_local_max_context_chunk_across_ranks,
589591
)
590592
else:
591593
chunked_context_metadata = (
592594
AscendMLAPrefillMetadata.ChunkedContextMetadata(
593595
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
594-
device, non_blocking=True
595-
),
596+
device, non_blocking=True),
596597
starts=chunk_starts.pin_memory().to(
597-
device, non_blocking=True
598-
),
598+
device, non_blocking=True),
599599
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
600-
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
600+
max_seq_lens=chunk_seq_lens.max(
601+
dim=1).values.tolist(),
601602
chunk_seq_lens=chunk_seq_lens,
602603
chunk_seq_lens_npu=chunk_seq_lens.npu(),
603604
workspace=self.chunked_prefill_workspace,
604-
)
605-
)
605+
))
606606
prefill_input_positions = input_positions[tokens_start:]
607607
cos = self.cos_cache[
608608
prefill_input_positions].unsqueeze( # type: ignore
@@ -634,7 +634,8 @@ def build(
634634
cos = common_attn_metadata.cos
635635
sin = common_attn_metadata.sin
636636
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
637-
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes + 1].tolist()
637+
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes +
638+
1].tolist()
638639
max_seq_lens = seq_lens[:num_decodes].max().item()
639640
seq_lens = seq_lens[:num_decodes]
640641
input_positions = input_positions[:num_decode_tokens]

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,9 @@ def __init__(
144144
self.arange = torch.arange(max_num_slots_for_arange,
145145
device=device,
146146
dtype=torch.int32)
147-
self.arange_cpu = torch.arange(
148-
max_num_slots_for_arange, device="cpu", dtype=torch.int32
149-
)
147+
self.arange_cpu = torch.arange(max_num_slots_for_arange,
148+
device="cpu",
149+
dtype=torch.int32)
150150

151151
self.inputs_embeds = torch.zeros(
152152
(self.max_num_tokens, self.hidden_size),
@@ -346,7 +346,8 @@ def generate_token_ids(self,
346346
self.runner.discard_request_indices.gpu,
347347
self.runner.num_discarded_requests
348348
)
349-
self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count)
349+
self._copy_valid_sampled_token_count(next_token_ids,
350+
valid_sampled_tokens_count)
350351

351352
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
352353
if self.pcp_size > 1:
@@ -426,24 +427,28 @@ def generate_token_ids(self,
426427
)
427428

428429
return draft_token_ids
429-
430+
430431
def _copy_valid_sampled_token_count(
431-
self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor
432-
) -> None:
432+
self, next_token_ids: torch.Tensor,
433+
valid_sampled_tokens_count: torch.Tensor) -> None:
433434
if self.runner.valid_sampled_token_count_event is not None:
434435
default_stream = torch.npu.current_stream()
435436
# initialize a new stream to overlap the copy operation with
436437
# prepare_input of draft model.
437-
with torch.npu.stream(self.runner.valid_sampled_token_count_copy_stream):
438+
with torch.npu.stream(
439+
self.runner.valid_sampled_token_count_copy_stream):
438440
self.runner.valid_sampled_token_count_copy_stream.wait_stream(
439-
default_stream
440-
) # type: ignore
441-
self.runner.valid_sampled_token_count_cpu[
442-
: valid_sampled_tokens_count.shape[0]
443-
].copy_(valid_sampled_tokens_count, non_blocking=True)
441+
default_stream) # type: ignore
442+
self.runner.valid_sampled_token_count_cpu[:
443+
valid_sampled_tokens_count
444+
.shape[0]].copy_(
445+
valid_sampled_tokens_count,
446+
non_blocking=True
447+
)
444448
self.runner.valid_sampled_token_count_event.record()
445449

446-
self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1)
450+
self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(
451+
1)
447452

448453
def _init_mtp_model(self):
449454
architecture = self.vllm_config.model_config.architecture

vllm_ascend/worker/model_runner_v1.py

Lines changed: 55 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,9 @@ def get_output(self) -> ModelRunnerOutput:
247247

248248
max_gen_len = self._sampled_token_ids_cpu.shape[-1]
249249
if max_gen_len == 1:
250-
valid_sampled_token_ids: list[np.ndarray] = [row for row in self._sampled_token_ids_cpu.numpy()]
250+
valid_sampled_token_ids: list[np.ndarray] = [
251+
row for row in self._sampled_token_ids_cpu.numpy()
252+
]
251253
else:
252254
valid_sampled_token_ids = RejectionSampler.parse_output(
253255
self._sampled_token_ids_cpu,
@@ -596,7 +598,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
596598
dtype=torch.int64,
597599
device="cpu",
598600
pin_memory=self.pin_memory,
599-
)
601+
)
600602
# Input Batch
601603
# NOTE(Chen): Ideally, we should initialize the input batch inside
602604
# `initialize_kv_cache` based on the kv cache config. However, as in
@@ -843,7 +845,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
843845
req_state.prev_num_draft_len = 0
844846
else:
845847
assert self.input_batch.prev_req_id_to_index is not None
846-
prev_req_index = self.input_batch.prev_req_id_to_index[req_id]
848+
prev_req_index = self.input_batch.prev_req_id_to_index[
849+
req_id]
847850
num_accepted = valid_sampled_token_count[prev_req_index] - 1
848851
num_rejected = req_state.prev_num_draft_len - num_accepted
849852
num_computed_tokens -= num_rejected
@@ -935,15 +938,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
935938
def _get_valid_sampled_token_count(self) -> list[int]:
936939
# Wait until valid_sampled_tokens_count is copied to cpu,
937940
prev_sampled_token_ids = self.input_batch.prev_sampled_token_ids
938-
if (
939-
self.valid_sampled_token_count_event is None
940-
or prev_sampled_token_ids is None
941-
):
941+
if (self.valid_sampled_token_count_event is None
942+
or prev_sampled_token_ids is None):
942943
return []
943944

944945
counts_cpu = self.valid_sampled_token_count_cpu
945946
self.valid_sampled_token_count_event.synchronize()
946-
return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist()
947+
return counts_cpu[:prev_sampled_token_ids.shape[0]].tolist()
947948

948949
def _init_mrope_positions(self, req_state: CachedRequestState):
949950
assert supports_mrope(self.model), "MROPE is not supported"
@@ -1278,7 +1279,8 @@ def _get_cumsum_and_arange(
12781279

12791280
return cu_num_tokens, arange
12801281

1281-
def _prepare_input_ids(self, scheduler_output: "SchedulerOutput", total_num_scheduled_tokens: int,
1282+
def _prepare_input_ids(self, scheduler_output: "SchedulerOutput",
1283+
total_num_scheduled_tokens: int,
12821284
cu_num_tokens: np.ndarray) -> None:
12831285
"""Prepare the input IDs for the current batch.
12841286
@@ -1295,7 +1297,7 @@ def _prepare_input_ids(self, scheduler_output: "SchedulerOutput", total_num_sche
12951297
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
12961298
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
12971299
return
1298-
1300+
12991301
# Async scheduling case, where some decode requests from the previous
13001302
# iteration won't have entries in input_ids_cpu and need to be copied
13011303
# on the NPU from prev_sampled_token_ids.
@@ -1322,23 +1324,22 @@ def _prepare_input_ids(self, scheduler_output: "SchedulerOutput", total_num_sche
13221324
# spec_flattened_indices = [1, 3, 4, 6, 7]
13231325
sample_flattened_indices.append(flattened_index - draft_len)
13241326
spec_flattened_indices.extend(
1325-
range(flattened_index - draft_len + 1, flattened_index + 1)
1326-
)
1327+
range(flattened_index - draft_len + 1,
1328+
flattened_index + 1))
13271329
start = prev_index * self.num_spec_tokens
13281330
# prev_draft_token_indices is used to find which draft_tokens_id
13291331
# should be copied to input_ids
13301332
# example: prev draft_tokens_id [[1,2], [3,4], [5, 6]]
13311333
# flatten draft_tokens_id [1,2,3,4,5,6]
13321334
# draft_len of each request [1, 2, 1]
13331335
# then prev_draft_token_indices is [0, 2, 3, 4]
1334-
prev_draft_token_indices.extend(range(start, start + draft_len))
1336+
prev_draft_token_indices.extend(range(start,
1337+
start + draft_len))
13351338
indices_match &= prev_index == flattened_index
13361339
max_flattened_index = max(max_flattened_index, flattened_index)
13371340
num_commmon_tokens = len(sample_flattened_indices)
1338-
total_without_spec = (
1339-
total_num_scheduled_tokens
1340-
- total_num_spec_tokens
1341-
)
1341+
total_without_spec = (total_num_scheduled_tokens -
1342+
total_num_spec_tokens)
13421343
if num_commmon_tokens < total_without_spec:
13431344
# If not all requests are decodes from the last iteration,
13441345
# We need to copy the input_ids_cpu to the NPU first.
@@ -1365,17 +1366,18 @@ def _prepare_input_ids(self, scheduler_output: "SchedulerOutput", total_num_sche
13651366
return
13661367
# Upload the index tensors asynchronously so the scatter can be non-blocking.
13671368
sampled_tokens_index_tensor = torch.tensor(
1368-
sample_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
1369-
).to(self.device, non_blocking=True)
1369+
sample_flattened_indices,
1370+
dtype=torch.int64,
1371+
pin_memory=self.pin_memory).to(self.device, non_blocking=True)
13701372
prev_common_req_indices_tensor = torch.tensor(
1371-
prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory
1372-
).to(self.device, non_blocking=True)
1373+
prev_common_req_indices,
1374+
dtype=torch.int64,
1375+
pin_memory=self.pin_memory).to(self.device, non_blocking=True)
13731376
self.input_ids.scatter_(
13741377
dim=0,
13751378
index=sampled_tokens_index_tensor,
13761379
src=self.input_batch.prev_sampled_token_ids[
1377-
prev_common_req_indices_tensor, 0
1378-
],
1380+
prev_common_req_indices_tensor, 0],
13791381
)
13801382

13811383
# scatter the draft tokens after the sampled tokens are scattered.
@@ -1384,11 +1386,13 @@ def _prepare_input_ids(self, scheduler_output: "SchedulerOutput", total_num_sche
13841386

13851387
assert isinstance(self._draft_token_ids, torch.Tensor)
13861388
draft_tokens_index_tensor = torch.tensor(
1387-
spec_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
1388-
).to(self.device, non_blocking=True)
1389+
spec_flattened_indices,
1390+
dtype=torch.int64,
1391+
pin_memory=self.pin_memory).to(self.device, non_blocking=True)
13891392
prev_draft_token_indices_tensor = torch.tensor(
1390-
prev_draft_token_indices, dtype=torch.int64, pin_memory=self.pin_memory
1391-
).to(self.device, non_blocking=True)
1393+
prev_draft_token_indices,
1394+
dtype=torch.int64,
1395+
pin_memory=self.pin_memory).to(self.device, non_blocking=True)
13921396

13931397
# because input_ids dtype is torch.int32,
13941398
# so convert draft_token_ids to torch.int32 here.
@@ -1672,9 +1676,8 @@ def _prepare_inputs(
16721676
self.query_lens = torch.from_numpy(num_scheduled_tokens)
16731677

16741678
# Copy the tensors to the NPU.
1675-
self._prepare_input_ids(
1676-
scheduler_output, total_num_scheduled_tokens, cu_num_tokens
1677-
)
1679+
self._prepare_input_ids(scheduler_output, total_num_scheduled_tokens,
1680+
cu_num_tokens)
16781681
self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
16791682
self.positions[:num_input_tokens].copy_(
16801683
self.positions_cpu[:num_input_tokens], non_blocking=True)
@@ -2122,8 +2125,9 @@ def _calc_spec_decode_metadata(
21222125
cu_num_scheduled_tokens - num_sampled_tokens,
21232126
num_sampled_tokens)
21242127
logits_indices_pcp += arange
2125-
logits_indices_pcp = torch.from_numpy(logits_indices_pcp).pin_memory().to(
2126-
self.device, non_blocking=True)
2128+
logits_indices_pcp = torch.from_numpy(
2129+
logits_indices_pcp).pin_memory().to(self.device,
2130+
non_blocking=True)
21272131

21282132
# Compute the bonus logits indices.
21292133
bonus_logits_indices = cu_num_sampled_tokens - 1
@@ -2145,27 +2149,19 @@ def _calc_spec_decode_metadata(
21452149

21462150
# TODO: Optimize the CPU -> NPU copy.
21472151
cu_num_draft_tokens = (
2148-
torch.from_numpy(cu_num_draft_tokens)
2149-
.pin_memory()
2150-
.to(self.device, non_blocking=True)
2151-
)
2152+
torch.from_numpy(cu_num_draft_tokens).pin_memory().to(
2153+
self.device, non_blocking=True))
21522154
cu_num_sampled_tokens = (
2153-
torch.from_numpy(cu_num_sampled_tokens)
2154-
.pin_memory()
2155-
.to(self.device, non_blocking=True)
2156-
)
2157-
logits_indices = (
2158-
torch.from_numpy(logits_indices)
2159-
.pin_memory()
2160-
.to(self.device, non_blocking=True)
2161-
)
2155+
torch.from_numpy(cu_num_sampled_tokens).pin_memory().to(
2156+
self.device, non_blocking=True))
2157+
logits_indices = (torch.from_numpy(logits_indices).pin_memory().to(
2158+
self.device, non_blocking=True))
21622159
target_logits_indices = (
2163-
torch.from_numpy(target_logits_indices)
2164-
.pin_memory()
2165-
.to(self.device, non_blocking=True)
2166-
)
2167-
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).pin_memory().to(
2168-
self.device, non_blocking=True)
2160+
torch.from_numpy(target_logits_indices).pin_memory().to(
2161+
self.device, non_blocking=True))
2162+
bonus_logits_indices = torch.from_numpy(
2163+
bonus_logits_indices).pin_memory().to(self.device,
2164+
non_blocking=True)
21692165

21702166
# Compute the draft token ids.
21712167
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
@@ -2654,7 +2650,7 @@ def sample_tokens(
26542650
# when preparing inputs.
26552651
self.input_batch.prev_sampled_token_ids = sampled_token_ids
26562652

2657-
2653+
26582654
self.input_batch.prev_sampled_token_ids_invalid_indices = \
26592655
invalid_req_indices_set
26602656
self.input_batch.prev_req_id_to_index = {
@@ -2671,8 +2667,9 @@ def sample_tokens(
26712667
for req_idx in range(num_sampled_tokens):
26722668
sampled_ids: np.ndarray | None
26732669
if self.use_async_scheduling:
2674-
sampled_ids = (np.array([-1]) if req_idx
2675-
not in invalid_req_indices_set else None)
2670+
sampled_ids = (np.array([
2671+
-1
2672+
]) if req_idx not in invalid_req_indices_set else None)
26762673
else:
26772674
sampled_ids = valid_sampled_token_ids[req_idx]
26782675
if sampled_ids is None or sampled_ids.shape[0] == 0:
@@ -2685,16 +2682,17 @@ def sample_tokens(
26852682
f"Total number of tokens: {end_idx} > max_model_len: "
26862683
f"{self.model_config.max_model_len}")
26872684

2688-
self.input_batch.token_ids_cpu[req_idx,
2689-
start_idx:end_idx] = sampled_ids
2685+
self.input_batch.token_ids_cpu[
2686+
req_idx, start_idx:end_idx] = sampled_ids
26902687
self.input_batch.is_token_ids[req_idx,
2691-
start_idx:end_idx] = True
2688+
start_idx:end_idx] = True
26922689
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
26932690
self.input_batch.num_tokens[req_idx] = end_idx
26942691
req_id = self.input_batch.req_ids[req_idx]
26952692
req_state = self.requests[req_id]
26962693
req_state.output_token_ids.extend(sampled_ids.tolist())
26972694
self.input_batch.prev_sampled_token_ids = None
2695+
26982696
def propose_draft_token_ids(sampled_token_ids):
26992697
assert self.spec_decode_common_attn_metadata is not None
27002698
self._draft_token_ids = self.propose_draft_token_ids(

0 commit comments

Comments
 (0)