1- from typing import List
1+ from typing import Any , List
22
33import jax
44import jax .numpy as jnp
5+ import numpy as np
6+ from jax ._src import dtypes
57from jax .sharding import Mesh , NamedSharding , PartitionSpec
8+ from torchax .ops .mappings import t2j_dtype
69
710import tpu_commons .kernels .ragged_paged_attention .v3 .kernel as rpa
811from tpu_commons .logger import init_logger
@@ -37,11 +40,11 @@ def create_kv_caches(
3740 cache_dtype : jnp .dtype = DEFAULT_KV_CACHE_DTYPE ,
3841) -> List [jax .Array ]:
3942 """
40- Creates the KV caches, a list of arrays, each array is for one attention layer.
43+ Creates a list of KV cache where each array mapps to single attention layer.
4144
4245 The shape of the KV cache per layer is:
43- (num_blocks, block_size, cdiv(num_kv_heads * 2, packing), packing, head_size).
44- packing = (32 // dtype bits)
46+ (num_blocks, block_size, cdiv(num_kv_heads * 2, packing), packing, head_dim)
47+ where packing = (32 // dtype bits)
4548
4649 Args:
4750 num_blocks: The number of blocks in the KV cache.
@@ -50,6 +53,7 @@ def create_kv_caches(
5053 head_size: The size of each head in the KV cache.
5154 mesh: The mesh to shard the KV caches across.
5255 layer_names: The names of the decoder layers in the model.
56+ cache_dtype: The datatype of KV cache.
5357
5458 Returns:
5559 A list of KV caches, one per each decoder layer in the model.
@@ -75,3 +79,41 @@ def _allocate() -> jax.Array:
7579 for _ in layer_names :
7680 kv_caches .append (sharded_allocate ())
7781 return kv_caches
82+
83+
84+ def get_rpa_page_size_bytes (mesh : Mesh , kv_cache_specs : dict [str , Any ]) -> int :
85+ """
86+ Calculate KV cache page size of RPA kernel.
87+
88+ Args:
89+ mesh: The mesh to shard the KV caches across.
90+ kv_cache_specs: Dictionary of KV cache specs.
91+
92+ Returns:
93+ KV cache page size in bytes.
94+ """
95+
96+ # Import it here to avoid circular import.
97+ from vllm .v1 .kv_cache_interface import AttentionSpec
98+
99+ page_size_bytes_set = set ()
100+ for kv_cache_spec in kv_cache_specs .values ():
101+ assert isinstance (kv_cache_spec , AttentionSpec )
102+
103+ dtype = t2j_dtype (kv_cache_spec .dtype )
104+ bits = dtypes .bit_width (dtype )
105+
106+ kv_cache_shape = get_kv_cache_shape_with_mesh (
107+ mesh = mesh ,
108+ total_num_pages = 1 , # Pass 1 to get shape of a single page.
109+ page_size = kv_cache_spec .block_size ,
110+ actual_num_kv_heads = kv_cache_spec .num_kv_heads ,
111+ actual_head_dim = kv_cache_spec .head_size ,
112+ kv_dtype = dtype ,
113+ )
114+ page_size_bytes = (bits * np .prod (kv_cache_shape )) // 8
115+ page_size_bytes_set .add (page_size_bytes )
116+
117+ # Ensure that page size is the same for all kv caches.
118+ assert len (page_size_bytes_set ) == 1
119+ return page_size_bytes_set .pop ()
0 commit comments