diff --git a/tests/e2e/singlecard/test_async_scheduling.py b/tests/e2e/singlecard/test_async_scheduling.py new file mode 100644 index 00000000000..3bfbd0c9874 --- /dev/null +++ b/tests/e2e/singlecard/test_async_scheduling.py @@ -0,0 +1,213 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +from itertools import repeat +from typing import Any + +import pytest +import torch._dynamo.config as dynamo_config +from vllm import SamplingParams +from vllm.v1.metrics.reader import Metric + +from tests.e2e.conftest import VllmRunner +from tests.e2e.model_utils import check_outputs_equal + +MODEL = "Qwen/Qwen3-0.6B" + +first_prompt = ("The following numbers of the sequence " + + ", ".join(str(i) for i in range(10)) + " are:") +example_prompts = [first_prompt, "In one word, the capital of France is " + ] + [f"Tell me about the number {i}: " for i in range(32)] + +default_params = dict( + temperature=0.0, # greedy + max_tokens=23, + min_tokens=18, +) + + +def test_without_spec_decoding(monkeypatch: pytest.MonkeyPatch, ): + """Test consistency of combos of async scheduling, preemption, + uni/multiproc executor, prefill chunking.""" + test_sampling_params: list[dict[str, Any]] = [ + dict(), + ] + + # test_preemption, executor, async_scheduling, + # spec_config, test_prefill_chunking + test_configs = [ + (False, "mp", False, None, False), + (False, "mp", True, None, False), + (False, "uni", True, None, False), + ] + + run_tests(monkeypatch, MODEL, test_configs, test_sampling_params) + + +@dynamo_config.patch(cache_size_limit=16) +def run_tests( + monkeypatch: pytest.MonkeyPatch, + model: str, + test_configs: list[tuple], + test_sampling_params: list[dict[str, Any]], +): + """Test consistency of combos of async scheduling, preemption, + uni/multiproc executor with spec decoding.""" + + with monkeypatch.context(): + # avoid precision errors + outputs: list[tuple[str, list, list]] = [] + for n, ( + test_preemption, + executor, + async_scheduling, + spec_config, + test_prefill_chunking, + ) in enumerate(test_configs, 1): + test_str = f"{n}/{len(test_configs)}" + test_results = run_test( + model, + test_str, + test_sampling_params, + test_preemption, + executor, + async_scheduling, + spec_config, + test_prefill_chunking=test_prefill_chunking, + ) + outputs.append(test_results) + + baseline_config, baseline_tests, _ = outputs[0] + _, _, baseline_acceptances = next((o for o in outputs if o[2] is not None), + (None, None, None)) + + print( + f"BASELINE: config=[{baseline_config}], accept_rates={baseline_acceptances}" + ) + + failure = None + for test_config, test_outputs, test_acceptance_rates in outputs[1:]: + for base_outs, base_acceptance_rate, test_outs, test_acceptance_rate, params in zip( + baseline_tests, + baseline_acceptances or repeat(None), + test_outputs, + test_acceptance_rates or repeat(None), + test_sampling_params, + ): + try: + check_outputs_equal( + outputs_0_lst=base_outs, + outputs_1_lst=test_outs, + name_0=f"baseline=[{baseline_config}], params={params}", + name_1=f"config=[{test_config}], params={params}", + ) + + if (base_acceptance_rate is not None + and test_acceptance_rate is not None): + if "spec_mml=None" in test_config: + assert (test_acceptance_rate > base_acceptance_rate + or test_acceptance_rate == pytest.approx( + base_acceptance_rate, rel=5e-2)) + else: + # Currently the reported acceptance rate is expected to be + # lower when we sometimes skip drafting altogether. + assert test_acceptance_rate > 0.1 + print(f"PASSED: config=[{test_config}], params={params}" + f" accept_rate={test_acceptance_rate}") + except AssertionError as e: + print(f"FAILED: config=[{test_config}], params={params}" + f" accept_rate={test_acceptance_rate}") + if failure is None: + failure = e + + if failure is not None: + raise failure + + +def run_test( + model: str, + test_str: str, + sampling_param_tests: list[dict[str, Any]], + test_preemption: bool, + executor: str, + async_scheduling: bool, + spec_config: dict[str, Any] | None, + test_prefill_chunking: bool, +): + os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' + spec_decoding = spec_config is not None + cache_arg: dict[str, Any] = ( + # Force preemptions + dict(num_gpu_blocks_override=2) if test_preemption else dict( + gpu_memory_utilization=0.9)) + spec_mml = (spec_config or {}).get("max_model_len") + test_config = (f"executor={executor}, preemption={test_preemption}, " + f"async_sched={async_scheduling}, " + f"chunk_prefill={test_prefill_chunking}, " + f"spec_decoding={spec_decoding}, spec_mml={spec_mml}") + print("-" * 80) + print(f"---- TESTING {test_str}: {test_config}") + print("-" * 80) + with VllmRunner( + model, + max_model_len=512, + enable_chunked_prefill=test_prefill_chunking, + # Force prefill chunking + max_num_batched_tokens=48 if test_prefill_chunking else None, + enforce_eager=True, + async_scheduling=async_scheduling, + distributed_executor_backend=executor, + dtype="float16", # avoid precision errors + speculative_config=spec_config, + disable_log_stats=False, + **cache_arg, + ) as vllm_model: + results = [] + acceptance_rates: list[float] | None = [] if spec_decoding else None + for override_params in sampling_param_tests: + metrics_before = vllm_model.model.get_metrics() + print(f"----------- RUNNING PARAMS: {override_params}") + results.append( + vllm_model.generate( + example_prompts, + sampling_params=SamplingParams(**default_params, + **override_params), + )) + metrics_after = vllm_model.model.get_metrics() + if acceptance_rates is not None: + acceptance_rate = _get_acceptance_rate(metrics_before, + metrics_after) + acceptance_rates.append(acceptance_rate) + print(f"ACCEPTANCE RATE {acceptance_rate}") + + if test_preemption: + preemptions = _get_count(metrics_before, metrics_after, + "vllm:num_preemptions") + assert preemptions > 0, "preemption test had no preemptions" + + if len(results) > 1: + # First check that the different parameter configs + # actually result in different output. + for other_test_outs, params in zip(results[1:], + sampling_param_tests[1:]): + with pytest.raises(AssertionError): + check_outputs_equal( + outputs_0_lst=results[0][0], + outputs_1_lst=other_test_outs, + name_0=f"baseline params={params}", + name_1=f"other params={params}", + ) + + return test_config, results, acceptance_rates + + +def _get_acceptance_rate(before: list[Metric], after: list[Metric]) -> float: + draft = _get_count(before, after, "vllm:spec_decode_num_draft_tokens") + accept = _get_count(before, after, "vllm:spec_decode_num_accepted_tokens") + return accept / draft if draft > 0 else 0.0 + + +def _get_count(before: list[Metric], after: list[Metric], name: str) -> int: + before_val = next(m.value for m in before if m.name == name) + after_val = next(m.value for m in after if m.name == name) + return after_val - before_val diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index af179b37122..5923a021026 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -87,6 +87,7 @@ def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp, self.mock_vllm_config.scheduler_config.decode_max_num_seqs = 10 self.mock_vllm_config.scheduler_config.chunked_prefill_enabled = False self.mock_device = 'cpu:0' + torch.Tensor.pin_memory = lambda x: x # noqa self.builder = AscendAttentionMetadataBuilder(None, None, self.mock_vllm_config, self.mock_device) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index fbb90aa3963..e2f746558fb 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -299,6 +299,7 @@ def test_ascend_mla_metadata_builder_build_full_graph( mock_vllm_config.scheduler_config.decode_max_num_seqs = 4 mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_device = 'cpu' + torch.Tensor.pin_memory = lambda x: x # noqa mock_dcp.world_size = 1 dcp_group = MagicMock(spec=GroupCoordinator) @@ -534,6 +535,7 @@ def test_build_prefix_no_cache_metadata(self, mock_npu_available, mock_get_pcp_group): mock_npu_available.return_value = False mock_dcp_world_size.return_value = 1 + torch.Tensor.pin_memory = lambda x: x # noqa pcp_group = MagicMock(spec=GroupCoordinator) pcp_group.world_size = 1 mock_get_pcp_group.return_value = pcp_group @@ -599,6 +601,7 @@ def test_build_chunked_prefix_metadata(self, mock_npu_available, mock_get_pcp_group): mock_npu_available.return_value = False mock_dcp_world_size.return_value = 1 + torch.Tensor.pin_memory = lambda x: x # noqa pcp_group = MagicMock(spec=GroupCoordinator) pcp_group.world_size = 1 mock_get_pcp_group.return_value = pcp_group @@ -660,6 +663,8 @@ def test_build_decode_only_metadata(self, mock_get_ascend_config, mock_dcp_world_size, mock_get_pcp_group): mock_dcp_world_size.return_value = 1 + torch.Tensor.pin_memory = lambda x: x # noqa + pcp_group = MagicMock(spec=GroupCoordinator) pcp_group.world_size = 1 mock_get_pcp_group.return_value = pcp_group @@ -713,6 +718,8 @@ def test_build_for_graph_capture_decode_only(self, mock_get_ascend_config, mock_dcp_world_size, mock_get_pcp_group): mock_dcp_world_size.return_value = 1 + torch.Tensor.pin_memory = lambda x: x # noqa + pcp_group = MagicMock(spec=GroupCoordinator) pcp_group.world_size = 1 mock_get_pcp_group.return_value = pcp_group @@ -767,6 +774,7 @@ def test_build_for_graph_capture_prefill(self, mock_get_ascend_config, mock_dcp_world_size, mock_get_pcp_group): mock_dcp_world_size.return_value = 1 + torch.Tensor.pin_memory = lambda x: x # noqa pcp_group = MagicMock(spec=GroupCoordinator) pcp_group.world_size = 1 mock_get_pcp_group.return_value = pcp_group diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index cc7b77a2d72..7ba77449d08 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -317,8 +317,8 @@ def build( query_start_loc_cpu.device).to(query_start_loc_cpu.dtype) ]) - query_start_loc = query_start_loc_cpu.to(self.device, - non_blocking=True) + query_start_loc = query_start_loc_cpu.pin_memory().to( + self.device, non_blocking=True) attn_metadata = AscendMetadata( num_actual_tokens=num_actual_tokens, diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 610a6c2abb7..23ee0692cd6 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -556,35 +556,43 @@ def build( out=padded_local_cu_chunk_seq_lens_cpu[:, 1:], dtype=torch.int32, ) - chunked_context_metadata = \ - AscendMLAPrefillMetadata.ChunkedContextMetadata( - cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), - starts=local_chunk_starts.to(device, non_blocking=True), - seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(), + chunked_context_metadata = AscendMLAPrefillMetadata.ChunkedContextMetadata( + cu_seq_lens=cu_seq_lens_cpu.pin_memory().to( + device, non_blocking=True), + starts=local_chunk_starts.pin_memory().to( + device, non_blocking=True), + seq_tot=padded_local_chunk_seq_lens.sum( + dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), chunk_seq_lens=chunk_seq_lens, chunk_seq_lens_npu=chunk_seq_lens.npu(), workspace=self.chunked_prefill_workspace, - padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(), - padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(), - local_context_lens_allranks=local_context_lens_allranks.tolist(), - padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to( - device, non_blocking=True - ), + padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens. + npu(), + padded_local_chunk_seq_lens=padded_local_chunk_seq_lens + .tolist(), + local_context_lens_allranks=local_context_lens_allranks + .tolist(), + padded_local_cu_seq_lens= + padded_local_cu_chunk_seq_lens_cpu.pin_memory().to( + device, non_blocking=True), cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), chunk_size=padded_local_max_context_chunk_across_ranks, ) else: - chunked_context_metadata = \ + chunked_context_metadata = ( AscendMLAPrefillMetadata.ChunkedContextMetadata( - cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), - starts=chunk_starts.to(device, non_blocking=True), - seq_tot=chunk_seq_lens.sum(dim=1).tolist(), - max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), - chunk_seq_lens=chunk_seq_lens, - chunk_seq_lens_npu=chunk_seq_lens.npu(), - workspace=self.chunked_prefill_workspace, - ) + cu_seq_lens=cu_seq_lens_cpu.pin_memory().to( + device, non_blocking=True), + starts=chunk_starts.pin_memory().to( + device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max( + dim=1).values.tolist(), + chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens.npu(), + workspace=self.chunked_prefill_workspace, + )) prefill_input_positions = input_positions[tokens_start:] cos = self.cos_cache[ prefill_input_positions].unsqueeze( # type: ignore @@ -616,7 +624,8 @@ def build( cos = common_attn_metadata.cos sin = common_attn_metadata.sin # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario - actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist() + actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes + + 1].tolist() max_seq_lens = seq_lens[:num_decodes].max().item() seq_lens = seq_lens[:num_decodes] input_positions = input_positions[:num_decode_tokens] diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 4b7bfad91d2..b8923a7954c 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -142,6 +142,9 @@ def __init__( self.arange = torch.arange(max_num_slots_for_arange, device=device, dtype=torch.int32) + self.arange_cpu = torch.arange(max_num_slots_for_arange, + device="cpu", + dtype=torch.int32) self.inputs_embeds = torch.zeros( (self.max_num_tokens, self.hidden_size), @@ -157,6 +160,7 @@ def __init__( ) self.use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk") + self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling def load_model(self, model) -> None: loader = get_model_loader(self.vllm_config.load_config) @@ -351,6 +355,8 @@ def generate_token_ids(self, self.runner.discard_request_indices.gpu, self.runner.num_discarded_requests ) + self._copy_valid_sampled_token_count(next_token_ids, + valid_sampled_tokens_count) req_scheduled_tokens = scheduler_output.num_scheduled_tokens if self.pcp_size > 1: @@ -430,6 +436,28 @@ def generate_token_ids(self, return draft_token_ids + def _copy_valid_sampled_token_count( + self, next_token_ids: torch.Tensor, + valid_sampled_tokens_count: torch.Tensor) -> None: + if self.runner.valid_sampled_token_count_event is not None: + default_stream = torch.npu.current_stream() + # initialize a new stream to overlap the copy operation with + # prepare_input of draft model. + with torch.npu.stream( + self.runner.valid_sampled_token_count_copy_stream): + self.runner.valid_sampled_token_count_copy_stream.wait_stream( + default_stream) # type: ignore + self.runner.valid_sampled_token_count_cpu[: + valid_sampled_tokens_count + .shape[0]].copy_( + valid_sampled_tokens_count, + non_blocking=True + ) + self.runner.valid_sampled_token_count_event.record() + + self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze( + 1) + def _init_mtp_model(self): architecture = self.vllm_config.model_config.architecture target_device = self.vllm_config.device_config.device @@ -696,6 +724,11 @@ def _propose( has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0 aclgraph_runtime_mode, batch_descriptor = \ self.runner.aclgraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora) + if self.use_async_scheduling: + # there is synchronize between mtp steps when enable aclgraph, + # disable aclgraph when use async scheduling to avoid the + # synchronize overhead. + aclgraph_runtime_mode = CUDAGraphMode.NONE if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs( ) and aclgraph_runtime_mode == CUDAGraphMode.FULL: @@ -822,7 +855,7 @@ def _propose( # When disable_padded_drafter_batch=False, it should not to be updating these params, maybe. if decode_metadata is not None and (self.speculative_config.disable_padded_drafter_batch or \ aclgraph_runtime_mode != CUDAGraphMode.FULL): - decode_metadata.actual_seq_lengths_q = attn_metadata_i.query_start_loc[ + decode_metadata.actual_seq_lengths_q = self.arange_cpu[ 1:batch_size + 1].tolist() if aclgraph_runtime_mode == CUDAGraphMode.FULL: decode_metadata.actual_seq_lengths_q = \ @@ -847,7 +880,9 @@ def _propose( clamped_positions = torch.where(exceeds_max_model_len, 0, positions[:batch_size]) # Increment the sequence lengths. - attn_metadata_i.seq_lens[:batch_size] += 1 + # This is an out-of-place operation to avoid modifying the original tensor + # when enable async_scheduling. + attn_metadata_i.seq_lens = attn_metadata_i.seq_lens + 1 # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. exceeds_max_model_len_cpu = exceeds_max_model_len.to( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index c66914c7150..907a6e0707c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -97,6 +97,7 @@ make_empty_encoder_model_runner_output) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer @@ -213,6 +214,7 @@ def __init__( sampled_token_ids: torch.Tensor, invalid_req_indices: list[int], async_output_copy_stream: torch.npu.Stream, + vocab_size: int, ): self._model_runner_output = model_runner_output self._invalid_req_indices = invalid_req_indices @@ -223,7 +225,7 @@ def __init__( # Keep a reference to the device tensor to avoid it being # deallocated until we finish copying it to the host. self._sampled_token_ids = sampled_token_ids - + self.vocab_size = vocab_size # Initiate the copy on a separate stream, but do not synchronize it. default_stream = torch.npu.current_stream() with torch.npu.stream(async_output_copy_stream): @@ -242,10 +244,17 @@ def get_output(self) -> ModelRunnerOutput: # Release the device tensor once the copy has completed del self._sampled_token_ids - valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist() - for i in self._invalid_req_indices: - valid_sampled_token_ids[i].clear() - + max_gen_len = self._sampled_token_ids_cpu.shape[-1] + if max_gen_len == 1: + valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist() + for i in self._invalid_req_indices: + valid_sampled_token_ids[i].clear() + else: + valid_sampled_token_ids, _ = RejectionSampler.parse_output( + self._sampled_token_ids_cpu, + self.vocab_size, + self._invalid_req_indices, + return_cu_num_tokens=False) output = self._model_runner_output output.sampled_token_ids = valid_sampled_token_ids return output @@ -567,6 +576,20 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.use_async_scheduling = self.scheduler_config.async_scheduling self.async_output_copy_stream = torch.npu.Stream() if \ self.use_async_scheduling else None + self.num_spec_tokens = 0 + if self.speculative_config: + self.num_spec_tokens = self.speculative_config.num_speculative_tokens # noqa + self.valid_sampled_token_count_event: torch.npu.Event | None = None + self.valid_sampled_token_count_copy_stream: torch.npu.Stream | None = None + if self.use_async_scheduling and self.num_spec_tokens: + self.valid_sampled_token_count_event = torch.npu.Event() + self.valid_sampled_token_count_copy_stream = torch.npu.Stream() + self.valid_sampled_token_count_cpu = torch.empty( + self.max_num_reqs, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory, + ) # Input Batch # NOTE(Chen): Ideally, we should initialize the input batch inside # `initialize_kv_cache` based on the kv cache config. However, as in @@ -791,13 +814,40 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank req_data = scheduler_output.scheduled_cached_reqs + # wait until valid_sampled_tokens_count is copied to cpu, + # then use it to update actual num_computed_tokens of each request. + valid_sampled_token_count = self._get_valid_sampled_token_count() for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] num_computed_tokens = req_data.num_computed_tokens[i] new_block_ids = req_data.new_block_ids[i] - resumed_from_preemption = req_data.resumed_from_preemption[i] - - # Update the cached states. + resumed_from_preemption = req_id in req_data.resumed_req_ids + num_output_tokens = req_data.num_output_tokens[i] + req_index = self.input_batch.req_id_to_index.get(req_id) + # prev_num_draft_len is used in async scheduling mode with + # spec decode. it indicates if need to update num_computed_tokens + # of the request. for example: + # fist step: num_computed_tokens = 0, spec_tokens = [], + # prev_num_draft_len = 0. + # second step: num_computed_tokens = 100(prompt length), + # spec_tokens = [a,b], prev_num_draft_len = 0. + # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d], + # prev_num_draft_len = 2. + # num_computed_tokens in first step and second step doesn't contain + # the spec tokens length, but in third step it contains the + # spec tokens length. we only need to update num_computed_tokens + # when prev_num_draft_len > 0. + if req_state.prev_num_draft_len: + if req_index is None: + req_state.prev_num_draft_len = 0 + else: + assert self.input_batch.prev_req_id_to_index is not None + prev_req_index = self.input_batch.prev_req_id_to_index[ + req_id] + num_accepted = valid_sampled_token_count[prev_req_index] - 1 + num_rejected = req_state.prev_num_draft_len - num_accepted + num_computed_tokens -= num_rejected + req_state.output_token_ids.extend([-1] * num_accepted) req_state.num_computed_tokens = num_computed_tokens if not is_last_rank: @@ -828,12 +878,20 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # The request is resumed from preemption. # Replace the existing block IDs with the new ones. req_state.block_ids = new_block_ids - - req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is None: # The request is not in the persistent batch. - # The request was either preempted and resumed later, or was not - # scheduled in the previous step and needs to be added again. + # The request was either preempted and resumed later, or was + # not scheduled in the previous step and needs to be added + # again. + + if self.use_async_scheduling and num_output_tokens > 0: + # We must recover the output token ids for resumed requests + # in the async scheduling case, so that correct input_ids + # are obtained. + resumed_token_ids = req_data.all_token_ids[req_id] + req_state.output_token_ids = resumed_token_ids[ + -num_output_tokens:] + req_ids_to_add.append(req_id) continue @@ -860,8 +918,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Add spec_token_ids to token_ids_cpu. spec_token_ids = ( scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) - if spec_token_ids: - num_spec_tokens = len(spec_token_ids) + num_spec_tokens = len(spec_token_ids) + if self.use_async_scheduling: + req_state.prev_num_draft_len = num_spec_tokens + if num_spec_tokens: start_index = self.input_batch.num_tokens_no_spec[req_index] end_token_index = start_index + num_spec_tokens self.input_batch.token_ids_cpu[ @@ -882,6 +942,17 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() + def _get_valid_sampled_token_count(self) -> list[int]: + # Wait until valid_sampled_tokens_count is copied to cpu, + prev_sampled_token_ids = self.input_batch.prev_sampled_token_ids + if (self.valid_sampled_token_count_event is None + or prev_sampled_token_ids is None): + return [] + + counts_cpu = self.valid_sampled_token_count_cpu + self.valid_sampled_token_count_event.synchronize() + return counts_cpu[:prev_sampled_token_ids.shape[0]].tolist() + def _init_mrope_positions(self, req_state: CachedRequestState): assert supports_mrope(self.model), "MROPE is not supported" req_state.mrope_positions, req_state.mrope_position_delta = \ @@ -901,26 +972,25 @@ def _sync_metadata_across_dp( # immediately once the other two flags are no longer needed. if self.dp_size == 1: return num_tokens, None, with_prefill - # Sync num_tokens, with_prefill across dp ranks num_tokens_tensor = torch.tensor([ num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size) ], dtype=torch.int32, - device="npu") + device="cpu") flags_tensor = torch.tensor([int(with_prefill)], dtype=torch.int32, - device="npu") + device="cpu") packed_tensor = torch.cat([num_tokens_tensor, flags_tensor]) - - dist.all_reduce(packed_tensor, group=get_dp_group().device_group) + # use cpu_group to avoid cpu synchronization issue. + # it can be overlapped with main moell execution on npu. + dist.all_reduce(packed_tensor, group=get_dp_group().cpu_group) # Unpack the results num_tokens_across_dp = packed_tensor[:-1] synced_flags = packed_tensor[-1:] - max_tokens_across_dp = torch.max(num_tokens_across_dp).item() global_with_prefill = bool(synced_flags[0]) @@ -1195,7 +1265,8 @@ def _get_cumsum_and_arange( return cu_num_tokens, arange - def _prepare_input_ids(self, total_num_scheduled_tokens: int, + def _prepare_input_ids(self, scheduler_output: "SchedulerOutput", + total_num_scheduled_tokens: int, cu_num_tokens: np.ndarray) -> None: """Prepare the input IDs for the current batch. @@ -1218,21 +1289,44 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, # on the NPU from prev_sampled_token_ids. prev_req_id_to_index = self.input_batch.prev_req_id_to_index assert prev_req_id_to_index is not None - flattened_indices = [] - prev_common_req_indices = [] + sample_flattened_indices: list[int] = [] + spec_flattened_indices: list[int] = [] + prev_common_req_indices: list[int] = [] + prev_draft_token_indices: list[int] = [] indices_match = True max_flattened_index = -1 + total_num_spec_tokens = 0 + scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens for req_id, cur_index in self.input_batch.req_id_to_index.items(): if (prev_index := prev_req_id_to_index.get(req_id)) is not None: prev_common_req_indices.append(prev_index) # We need to compute the flattened input_ids index of the # last token in each common request. + draft_len = len(scheduled_spec_tokens.get(req_id, ())) + total_num_spec_tokens += draft_len flattened_index = cu_num_tokens[cur_index].item() - 1 - flattened_indices.append(flattened_index) - indices_match &= (prev_index == flattened_index) + # example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2] + # sample_flattened_indices = [0, 2, 5] + # spec_flattened_indices = [1, 3, 4, 6, 7] + sample_flattened_indices.append(flattened_index - draft_len) + spec_flattened_indices.extend( + range(flattened_index - draft_len + 1, + flattened_index + 1)) + start = prev_index * self.num_spec_tokens + # prev_draft_token_indices is used to find which draft_tokens_id + # should be copied to input_ids + # example: prev draft_tokens_id [[1,2], [3,4], [5, 6]] + # flatten draft_tokens_id [1,2,3,4,5,6] + # draft_len of each request [1, 2, 1] + # then prev_draft_token_indices is [0, 2, 3, 4] + prev_draft_token_indices.extend(range(start, + start + draft_len)) + indices_match &= prev_index == flattened_index max_flattened_index = max(max_flattened_index, flattened_index) - num_commmon_tokens = len(flattened_indices) - if num_commmon_tokens < total_num_scheduled_tokens: + num_commmon_tokens = len(sample_flattened_indices) + total_without_spec = (total_num_scheduled_tokens - + total_num_spec_tokens) + if num_commmon_tokens < total_without_spec: # If not all requests are decodes from the last iteration, # We need to copy the input_ids_cpu to the NPU first. self.input_ids[:total_num_scheduled_tokens].copy_( @@ -1256,21 +1350,45 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, non_blocking=True) self.is_token_ids.gpu[:num_commmon_tokens] = True return - # Upload the index tensors asynchronously - # so the scatter can be non-blocking. - input_ids_index_tensor = torch.tensor(flattened_indices, - dtype=torch.int64, - pin_memory=self.pin_memory).to( - self.device, - non_blocking=True) + # Upload the index tensors asynchronously so the scatter can be non-blocking. + sampled_tokens_index_tensor = torch.tensor( + sample_flattened_indices, + dtype=torch.int64, + pin_memory=self.pin_memory).to(self.device, non_blocking=True) prev_common_req_indices_tensor = torch.tensor( prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory).to(self.device, non_blocking=True) - self.input_ids.scatter_(dim=0, - index=input_ids_index_tensor, - src=self.input_batch.prev_sampled_token_ids[ - prev_common_req_indices_tensor, 0]) + self.input_ids.scatter_( + dim=0, + index=sampled_tokens_index_tensor, + src=self.input_batch.prev_sampled_token_ids[ + prev_common_req_indices_tensor, 0], + ) + + # scatter the draft tokens after the sampled tokens are scattered. + if self._draft_token_ids is None or not spec_flattened_indices: + return + + assert isinstance(self._draft_token_ids, torch.Tensor) + draft_tokens_index_tensor = torch.tensor( + spec_flattened_indices, + dtype=torch.int64, + pin_memory=self.pin_memory).to(self.device, non_blocking=True) + prev_draft_token_indices_tensor = torch.tensor( + prev_draft_token_indices, + dtype=torch.int64, + pin_memory=self.pin_memory).to(self.device, non_blocking=True) + + # because input_ids dtype is torch.int32, + # so convert draft_token_ids to torch.int32 here. + draft_token_ids = self._draft_token_ids.to(dtype=torch.int32) + self._draft_token_ids = None + self.input_ids.scatter_( + dim=0, + index=draft_tokens_index_tensor, + src=draft_token_ids.flatten()[prev_draft_token_indices_tensor], + ) def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: """ @@ -1544,7 +1662,8 @@ def _prepare_inputs( self.query_lens = torch.from_numpy(num_scheduled_tokens) # Copy the tensors to the NPU. - self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) + self._prepare_input_ids(scheduler_output, total_num_scheduled_tokens, + cu_num_tokens) self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_() self.positions[:num_input_tokens].copy_( self.positions_cpu[:num_input_tokens], non_blocking=True) @@ -1993,8 +2112,9 @@ def _calc_spec_decode_metadata( cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens) logits_indices_pcp += arange - logits_indices_pcp = torch.from_numpy(logits_indices_pcp).to( - self.device, non_blocking=True) + logits_indices_pcp = torch.from_numpy( + logits_indices_pcp).pin_memory().to(self.device, + non_blocking=True) # Compute the bonus logits indices. bonus_logits_indices = cu_num_sampled_tokens - 1 @@ -2015,16 +2135,20 @@ def _calc_spec_decode_metadata( target_logits_indices += arange # TODO: Optimize the CPU -> NPU copy. - cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( - self.device, non_blocking=True) - cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).to( - self.device, non_blocking=True) - logits_indices = torch.from_numpy(logits_indices).to(self.device, - non_blocking=True) - target_logits_indices = torch.from_numpy(target_logits_indices).to( - self.device, non_blocking=True) - bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to( - self.device, non_blocking=True) + cu_num_draft_tokens = ( + torch.from_numpy(cu_num_draft_tokens).pin_memory().to( + self.device, non_blocking=True)) + cu_num_sampled_tokens = ( + torch.from_numpy(cu_num_sampled_tokens).pin_memory().to( + self.device, non_blocking=True)) + logits_indices = (torch.from_numpy(logits_indices).pin_memory().to( + self.device, non_blocking=True)) + target_logits_indices = ( + torch.from_numpy(target_logits_indices).pin_memory().to( + self.device, non_blocking=True)) + bonus_logits_indices = torch.from_numpy( + bonus_logits_indices).pin_memory().to(self.device, + non_blocking=True) # Compute the draft token ids. # draft_token_indices: [ 1, 2, 3, 105, 106, 208] @@ -2466,7 +2590,6 @@ def sample_tokens( sampler_output.sampled_token_ids = output_token_ids if self.need_accepted_tokens: self._update_states_after_model_execute(output_token_ids) - discard_sampled_tokens_req_indices = \ self.discard_request_indices.np[:self.num_discarded_requests] for i in discard_sampled_tokens_req_indices: @@ -2494,6 +2617,7 @@ def sample_tokens( num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] sampled_token_ids = sampler_output.sampled_token_ids + if not self.use_async_scheduling: # Get the valid generated tokens. max_gen_len = sampled_token_ids.shape[-1] @@ -2514,13 +2638,14 @@ def sample_tokens( invalid_req_indices = discard_sampled_tokens_req_indices.tolist( ) invalid_req_indices_set = set(invalid_req_indices) - assert sampled_token_ids.shape[-1] == 1 + if self.num_spec_tokens <= 0: + assert sampled_token_ids.shape[-1] == 1 + # Cache the sampled tokens on the NPU and avoid CPU sync. + # These will be copied into input_ids in the next step + # when preparing inputs. + self.input_batch.prev_sampled_token_ids = sampled_token_ids + - # Cache the sampled tokens on the NPU and avoid CPU sync. - # These will be copied into input_ids in the next step - # when preparing inputs. - self.input_batch.prev_sampled_token_ids = \ - sampled_token_ids self.input_batch.prev_sampled_token_ids_invalid_indices = \ invalid_req_indices_set self.input_batch.prev_req_id_to_index = { @@ -2629,6 +2754,7 @@ def propose_draft_token_ids(sampled_token_ids): sampled_token_ids=sampled_token_ids, invalid_req_indices=invalid_req_indices, async_output_copy_stream=self.async_output_copy_stream, + vocab_size=self.input_batch.vocab_size, ) def take_draft_token_ids(self) -> Optional[DraftTokenIds]: diff --git a/vllm_ascend/worker/npu_input_batch.py b/vllm_ascend/worker/npu_input_batch.py index 471c150ba62..ad4f525ffbb 100644 --- a/vllm_ascend/worker/npu_input_batch.py +++ b/vllm_ascend/worker/npu_input_batch.py @@ -68,6 +68,8 @@ class CachedRequestState: lora_request: Optional[LoRARequest] = None prompt_embeds: Optional[torch.Tensor] = None + prev_num_draft_len: int = 0 # previous number of draft tokens + def __post_init__(self): self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( self.prompt_token_ids, self.prompt_embeds)