Skip to content

Commit a59e660

Browse files
committed
fix RejectionSampler.parse_output
Signed-off-by: Ronald1995 <[email protected]>
1 parent 682ab95 commit a59e660

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

tests/e2e/singlecard/test_async_scheduling.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from itertools import repeat
44
from typing import Any
5+
import os
56

67
import pytest
78
import torch._dynamo.config as dynamo_config
@@ -169,10 +170,11 @@ def run_test(
169170
spec_config: dict[str, Any] | None,
170171
test_prefill_chunking: bool,
171172
):
173+
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
172174
spec_decoding = spec_config is not None
173175
cache_arg: dict[str, Any] = (
174176
# Force preemptions
175-
dict(num_gpu_blocks_override=32) if test_preemption else dict(
177+
dict(num_gpu_blocks_override=2) if test_preemption else dict(
176178
gpu_memory_utilization=0.9))
177179
spec_mml = (spec_config or {}).get("max_model_len")
178180
test_config = (f"executor={executor}, preemption={test_preemption}, "
@@ -199,7 +201,7 @@ def run_test(
199201
results = []
200202
acceptance_rates: list[float] | None = [] if spec_decoding else None
201203
for override_params in sampling_param_tests:
202-
metrics_before = vllm_model.llm.get_metrics()
204+
metrics_before = vllm_model.model.get_metrics()
203205
print(f"----------- RUNNING PARAMS: {override_params}")
204206
results.append(
205207
vllm_model.generate(
@@ -208,7 +210,7 @@ def run_test(
208210
**override_params),
209211
return_logprobs=True,
210212
))
211-
metrics_after = vllm_model.llm.get_metrics()
213+
metrics_after = vllm_model.model.get_metrics()
212214
if acceptance_rates is not None:
213215
acceptance_rate = _get_acceptance_rate(metrics_before,
214216
metrics_after)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def get_output(self) -> ModelRunnerOutput:
252252
for i in self._invalid_req_indices:
253253
valid_sampled_token_ids[i].clear()
254254
else:
255-
valid_sampled_token_ids = RejectionSampler.parse_output(
255+
valid_sampled_token_ids, _ = RejectionSampler.parse_output(
256256
self._sampled_token_ids_cpu,
257257
self.vocab_size,
258258
self._invalid_req_indices,

0 commit comments

Comments
 (0)