Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions tests/layers/vllm/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@
# Number of attention heads (Key/Value) - for Grouped-Query Attention
NUM_KV_HEADS = 4
# Dimension of each attention head
HEAD_DIM = 64
# Padded head dimension
PADDED_HEAD_DIM = 64
HEAD_DIM = 128
# Total number of blocks in the KV cache
NUM_BLOCKS = 32
# Number of tokens per block
Expand Down
15 changes: 12 additions & 3 deletions tpu_inference/runner/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torchax.ops.mappings import t2j_dtype

import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
from tpu_inference.logger import init_logger

logger = init_logger(__name__)
Expand All @@ -22,10 +23,18 @@ def get_kv_cache_shape_with_mesh(mesh: Mesh, total_num_pages: int,

model_cnt = mesh.shape["model"]
assert actual_num_kv_heads % model_cnt == 0
# NOTE(chengjiyao): Currently, the attention kernel is tailored to the
# specific model, rather than being determined by the head_dim. If new
# models are introduced with a head_dim of 64, this will require additional
# model-specific adjustments.
get_kv_cache_shape_fn = (
rpa_hd64.get_kv_cache_shape if actual_head_dim == 64 \
else rpa.get_kv_cache_shape
)
shape = list(
rpa.get_kv_cache_shape(total_num_pages, page_size,
actual_num_kv_heads // model_cnt,
actual_head_dim, kv_dtype))
get_kv_cache_shape_fn(total_num_pages, page_size,
actual_num_kv_heads // model_cnt,
actual_head_dim, kv_dtype))
shape[2] *= model_cnt
return tuple(shape)

Expand Down