Skip to content

Commit c7c093f

Browse files
committed
implement async scheduling for mtp
Signed-off-by: Ronald1995 <[email protected]>
1 parent 84d7f5a commit c7c093f

File tree

6 files changed

+1813
-1213
lines changed

6 files changed

+1813
-1213
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def build(
348348
device=query_start_loc_cpu.device)
349349
])
350350

351-
query_start_loc = query_start_loc_cpu.to(self.device,
351+
query_start_loc = query_start_loc_cpu.pin_memory().to(self.device,
352352
non_blocking=True)
353353

354354
if get_ascend_device_type() == AscendDeviceType._310P:

vllm_ascend/attention/mla_v1.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -566,10 +566,13 @@ def build(
566566
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
567567
dtype=torch.int32,
568568
)
569-
chunked_context_metadata = \
570-
AscendMLAPrefillMetadata.ChunkedContextMetadata(
571-
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
572-
starts=local_chunk_starts.to(device, non_blocking=True),
569+
chunked_context_metadata = AscendMLAPrefillMetadata.ChunkedContextMetadata(
570+
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
571+
device, non_blocking=True
572+
),
573+
starts=local_chunk_starts.pin_memory().to(
574+
device, non_blocking=True
575+
),
573576
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
574577
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
575578
chunk_seq_lens=chunk_seq_lens,
@@ -578,22 +581,27 @@ def build(
578581
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(),
579582
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
580583
local_context_lens_allranks=local_context_lens_allranks.tolist(),
581-
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to(
584+
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.pin_memory().to(
582585
device, non_blocking=True
583586
),
584587
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
585588
chunk_size=padded_local_max_context_chunk_across_ranks,
586589
)
587590
else:
588-
chunked_context_metadata = \
591+
chunked_context_metadata = (
589592
AscendMLAPrefillMetadata.ChunkedContextMetadata(
590-
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
591-
starts=chunk_starts.to(device, non_blocking=True),
592-
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
593-
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
594-
chunk_seq_lens=chunk_seq_lens,
595-
chunk_seq_lens_npu=chunk_seq_lens.npu(),
596-
workspace=self.chunked_prefill_workspace,
593+
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
594+
device, non_blocking=True
595+
),
596+
starts=chunk_starts.pin_memory().to(
597+
device, non_blocking=True
598+
),
599+
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
600+
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
601+
chunk_seq_lens=chunk_seq_lens,
602+
chunk_seq_lens_npu=chunk_seq_lens.npu(),
603+
workspace=self.chunked_prefill_workspace,
604+
)
597605
)
598606
prefill_input_positions = input_positions[tokens_start:]
599607
cos = self.cos_cache[
@@ -626,7 +634,7 @@ def build(
626634
cos = common_attn_metadata.cos
627635
sin = common_attn_metadata.sin
628636
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
629-
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
637+
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes + 1].tolist()
630638
max_seq_lens = seq_lens[:num_decodes].max().item()
631639
seq_lens = seq_lens[:num_decodes]
632640
input_positions = input_positions[:num_decode_tokens]

vllm_ascend/sample/rejection_sampler.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -317,21 +317,27 @@ def rejection_greedy_sample_pytorch(
317317
draft_token_ids, # [num_tokens]
318318
target_argmax, # [num_tokens]
319319
bonus_token_ids, # [batch_size]
320-
draft_tokens_per_req, # [batch_size], list
320+
draft_tokens_per_req_cpu, # [batch_size], list
321321
max_spec_len,
322322
is_greedy=None, # [batch_size] or None
323323
):
324324
batch_size = output_token_ids.size(0)
325325
num_tokens = draft_token_ids.size(0)
326326
device = output_token_ids.device
327-
draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to(
328-
device, non_blocking=True)
327+
draft_tokens_per_req = (
328+
torch.tensor(draft_tokens_per_req_cpu)
329+
.pin_memory()
330+
.to(device, non_blocking=True)
331+
)
329332
if is_greedy is None:
330333
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device)
331334

332335
start_indices = cu_num_draft_tokens - draft_tokens_per_req
333336
req_ids = torch.arange(batch_size, device=device)
334-
token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req)
337+
total_draft_tokens = sum(draft_tokens_per_req_cpu)
338+
token_req_ids = torch.repeat_interleave(
339+
req_ids, draft_tokens_per_req, output_size=total_draft_tokens
340+
)
335341
token_positions = torch.arange(
336342
num_tokens, device=device) - start_indices[token_req_ids]
337343

@@ -357,8 +363,11 @@ def rejection_greedy_sample_pytorch(
357363
max_spec_len * 2)
358364
first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1)
359365
no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2)
360-
first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[
361-
no_mismatch_mask]
366+
first_mismatch_pos_per_req = torch.where(
367+
no_mismatch_mask,
368+
draft_tokens_per_req,
369+
first_mismatch_pos_per_req,
370+
)
362371

