Skip to content

Commit 3480094

Browse files
support async mtp (#4511)
### What this PR does / why we need it? this pr aims to support async_scheduling for mtp, which refer to vllm pr vllm-project/vllm#24799. and this pr fix some synchronize problem in vllm-ascend. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e --------- Signed-off-by: Ronald1995 <[email protected]> Co-authored-by: wangxiyuan <[email protected]>
1 parent f067623 commit 3480094

File tree

8 files changed

+477
-83
lines changed

8 files changed

+477
-83
lines changed
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import os
4+
from itertools import repeat
5+
from typing import Any
6+
7+
import pytest
8+
import torch._dynamo.config as dynamo_config
9+
from vllm import SamplingParams
10+
from vllm.v1.metrics.reader import Metric
11+
12+
from tests.e2e.conftest import VllmRunner
13+
from tests.e2e.model_utils import check_outputs_equal
14+
15+
MODEL = "Qwen/Qwen3-0.6B"
16+
17+
first_prompt = ("The following numbers of the sequence " +
18+
", ".join(str(i) for i in range(10)) + " are:")
19+
example_prompts = [first_prompt, "In one word, the capital of France is "
20+
] + [f"Tell me about the number {i}: " for i in range(32)]
21+
22+
default_params = dict(
23+
temperature=0.0, # greedy
24+
max_tokens=23,
25+
min_tokens=18,
26+
)
27+
28+
29+
def test_without_spec_decoding(monkeypatch: pytest.MonkeyPatch, ):
30+
"""Test consistency of combos of async scheduling, preemption,
31+
uni/multiproc executor, prefill chunking."""
32+
test_sampling_params: list[dict[str, Any]] = [
33+
dict(),
34+
]
35+
36+
# test_preemption, executor, async_scheduling,
37+
# spec_config, test_prefill_chunking
38+
test_configs = [
39+
(False, "mp", False, None, False),
40+
(False, "mp", True, None, False),
41+
(False, "uni", True, None, False),
42+
]
43+
44+
run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
45+
46+
47+
@dynamo_config.patch(cache_size_limit=16)
48+
def run_tests(
49+
monkeypatch: pytest.MonkeyPatch,
50+
model: str,
51+
test_configs: list[tuple],
52+
test_sampling_params: list[dict[str, Any]],
53+
):
54+
"""Test consistency of combos of async scheduling, preemption,
55+
uni/multiproc executor with spec decoding."""
56+
57+
with monkeypatch.context():
58+
# avoid precision errors
59+
outputs: list[tuple[str, list, list]] = []
60+
for n, (
61+
test_preemption,
62+
executor,
63+
async_scheduling,
64+
spec_config,
65+
test_prefill_chunking,
66+
) in enumerate(test_configs, 1):
67+
test_str = f"{n}/{len(test_configs)}"
68+
test_results = run_test(
69+
model,
70+
test_str,
71+
test_sampling_params,
72+
test_preemption,
73+
executor,
74+
async_scheduling,
75+
spec_config,
76+
test_prefill_chunking=test_prefill_chunking,
77+
)
78+
outputs.append(test_results)
79+
80+
baseline_config, baseline_tests, _ = outputs[0]
81+
_, _, baseline_acceptances = next((o for o in outputs if o[2] is not None),
82+
(None, None, None))
83+
84+
print(
85+
f"BASELINE: config=[{baseline_config}], accept_rates={baseline_acceptances}"
86+
)
87+
88+
failure = None
89+
for test_config, test_outputs, test_acceptance_rates in outputs[1:]:
90+
for base_outs, base_acceptance_rate, test_outs, test_acceptance_rate, params in zip(
91+
baseline_tests,
92+
baseline_acceptances or repeat(None),
93+
test_outputs,
94+
test_acceptance_rates or repeat(None),
95+
test_sampling_params,
96+
):
97+
try:
98+
check_outputs_equal(
99+
outputs_0_lst=base_outs,
100+
outputs_1_lst=test_outs,
101+
name_0=f"baseline=[{baseline_config}], params={params}",
102+
name_1=f"config=[{test_config}], params={params}",
103+
)
104+
105+
if (base_acceptance_rate is not None
106+
and test_acceptance_rate is not None):
107+
if "spec_mml=None" in test_config:
108+
assert (test_acceptance_rate > base_acceptance_rate
109+
or test_acceptance_rate == pytest.approx(
110+
base_acceptance_rate, rel=5e-2))
111+
else:
112+
# Currently the reported acceptance rate is expected to be
113+
# lower when we sometimes skip drafting altogether.
114+
assert test_acceptance_rate > 0.1
115+
print(f"PASSED: config=[{test_config}], params={params}"
116+
f" accept_rate={test_acceptance_rate}")
117+
except AssertionError as e:
118+
print(f"FAILED: config=[{test_config}], params={params}"
119+
f" accept_rate={test_acceptance_rate}")
120+
if failure is None:
121+
failure = e
122+
123+
if failure is not None:
124+
raise failure
125+
126+
127+
def run_test(
128+
model: str,
129+
test_str: str,
130+
sampling_param_tests: list[dict[str, Any]],
131+
test_preemption: bool,
132+
executor: str,
133+
async_scheduling: bool,
134+
spec_config: dict[str, Any] | None,
135+
test_prefill_chunking: bool,
136+
):
137+
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
138+
spec_decoding = spec_config is not None
139+
cache_arg: dict[str, Any] = (
140+
# Force preemptions
141+
dict(num_gpu_blocks_override=2) if test_preemption else dict(
142+
gpu_memory_utilization=0.9))
143+
spec_mml = (spec_config or {}).get("max_model_len")
144+
test_config = (f"executor={executor}, preemption={test_preemption}, "
145+
f"async_sched={async_scheduling}, "
146+
f"chunk_prefill={test_prefill_chunking}, "
147+
f"spec_decoding={spec_decoding}, spec_mml={spec_mml}")
148+
print("-" * 80)
149+
print(f"---- TESTING {test_str}: {test_config}")
150+
print("-" * 80)
151+
with VllmRunner(
152+
model,
153+
max_model_len=512,
154+
enable_chunked_prefill=test_prefill_chunking,
155+
# Force prefill chunking
156+
max_num_batched_tokens=48 if test_prefill_chunking else None,
157+
enforce_eager=True,
158+
async_scheduling=async_scheduling,
159+
distributed_executor_backend=executor,
160+
dtype="float16", # avoid precision errors
161+
speculative_config=spec_config,
162+
disable_log_stats=False,
163+
**cache_arg,
164+
) as vllm_model:
165+
results = []
166+
acceptance_rates: list[float] | None = [] if spec_decoding else None
167+
for override_params in sampling_param_tests:
168+
metrics_before = vllm_model.model.get_metrics()
169+
print(f"----------- RUNNING PARAMS: {override_params}")
170+
results.append(
171+
vllm_model.generate(
172+
example_prompts,
173+
sampling_params=SamplingParams(**default_params,
174+
**override_params),
175+
))
176+
metrics_after = vllm_model.model.get_metrics()
177+
if acceptance_rates is not None:
178+
acceptance_rate = _get_acceptance_rate(metrics_before,
179+
metrics_after)
180+
acceptance_rates.append(acceptance_rate)
181+
print(f"ACCEPTANCE RATE {acceptance_rate}")
182+
183+
if test_preemption:
184+
preemptions = _get_count(metrics_before, metrics_after,
185+
"vllm:num_preemptions")
186+
assert preemptions > 0, "preemption test had no preemptions"
187+
188+
if len(results) > 1:
189+
# First check that the different parameter configs
190+
# actually result in different output.
191+
for other_test_outs, params in zip(results[1:],
192+
sampling_param_tests[1:]):
193+
with pytest.raises(AssertionError):
194+
check_outputs_equal(
195+
outputs_0_lst=results[0][0],
196+
outputs_1_lst=other_test_outs,
197+
name_0=f"baseline params={params}",
198+
name_1=f"other params={params}",
199+
)
200+
201+
return test_config, results, acceptance_rates
202+
203+
204+
def _get_acceptance_rate(before: list[Metric], after: list[Metric]) -> float:
205+
draft = _get_count(before, after, "vllm:spec_decode_num_draft_tokens")
206+
accept = _get_count(before, after, "vllm:spec_decode_num_accepted_tokens")
207+
return accept / draft if draft > 0 else 0.0
208+
209+
210+
def _get_count(before: list[Metric], after: list[Metric], name: str) -> int:
211+
before_val = next(m.value for m in before if m.name == name)
212+
after_val = next(m.value for m in after if m.name == name)
213+
return after_val - before_val

tests/ut/attention/test_attention_v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp,
8787
self.mock_vllm_config.scheduler_config.decode_max_num_seqs = 10
8888
self.mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
8989
self.mock_device = 'cpu:0'
90+
torch.Tensor.pin_memory = lambda x: x # noqa
9091
self.builder = AscendAttentionMetadataBuilder(None, None,
9192
self.mock_vllm_config,
9293
self.mock_device)

tests/ut/attention/test_mla_v1.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ def test_ascend_mla_metadata_builder_build_full_graph(
299299
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
300300
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
301301
mock_device = 'cpu'
302+
torch.Tensor.pin_memory = lambda x: x # noqa
302303

303304
mock_dcp.world_size = 1
304305
dcp_group = MagicMock(spec=GroupCoordinator)
@@ -534,6 +535,7 @@ def test_build_prefix_no_cache_metadata(self, mock_npu_available,
534535
mock_get_pcp_group):
535536
mock_npu_available.return_value = False
536537
mock_dcp_world_size.return_value = 1
538+
torch.Tensor.pin_memory = lambda x: x # noqa
537539
pcp_group = MagicMock(spec=GroupCoordinator)
538540
pcp_group.world_size = 1
539541
mock_get_pcp_group.return_value = pcp_group
@@ -599,6 +601,7 @@ def test_build_chunked_prefix_metadata(self, mock_npu_available,
599601
mock_get_pcp_group):
600602
mock_npu_available.return_value = False
601603
mock_dcp_world_size.return_value = 1
604+
torch.Tensor.pin_memory = lambda x: x # noqa
602605
pcp_group = MagicMock(spec=GroupCoordinator)
603606
pcp_group.world_size = 1
604607
mock_get_pcp_group.return_value = pcp_group
@@ -660,6 +663,8 @@ def test_build_decode_only_metadata(self, mock_get_ascend_config,
660663
mock_dcp_world_size,
661664
mock_get_pcp_group):
662665
mock_dcp_world_size.return_value = 1
666+
torch.Tensor.pin_memory = lambda x: x # noqa
667+
663668
pcp_group = MagicMock(spec=GroupCoordinator)
664669
pcp_group.world_size = 1
665670
mock_get_pcp_group.return_value = pcp_group
@@ -713,6 +718,8 @@ def test_build_for_graph_capture_decode_only(self, mock_get_ascend_config,
713718
mock_dcp_world_size,
714719
mock_get_pcp_group):
715720
mock_dcp_world_size.return_value = 1
721+
torch.Tensor.pin_memory = lambda x: x # noqa
722+
716723
pcp_group = MagicMock(spec=GroupCoordinator)
717724
pcp_group.world_size = 1
718725
mock_get_pcp_group.return_value = pcp_group
@@ -767,6 +774,7 @@ def test_build_for_graph_capture_prefill(self, mock_get_ascend_config,
767774
mock_dcp_world_size,
768775
mock_get_pcp_group):
769776
mock_dcp_world_size.return_value = 1
777+
torch.Tensor.pin_memory = lambda x: x # noqa
770778
pcp_group = MagicMock(spec=GroupCoordinator)
771779
pcp_group.world_size = 1
772780
mock_get_pcp_group.return_value = pcp_group

vllm_ascend/attention/attention_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,8 @@ def build(
317317
query_start_loc_cpu.device).to(query_start_loc_cpu.dtype)
318318
])
319319

