Skip to content

Commit

Permalink
[Core] Multi-Step + Single Step Prefills via Chunked Prefill code path (
Browse files Browse the repository at this point in the history
vllm-project#8378)

Co-authored-by: Varun Sundar Rabindranath <[email protected]>
  • Loading branch information
2 people authored and siddharth9820 committed Sep 30, 2024
1 parent cfd055c commit 3ad6a87
Show file tree
Hide file tree
Showing 19 changed files with 514 additions and 109 deletions.
2 changes: 1 addition & 1 deletion csrc/prepare_inputs/advance_step.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ __global__ void advance_step_flashattn_kernel(
slot_mapping_ptr[cur_query_id] = slot_num;
}

inline void verify_tensor(std::string const& name, torch::Tensor& t,
inline void verify_tensor(std::string const& name, torch::Tensor const& t,
int64_t const size_0, int64_t const size_1,
c10::ScalarType const type) {
bool size_0_cond = true;
Expand Down
9 changes: 9 additions & 0 deletions tests/multi_step/test_correctness_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("is_async", [True])
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
@pytest.mark.parametrize("enable_chunked_prefill", [True, False])
@pytest.mark.asyncio
async def test_multi_step(
example_prompts,
Expand All @@ -49,6 +50,7 @@ async def test_multi_step(
is_async: bool,
num_logprobs: Optional[int],
attention_backend: str,
enable_chunked_prefill: bool,
monkeypatch,
) -> None:
"""Test vLLM engine with multi-step scheduling in an OpenAI-protocol
Expand All @@ -74,6 +76,10 @@ async def test_multi_step(
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> no logprobs
"""
if enable_chunked_prefill and \
(pp_size > 1 or attention_backend != "FLASH_ATTN"):
pytest.skip("Multi-step with Chunked-Prefill only supports"
"PP=1 and FLASH_ATTN backend")

override_backend_env_variable(monkeypatch, attention_backend)

Expand All @@ -93,6 +99,9 @@ async def test_multi_step(
if eager_mode:
ms_server_args.append("--enforce-eager")

if enable_chunked_prefill:
ms_server_args.append("--enable-chunked-prefill")

distributed_args = [
"--tensor-parallel-size",
str(tp_size),
Expand Down
4 changes: 4 additions & 0 deletions tests/multi_step/test_correctness_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("tp_size", [1])
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
Expand All @@ -28,6 +29,7 @@ def test_multi_step_llm(
model: str,
dtype: str,
tp_size: int,
enable_chunked_prefill: bool,
max_tokens: int,
enforce_eager: int,
num_scheduler_steps: int,
Expand All @@ -51,6 +53,7 @@ def test_multi_step_llm(
model: model under test (same for single- and multi-step engines)
dtype: tensor datatype for engine to utilize
tp_size: degree of tensor-parallelism
enable_chunked_prefill: chunked-prefill on/off
max_tokens: the maximum number of tokens to generate
enforce_eager
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
Expand All @@ -73,6 +76,7 @@ def test_multi_step_llm(
gpu_memory_utilization=0.7,
tensor_parallel_size=tp_size,
use_v2_block_manager=True,
enable_chunked_prefill=enable_chunked_prefill,
num_scheduler_steps=num_scheduler_steps,
) as vllm_model:
vllm_outputs = (vllm_model.generate_greedy(prompts, max_tokens)
Expand Down
32 changes: 27 additions & 5 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,13 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
)
return self._cached_decode_metadata

def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int, num_seqs: int, num_queries: int):
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""
Expand All @@ -355,6 +359,23 @@ def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
assert num_seqs > num_queries
assert self.use_cuda_graph

if turn_prefills_into_decodes:
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert self.num_decode_tokens + self.num_prefills == num_seqs
self.num_decode_tokens += self.num_prefills
self.num_prefills = 0
self.num_prefill_tokens = 0
self.max_prefill_seq_len = 0
self.max_query_len = 1

self.slot_mapping = self.slot_mapping[:num_seqs]
else:
assert self.seq_lens is not None
assert self.max_decode_seq_len == max(self.seq_lens)

assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs
Expand All @@ -366,7 +387,6 @@ def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
assert self.seq_lens_tensor.shape == (num_seqs, )
assert self.max_query_len == 1
assert self.max_prefill_seq_len == 0
assert self.max_decode_seq_len == max(self.seq_lens)

assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
Expand Down Expand Up @@ -706,8 +726,10 @@ def forward(

num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa

# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
Expand Down
20 changes: 12 additions & 8 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,18 +410,22 @@ def decode_metadata(self) -> Optional["FlashInferMetadata"]:

return self

def advance_step(
self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
):
def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""

assert not turn_prefills_into_decodes, \
("Chunked prefill is not supported with flashinfer yet."
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
"specific parameter.")

assert num_seqs > 0
assert num_queries > 0
assert model_input.attn_metadata is not None
Expand Down
13 changes: 10 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,9 +983,16 @@ def __init__(self,
policy: str = "fcfs") -> None:
if max_num_batched_tokens is None:
if enable_chunked_prefill:
# It is the values that have the best balance between ITL
# and TTFT on A100. Note it is not optimized for throughput.
max_num_batched_tokens = 512
if num_scheduler_steps > 1:
# Multi-step Chunked-Prefill doesn't allow prompt-chunking
# for now. Have max_num_batched_tokens set to max_model_len
# so we don't reject sequences on account of a short
# max_num_batched_tokens.
max_num_batched_tokens = max(max_model_len, 2048)
else:
# It is the values that have the best balance between ITL
# and TTFT on A100. Note it is not optimized for throughput.
max_num_batched_tokens = 512
else:
# If max_model_len is too short, use 2048 as the default value
# for higher throughput.
Expand Down
13 changes: 9 additions & 4 deletions vllm/core/block/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,12 @@ def __init__(
self._num_full_slots = self._get_num_token_ids()

@staticmethod
def get_num_required_blocks(token_ids: List[int], block_size: int) -> int:
def get_num_required_blocks(token_ids: List[int],
block_size: int,
num_lookahead_slots: int = 0) -> int:
"""Calculates the minimum number of blocks required to store a given
sequence of token IDs.
sequence of token IDs along with any look-ahead slots that may be
required (like in multi-step + chunked-prefill).
This assumes worst-case scenario, where every block requires a new
allocation (e.g. ignoring prefix caching).
Expand All @@ -66,12 +69,14 @@ def get_num_required_blocks(token_ids: List[int], block_size: int) -> int:
token_ids (List[int]): The sequence of token IDs to be stored.
block_size (int): The maximum number of tokens that can be stored in
a single block.
num_lookahead_slots (int): look-ahead slots that the sequence may
require.
Returns:
int: The minimum number of blocks required to store the given
sequence of token IDs.
sequence of token IDs along with any required look-ahead slots.
"""
return cdiv(len(token_ids), block_size)
return cdiv(len(token_ids) + num_lookahead_slots, block_size)

def allocate(self,
token_ids: List[int],
Expand Down
7 changes: 6 additions & 1 deletion vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,15 @@ def __init__(
def _get_seq_num_required_blocks(self, seq: Optional[Sequence]) -> int:
return 0 if seq is None else seq.n_blocks

def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
def can_allocate(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.

assert (num_lookahead_slots == 0
), "lookahead allocation not supported in BlockSpaceManagerV1"

check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)

self_num_required_blocks = self._get_seq_num_required_blocks(
Expand Down
5 changes: 4 additions & 1 deletion vllm/core/block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def __init__(
self._last_access_blocks_tracker = LastAccessBlocksTracker(
self.block_allocator)

def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
def can_allocate(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.

Expand All @@ -117,6 +119,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
num_required_blocks = BlockTable.get_num_required_blocks(
seq.get_token_ids(),
block_size=self.block_size,
num_lookahead_slots=num_lookahead_slots,
)

if seq_group.is_encoder_decoder():
Expand Down
4 changes: 3 additions & 1 deletion vllm/core/embedding_model_block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def __init__(
) -> None:
pass

def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
def can_allocate(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> AllocStatus:
# Always return OK for dummy purposes
return AllocStatus.OK

Expand Down
4 changes: 3 additions & 1 deletion vllm/core/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def get_block_space_manager_class(version: str):
raise ValueError(f"Unknown version {version=}")

@abstractmethod
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
def can_allocate(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> AllocStatus:
pass

@abstractmethod
Expand Down
Loading

0 comments on commit 3ad6a87

Please sign in to comment.