Skip to content

Commit 30a1aea

Browse files
yaochengjisixiang-google
authored andcommitted
Fix kv cache shape for head_dim=64 (#976)
Signed-off-by: Chengji Yao <[email protected]>
1 parent cb68b51 commit 30a1aea

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

tests/layers/vllm/test_attention.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@
3030
# Number of attention heads (Key/Value) - for Grouped-Query Attention
3131
NUM_KV_HEADS = 4
3232
# Dimension of each attention head
33-
HEAD_DIM = 64
34-
# Padded head dimension
35-
PADDED_HEAD_DIM = 64
33+
HEAD_DIM = 128
3634
# Total number of blocks in the KV cache
3735
NUM_BLOCKS = 32
3836
# Number of tokens per block

tpu_inference/runner/kv_cache.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torchax.ops.mappings import t2j_dtype
99

1010
import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
11+
import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
1112
from tpu_inference.logger import init_logger
1213

1314
logger = init_logger(__name__)
@@ -22,10 +23,18 @@ def get_kv_cache_shape_with_mesh(mesh: Mesh, total_num_pages: int,
2223

2324
model_cnt = mesh.shape["model"]
2425
assert actual_num_kv_heads % model_cnt == 0
26+
# NOTE(chengjiyao): Currently, the attention kernel is tailored to the
27+
# specific model, rather than being determined by the head_dim. If new
28+
# models are introduced with a head_dim of 64, this will require additional
29+
# model-specific adjustments.
30+
get_kv_cache_shape_fn = (
31+
rpa_hd64.get_kv_cache_shape if actual_head_dim == 64 \
32+
else rpa.get_kv_cache_shape
33+
)
2534
shape = list(
26-
rpa.get_kv_cache_shape(total_num_pages, page_size,
27-
actual_num_kv_heads // model_cnt,
28-
actual_head_dim, kv_dtype))
35+
get_kv_cache_shape_fn(total_num_pages, page_size,
36+
actual_num_kv_heads // model_cnt,
37+
actual_head_dim, kv_dtype))
2938
shape[2] *= model_cnt
3039
return tuple(shape)
3140

0 commit comments

Comments
 (0)