88from torchax .ops .mappings import t2j_dtype
99
1010import tpu_inference .kernels .ragged_paged_attention .v3 .kernel as rpa
11+ import tpu_inference .kernels .ragged_paged_attention .v3 .kernel_hd64 as rpa_hd64
1112from tpu_inference .logger import init_logger
1213
1314logger = init_logger (__name__ )
@@ -22,10 +23,18 @@ def get_kv_cache_shape_with_mesh(mesh: Mesh, total_num_pages: int,
2223
2324 model_cnt = mesh .shape ["model" ]
2425 assert actual_num_kv_heads % model_cnt == 0
26+ # NOTE(chengjiyao): Currently, the attention kernel is tailored to the
27+ # specific model, rather than being determined by the head_dim. If new
28+ # models are introduced with a head_dim of 64, this will require additional
29+ # model-specific adjustments.
30+ get_kv_cache_shape_fn = (
31+ rpa_hd64 .get_kv_cache_shape if actual_head_dim == 64 \
32+ else rpa .get_kv_cache_shape
33+ )
2534 shape = list (
26- rpa . get_kv_cache_shape (total_num_pages , page_size ,
27- actual_num_kv_heads // model_cnt ,
28- actual_head_dim , kv_dtype ))
35+ get_kv_cache_shape_fn (total_num_pages , page_size ,
36+ actual_num_kv_heads // model_cnt ,
37+ actual_head_dim , kv_dtype ))
2938 shape [2 ] *= model_cnt
3039 return tuple (shape )
3140
0 commit comments