320-
query_start_loc = query_start_loc_cpu.to(self.device,
321-
non_blocking=True)
320+
query_start_loc = query_start_loc_cpu.pin_memory().to(
321+
self.device, non_blocking=True)
322322

323323
attn_metadata = AscendMetadata(
324324
num_actual_tokens=num_actual_tokens,

vllm_ascend/attention/mla_v1.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -556,35 +556,43 @@ def build(
556556
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
557557
dtype=torch.int32,
558558
)
559-
chunked_context_metadata = \
560-
AscendMLAPrefillMetadata.ChunkedContextMetadata(
561-
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
562-
starts=local_chunk_starts.to(device, non_blocking=True),
563-
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
559+
chunked_context_metadata = AscendMLAPrefillMetadata.ChunkedContextMetadata(
560+
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
561+
device, non_blocking=True),
562+
starts=local_chunk_starts.pin_memory().to(
563+
device, non_blocking=True),
564+
seq_tot=padded_local_chunk_seq_lens.sum(
565+
dim=1).tolist(),
564566
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
565567
chunk_seq_lens=chunk_seq_lens,
566568
chunk_seq_lens_npu=chunk_seq_lens.npu(),
567569
workspace=self.chunked_prefill_workspace,
568-
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(),
569-
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
570-
local_context_lens_allranks=local_context_lens_allranks.tolist(),
571-
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to(
572-
device, non_blocking=True
573-
),
570+
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.
571+
npu(),
572+
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens
573+
.tolist(),
574+
local_context_lens_allranks=local_context_lens_allranks
575+
.tolist(),
576+
padded_local_cu_seq_lens=
577+
padded_local_cu_chunk_seq_lens_cpu.pin_memory().to(
578+
device, non_blocking=True),
574579
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
575580
chunk_size=padded_local_max_context_chunk_across_ranks,
576581
)
577582
else:
578-
chunked_context_metadata = \
583+
chunked_context_metadata = (
579584
AscendMLAPrefillMetadata.ChunkedContextMetadata(
580-
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
581-
starts=chunk_starts.to(device, non_blocking=True),
582-
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
583-
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
584-
chunk_seq_lens=chunk_seq_lens,
585-
chunk_seq_lens_npu=chunk_seq_lens.npu(),
586-
workspace=self.chunked_prefill_workspace,
587-
)
585+
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
586+
device, non_blocking=True),
587+
starts=chunk_starts.pin_memory().to(
588+
device, non_blocking=True),
589+
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
590+
max_seq_lens=chunk_seq_lens.max(
591+
dim=1).values.tolist(),
592+
chunk_seq_lens=chunk_seq_lens,
593+
chunk_seq_lens_npu=chunk_seq_lens.npu(),
594+
workspace=self.chunked_prefill_workspace,
595+
))
588596
prefill_input_positions = input_positions[tokens_start:]
589597
cos = self.cos_cache[
590598
prefill_input_positions].unsqueeze( # type: ignore
@@ -616,7 +624,8 @@ def build(
616624
cos = common_attn_metadata.cos
617625
sin = common_attn_metadata.sin
618626
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
619-
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
627+
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes +
628+
1].tolist()
620629
max_seq_lens = seq_lens[:num_decodes].max().item()
621630
seq_lens = seq_lens[:num_decodes]
622631
input_positions = input_positions[:num_decode_tokens]

0 commit comments

Comments
 (0)