Skip to content

Commit 303ab6f

Browse files
authored
[KV Cache] Override Number of KV cache blocks (#757)
Signed-off-by: Kyuyeun Kim <[email protected]>
1 parent f38db43 commit 303ab6f

File tree

2 files changed

+71
-5
lines changed

2 files changed

+71
-5
lines changed

tpu_commons/runner/kv_cache.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
from typing import List
1+
from typing import Any, List
22

33
import jax
44
import jax.numpy as jnp
5+
import numpy as np
6+
from jax._src import dtypes
57
from jax.sharding import Mesh, NamedSharding, PartitionSpec
8+
from torchax.ops.mappings import t2j_dtype
69

710
import tpu_commons.kernels.ragged_paged_attention.v3.kernel as rpa
811
from tpu_commons.logger import init_logger
@@ -37,11 +40,11 @@ def create_kv_caches(
3740
cache_dtype: jnp.dtype = DEFAULT_KV_CACHE_DTYPE,
3841
) -> List[jax.Array]:
3942
"""
40-
Creates the KV caches, a list of arrays, each array is for one attention layer.
43+
Creates a list of KV cache where each array mapps to single attention layer.
4144
4245
The shape of the KV cache per layer is:
43-
(num_blocks, block_size, cdiv(num_kv_heads * 2, packing), packing, head_size).
44-
packing = (32 // dtype bits)
46+
(num_blocks, block_size, cdiv(num_kv_heads * 2, packing), packing, head_dim)
47+
where packing = (32 // dtype bits)
4548
4649
Args:
4750
num_blocks: The number of blocks in the KV cache.
@@ -50,6 +53,7 @@ def create_kv_caches(
5053
head_size: The size of each head in the KV cache.
5154
mesh: The mesh to shard the KV caches across.
5255
layer_names: The names of the decoder layers in the model.
56+
cache_dtype: The datatype of KV cache.
5357
5458
Returns:
5559
A list of KV caches, one per each decoder layer in the model.
@@ -75,3 +79,41 @@ def _allocate() -> jax.Array:
7579
for _ in layer_names:
7680
kv_caches.append(sharded_allocate())
7781
return kv_caches
82+
83+
84+
def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
85+
"""
86+
Calculate KV cache page size of RPA kernel.
87+
88+
Args:
89+
mesh: The mesh to shard the KV caches across.
90+
kv_cache_specs: Dictionary of KV cache specs.
91+
92+
Returns:
93+
KV cache page size in bytes.
94+
"""
95+
96+
# Import it here to avoid circular import.
97+
from vllm.v1.kv_cache_interface import AttentionSpec
98+
99+
page_size_bytes_set = set()
100+
for kv_cache_spec in kv_cache_specs.values():
101+
assert isinstance(kv_cache_spec, AttentionSpec)
102+
103+
dtype = t2j_dtype(kv_cache_spec.dtype)
104+
bits = dtypes.bit_width(dtype)
105+
106+
kv_cache_shape = get_kv_cache_shape_with_mesh(
107+
mesh=mesh,
108+
total_num_pages=1, # Pass 1 to get shape of a single page.
109+
page_size=kv_cache_spec.block_size,
110+
actual_num_kv_heads=kv_cache_spec.num_kv_heads,
111+
actual_head_dim=kv_cache_spec.head_size,
112+
kv_dtype=dtype,
113+
)
114+
page_size_bytes = (bits * np.prod(kv_cache_shape)) // 8
115+
page_size_bytes_set.add(page_size_bytes)
116+
117+
# Ensure that page size is the same for all kv caches.
118+
assert len(page_size_bytes_set) == 1
119+
return page_size_bytes_set.pop()

tpu_commons/worker/tpu_worker_jax.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
init_distributed_environment)
1616
from vllm.lora.request import LoRARequest
1717
from vllm.tasks import SupportedTask
18+
from vllm.v1.core.kv_cache_utils import get_num_blocks, get_uniform_page_size
1819
from vllm.v1.core.sched.output import SchedulerOutput
1920
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
2021
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
@@ -27,6 +28,7 @@
2728
from tpu_commons.distributed.utils import (get_host_ip, get_kv_transfer_port,
2829
get_node_id)
2930
from tpu_commons.logger import init_logger
31+
from tpu_commons.runner.kv_cache import get_rpa_page_size_bytes
3032
from tpu_commons.runner.tpu_jax_runner import TPUModelRunner
3133
from tpu_commons.worker._temporary_vllm_compat import (
3234
adapt_kv_cache_config_if_needed, adapt_lora_request_if_needed,
@@ -251,7 +253,29 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
251253
# responsible for this translation. When vLLM can be modified, this
252254
# method should be changed to return `dict[str, AbstractKVCacheSpec]`,
253255
# and the vLLM side should be updated to handle the translation.
254-
return self.model_runner.get_kv_cache_spec()
256+
kv_cache_specs = self.model_runner.get_kv_cache_spec()
257+
258+
# TODO(kyuyeunk): Instead of checking page_size_bytes here, introduce
259+
# feature that allows overriding page_size_bytes of KVCacheSpec.
260+
vllm_page_size_bytes = get_uniform_page_size(kv_cache_specs)
261+
rpa_page_size_bytes = get_rpa_page_size_bytes(self.model_runner.mesh,
262+
kv_cache_specs)
263+
264+
if vllm_page_size_bytes != rpa_page_size_bytes:
265+
logger.info(
266+
f"KV cache page size calculated by vLLM "
267+
f"({vllm_page_size_bytes} Bytes) does not match with actual "
268+
f"page size used by RPA kernel ({rpa_page_size_bytes} Bytes). "
269+
f"Recalculating number of KV blocks using actual page size.")
270+
271+
available_memory = self.determine_available_memory()
272+
num_blocks = get_num_blocks(self.vllm_config, len(kv_cache_specs),
273+
available_memory, rpa_page_size_bytes)
274+
275+
cache_config = self.vllm_config.cache_config
276+
cache_config.num_gpu_blocks_override = num_blocks
277+
278+
return kv_cache_specs
255279

256280
def initialize_from_config(
257281
self,

0 commit comments

Comments
 (0)