363372
# Copy matched target tokens into output.
364373
copy_len = torch.minimum(first_mismatch_pos_per_req + 1,
@@ -369,16 +378,19 @@ def rejection_greedy_sample_pytorch(
369378
greedy_mask = is_greedy.unsqueeze(1)
370379
final_copy_mask = copy_mask & greedy_mask
371380
global_idx = start_indices.unsqueeze(1) + copy_indices
372-
output_token_ids[final_copy_mask] = target_argmax[
373-
global_idx[final_copy_mask]].to(output_token_ids.dtype)
381+
output_token_ids_ = torch.where(
382+
final_copy_mask,
383+
target_argmax[global_idx].to(output_token_ids.dtype),
384+
output_token_ids
385+
)
386+
output_token_ids.copy_(output_token_ids_)
374387
# Fill bonus token.
375388
needs_bonus = is_greedy & (first_mismatch_pos_per_req
376389
>= draft_tokens_per_req)
377-
if torch.any(needs_bonus):
378-
bonus_rows = torch.where(needs_bonus)[0]
379-
bonus_cols = draft_tokens_per_req[bonus_rows]
380-
bonus_token_ids = bonus_token_ids.squeeze(1)
381-
output_token_ids[bonus_rows, bonus_cols] = bonus_token_ids[bonus_rows]
390+
bonus_rows = torch.where(needs_bonus)[0]
391+
bonus_cols = draft_tokens_per_req[bonus_rows]
392+
bonus_token_ids = bonus_token_ids.squeeze(1)
393+
output_token_ids[bonus_rows, bonus_cols] = bonus_token_ids[bonus_rows]
382394

383395

384396
def rejection_random_sample_pytorch(

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ def __init__(
144144
self.arange = torch.arange(max_num_slots_for_arange,
145145
device=device,
146146
dtype=torch.int32)
147+
self.arange_cpu = torch.arange(
148+
max_num_slots_for_arange, device="cpu", dtype=torch.int32
149+
)
147150

148151
self.inputs_embeds = torch.zeros(
149152
(self.max_num_tokens, self.hidden_size),
@@ -159,6 +162,7 @@ def __init__(
159162
)
160163
self.use_sparse = hasattr(vllm_config.model_config.hf_config,
161164
"index_topk")
165+
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
162166

163167
def load_model(self, model) -> None:
164168
loader = get_model_loader(self.vllm_config.load_config)
@@ -342,6 +346,7 @@ def generate_token_ids(self,
342346
self.runner.discard_request_indices.gpu,
343347
self.runner.num_discarded_requests
344348
)
349+
self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count)
345350

346351
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
347352
if self.pcp_size > 1:
@@ -421,6 +426,24 @@ def generate_token_ids(self,
421426
)
422427

423428
return draft_token_ids
429+
430+
def _copy_valid_sampled_token_count(
431+
self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor
432+
) -> None:
433+
if self.runner.valid_sampled_token_count_event is not None:
434+
default_stream = torch.npu.current_stream()
435+
# initialize a new stream to overlap the copy operation with
436+
# prepare_input of draft model.
437+
with torch.npu.stream(self.runner.valid_sampled_token_count_copy_stream):
438+
self.runner.valid_sampled_token_count_copy_stream.wait_stream(
439+
default_stream
440+
) # type: ignore
441+
self.runner.valid_sampled_token_count_cpu[
442+
: valid_sampled_tokens_count.shape[0]
443+
].copy_(valid_sampled_tokens_count, non_blocking=True)
444+
self.runner.valid_sampled_token_count_event.record()
445+
446+
self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1)
424447

425448
def _init_mtp_model(self):
426449
architecture = self.vllm_config.model_config.architecture
@@ -689,7 +712,11 @@ def _propose(
689712
uniform_decode=False)
690713
aclgraph_runtime_mode, batch_descriptor = \
691714
self.runner.aclgraph_dispatcher.dispatch(batch_descriptor)
692-
715+
if self.use_async_scheduling:
716+
# there is synchronize between mtp steps when enable aclgraph,
717+
# disable aclgraph when use async scheduling to avoid the
718+
# synchronize overhead.
719+
aclgraph_runtime_mode = CUDAGraphMode.NONE
693720
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
694721
) and aclgraph_runtime_mode == CUDAGraphMode.FULL:
695722
graph_pad_size = num_input_tokens
@@ -795,7 +822,7 @@ def _propose(
795822
# When disable_padded_drafter_batch=False, it should not to be updating these params, maybe.
796823
if self.speculative_config.disable_padded_drafter_batch or \
797824
aclgraph_runtime_mode != CUDAGraphMode.FULL:
798-
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
825+
attn_metadata_i.decode.actual_seq_lengths_q = self.arange_cpu[
799826
1:batch_size + 1].tolist()
800827
if aclgraph_runtime_mode == CUDAGraphMode.FULL:
801828
attn_metadata_i.decode.actual_seq_lengths_q = \

0 commit comments

Comments
 (0)