diff --git a/tests/layers/vllm/test_attention.py b/tests/layers/vllm/test_attention.py index df1f0395f..002f2641a 100644 --- a/tests/layers/vllm/test_attention.py +++ b/tests/layers/vllm/test_attention.py @@ -30,9 +30,7 @@ # Number of attention heads (Key/Value) - for Grouped-Query Attention NUM_KV_HEADS = 4 # Dimension of each attention head -HEAD_DIM = 64 -# Padded head dimension -PADDED_HEAD_DIM = 64 +HEAD_DIM = 128 # Total number of blocks in the KV cache NUM_BLOCKS = 32 # Number of tokens per block diff --git a/tpu_inference/runner/kv_cache.py b/tpu_inference/runner/kv_cache.py index 236e86c5d..096be0164 100644 --- a/tpu_inference/runner/kv_cache.py +++ b/tpu_inference/runner/kv_cache.py @@ -8,6 +8,7 @@ from torchax.ops.mappings import t2j_dtype import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa +import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64 from tpu_inference.logger import init_logger logger = init_logger(__name__) @@ -22,10 +23,18 @@ def get_kv_cache_shape_with_mesh(mesh: Mesh, total_num_pages: int, model_cnt = mesh.shape["model"] assert actual_num_kv_heads % model_cnt == 0 + # NOTE(chengjiyao): Currently, the attention kernel is tailored to the + # specific model, rather than being determined by the head_dim. If new + # models are introduced with a head_dim of 64, this will require additional + # model-specific adjustments. + get_kv_cache_shape_fn = ( + rpa_hd64.get_kv_cache_shape if actual_head_dim == 64 \ + else rpa.get_kv_cache_shape + ) shape = list( - rpa.get_kv_cache_shape(total_num_pages, page_size, - actual_num_kv_heads // model_cnt, - actual_head_dim, kv_dtype)) + get_kv_cache_shape_fn(total_num_pages, page_size, + actual_num_kv_heads // model_cnt, + actual_head_dim, kv_dtype)) shape[2] *= model_cnt return tuple(shape)