Skip to content

Commit 682ab95

Browse files
committed
fix assert error of sampled_token_ids shape
Signed-off-by: Ronald1995 <[email protected]>
1 parent 8461e6c commit 682ab95

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2626,7 +2626,7 @@ def sample_tokens(
26262626

26272627
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
26282628
sampled_token_ids = sampler_output.sampled_token_ids
2629-
self.input_batch.prev_sampled_token_ids = None
2629+
26302630
if not self.use_async_scheduling:
26312631
# Get the valid generated tokens.
26322632
max_gen_len = sampled_token_ids.shape[-1]
@@ -2647,7 +2647,7 @@ def sample_tokens(
26472647
invalid_req_indices = discard_sampled_tokens_req_indices.tolist(
26482648
)
26492649
invalid_req_indices_set = set(invalid_req_indices)
2650-
if self.input_batch.prev_sampled_token_ids is None:
2650+
if self.num_spec_tokens <= 0:
26512651
assert sampled_token_ids.shape[-1] == 1
26522652
# Cache the sampled tokens on the NPU and avoid CPU sync.
26532653
# These will be copied into input_ids in the next step

0 commit comments

Comments
 (0)