diff --git a/scripts/signlerun.sh b/scripts/signlerun.sh new file mode 100644 index 00000000..965ad629 --- /dev/null +++ b/scripts/signlerun.sh @@ -0,0 +1,12 @@ +python3 ./src/parallax/launch.py \ +--model-path /Users/alizen/study_files/Dev/yuhao/models/Qwen3-VL-2B \ +--port 3210 \ +--start-layer 0 \ +--end-layer 28 \ +--kv-block-size 16 \ +--max-sequence-length 1000 \ +--max-num-tokens-per-batch 4096 \ +--kv-cache-memory-fraction 0.3 \ +--max-batch-size 4 \ +--log-level DEBUG \ +--use-hfcache diff --git a/scripts/test.sh b/scripts/test.sh new file mode 100644 index 00000000..a0598a78 --- /dev/null +++ b/scripts/test.sh @@ -0,0 +1,34 @@ +# curl --location 'http://localhost:3210/v1/chat/completions' \ +# --header 'Content-Type: application/json' \ +# --data '{ +# "max_tokens": 10, +# "messages": [ +# { +# "role": "user", +# "content": [ +# { +# "type": "image_url", +# "image_url": { +# "url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" +# } +# }, +# {"type": "text", "text": "Describe this image."} +# ] +# } +# ], +# "stream": false +# }' + + +curl --location 'http://localhost:3210/v1/chat/completions' \ + --header 'Content-Type: application/json' \ + --data '{ + "max_tokens": 10, + "messages": [ + { + "role": "user", + "content": "hello" + } + ], + "stream": false + }' diff --git a/src/parallax/metal/paged_attention/paged_attention.metal b/src/parallax/metal/paged_attention/paged_attention.metal index 6780a411..b59612da 100644 --- a/src/parallax/metal/paged_attention/paged_attention.metal +++ b/src/parallax/metal/paged_attention/paged_attention.metal @@ -36,13 +36,13 @@ int kv_head_idx = head_idx / (_num_heads / _num_kv_heads); // Q: [batch, num_heads, k_head_dim] // Thread i loads elements i, i+32, ... -// Support up to 256 head dim (8 * 32) -float q_vec[8] = {0.0f}; +// Support up to 640 head dim (20 * 32) +float q_vec[20] = {0.0f}; int q_offset = batch_idx * _num_heads * _k_head_dim + head_idx * _k_head_dim; for (int i = tid.x; i < _k_head_dim; i += 32) { - if (i < 256) { + if (i < 640) { q_vec[i / 32] = queries[q_offset + i]; } } @@ -50,7 +50,7 @@ for (int i = tid.x; i < _k_head_dim; i += 32) { // Running statistics for Softmax float m_i = -INFINITY; float l_i = 0.0f; -float acc_vec[8] = {0.0f}; +float acc_vec[20] = {0.0f}; int context_len = context_lengths[batch_idx]; int num_context_blocks = (context_len + _block_size - 1) / _block_size; @@ -86,7 +86,7 @@ for (int b = 0; b < num_context_blocks; b++) { // offset inside block: t * k_head_dim + i float k_val = key_cache[k_block_base + t * _k_head_dim + i]; - if (i < 256) { + if (i < 640) { score += q_vec[i / 32] * k_val; } } @@ -106,7 +106,7 @@ for (int b = 0; b < num_context_blocks; b++) { // Accumulate V for (int i = tid.x; i < _v_head_dim; i += 32) { float v_val = value_cache[v_block_base + t * _v_head_dim + i]; - if (i < 256) { + if (i < 640) { acc_vec[i / 32] = acc_vec[i / 32] * alpha + v_val * beta; } } @@ -114,14 +114,14 @@ for (int b = 0; b < num_context_blocks; b++) { } // Finalize Output -for (int i = 0; i < 8; i++) { +for (int i = 0; i < 20; i++) { acc_vec[i] /= l_i; } int out_offset = batch_idx * _num_heads * _v_head_dim + head_idx * _v_head_dim; for (int i = tid.x; i < _v_head_dim; i += 32) { - if (i < 256) { + if (i < 640) { output[out_offset + i] = ({{T}})acc_vec[i / 32]; } } diff --git a/src/parallax/models/glm4_moe_lite.py b/src/parallax/models/glm4_moe_lite.py new file mode 100644 index 00000000..8311c7de --- /dev/null +++ b/src/parallax/models/glm4_moe_lite.py @@ -0,0 +1,215 @@ +from typing import Any, List, Optional + +from parallax_utils.logging_config import get_logger + +logger = get_logger(__name__) + +import mlx.core as mx +from mlx_lm.models.base import scaled_dot_product_attention +from mlx_lm.models.glm4_moe_lite import Glm4MoeLiteAttention as MLXGLM4MoeLiteAttention +from mlx_lm.models.glm4_moe_lite import Glm4MoeLiteDecoderLayer as MLXGLM4MoeLiteBlock +from mlx_lm.models.glm4_moe_lite import ModelArgs + +from parallax.metal.paged_attention.kernel import paged_attention, reshape_and_cache +from parallax.server.cache.base import BaseCache +from parallax.utils.prefix_cache_utils import compute_attention_with_prefix_cache + + +class ParallaxGLM4MoeLiteAttention(MLXGLM4MoeLiteAttention): + """A custom attention module for Parallax, extending the GLM4 MoE Lite Attention class. + + GLM4 MoE Lite uses Multi-head Latent Attention (MLA) similar to DeepSeek V3, but + instead of kv_b_proj, it uses embed_q and unembed_out (MultiLinear): + - embed_q: transforms q_nope from qk_nope_head_dim -> kv_lora_rank (per head) + - unembed_out: transforms attention output from kv_lora_rank -> v_head_dim (per head) + - keys = [kv_latent, k_pe] with 1 KV head (MQA-style) + - values = kv_latent with 1 KV head + """ + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[BaseCache] = None, + block_tables: Optional[mx.array] = None, + context_lengths: Optional[mx.array] = None, + slot_mapping: Optional[mx.array] = None, + prefix_lens: Optional[mx.array] = None, + **kwargs, + ) -> mx.array: + """ + Attention forward pass with explicit KV cache handling. + + Args: + x: (batch, target_len, hidden_dim) - Input hidden states. + mask: (batch, n_q_heads, target_len, source_len) + cache: BaseCache object containing the layer cache. + block_tables: (batch, max_blocks) - PagedKV block tables. + context_lengths: (batch,) - PagedKV sequence lengths. + slot_mapping: (batch * target_len,) - Flattened slot mapping. + prefix_lens: (batch,) - Number of prefix tokens already cached. + + Returns: + output: (batch, target_len, hidden_dim) - Output hidden states. + """ + batch, target_len, _ = x.shape + + # Q projection (with optional LoRA) + if self.q_lora_rank is None: + q = self.q_proj(x) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x))) + + q = q.reshape(batch, target_len, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3) + q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1) + compressed_kv = self.kv_a_proj_with_mqa(x) + compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1) + k_pe = k_pe.reshape(batch, target_len, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3) + + kv_latent = self.kv_a_layernorm(compressed_kv) + + if target_len == 1: + current_pos = context_lengths - 1 + elif prefix_lens is not None: + current_pos = prefix_lens + else: + current_pos = 0 + + q_pe = self.rope(q_pe, offset=current_pos) + k_pe = self.rope(k_pe, offset=current_pos) + + # Transform q_nope into kv_lora_rank space via embed_q (per-head MultiLinear) + kv_latent_expanded = mx.expand_dims(kv_latent, axis=1) + # kv_latent_expanded: (batch, 1, target_len, kv_lora_rank) + + q_nope = self.embed_q(q_nope) + # q_nope: (batch, num_heads, target_len, kv_lora_rank) + + # Construct queries, keys, values + queries = mx.concatenate([q_nope, q_pe], axis=-1) + # queries: (batch, num_heads, target_len, kv_lora_rank + qk_rope_head_dim) + + keys = mx.concatenate([kv_latent_expanded, k_pe], axis=-1) + # keys: (batch, 1, target_len, kv_lora_rank + qk_rope_head_dim) + + # Values = kv_latent (the non-rope part of keys) + # For reshape_and_cache, values shape: (batch, target_len, num_kv_heads=1, kv_lora_rank) + values = mx.expand_dims(kv_latent, axis=2) + # values: (batch, target_len, 1, kv_lora_rank) + + key_cache_global, value_cache_global = cache.get_cache() + block_size = key_cache_global.shape[3] + + # Store keys and values in paged cache + reshape_and_cache( + keys.transpose(0, 2, 1, 3), # (batch, target_len, 1, key_head_dim) + values, # (batch, target_len, 1, kv_lora_rank) + key_cache_global, + value_cache_global, + block_tables, + context_lengths, + block_size, + slot_mapping=slot_mapping, + ) + + if target_len == 1: + # Decode phase: Use Paged Attention + output = paged_attention( + queries, + key_cache_global, + value_cache_global, + block_tables, + context_lengths, + block_size, + self.scale, + 1, # num_kv_heads = 1 (MQA via latent attention) + v_head_dim=self.kv_lora_rank, + ) + # output: (batch, num_heads, 1, kv_lora_rank) + output = self.unembed_out(output) + # output: (batch, num_heads, 1, v_head_dim) + output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) + else: + # Prefill phase + has_prefix_cache = prefix_lens is not None and bool(mx.any(prefix_lens > 0)) + + if has_prefix_cache: + k_new = keys # (batch, 1, target_len, key_head_dim) + v_new = values.transpose(0, 2, 1, 3) # (batch, 1, target_len, kv_lora_rank) + output = compute_attention_with_prefix_cache( + queries, + k_new, + v_new, + cache, + block_tables, + prefix_lens, + target_len, + self.scale, + 1, # num_kv_heads = 1 + mask=mask, + unembed_out=True, # Skip reshape, we need to apply unembed_out first + ) + # output: (batch, num_heads, target_len, kv_lora_rank) + output = self.unembed_out(output) + # output: (batch, num_heads, target_len, v_head_dim) + output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) + else: + # No prefix cache, standard self-attention + if mask is not None: + mask = mx.array(mask, dtype=queries.dtype) + + output = scaled_dot_product_attention( + queries, + keys, + values.transpose(0, 2, 1, 3), # (batch, 1, target_len, kv_lora_rank) + scale=self.scale, + mask=mask, + cache=None, + ) + # output: (batch, num_heads, target_len, kv_lora_rank) + output = self.unembed_out(output) + # output: (batch, num_heads, target_len, v_head_dim) + output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) + + return self.o_proj(output) + + +class ParallaxGLM4MoeLiteBlock(MLXGLM4MoeLiteBlock): + """A custom transformer block for Parallax, extending GLM4 MoE Lite DecoderLayer.""" + + def __init__(self, args: ModelArgs, layer_idx: int, local_layer_idx: int): + super().__init__(args, layer_idx) + self.self_attn = ParallaxGLM4MoeLiteAttention(args) + self.local_layer_idx = local_layer_idx + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[List[Any]] = None, + block_tables: Optional[mx.array] = None, + context_lengths: Optional[mx.array] = None, + slot_mapping: Optional[mx.array] = None, + **kwargs, + ): + r = self.self_attn( + self.input_layernorm(x), + mask, + cache[self.local_layer_idx], + block_tables=block_tables, + context_lengths=context_lengths, + slot_mapping=slot_mapping, + **kwargs, + ) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out + + @classmethod + def get_architecture(cls): + """Get the architecture name for the block.""" + return "Glm4MoeLiteForCausalLM" + + +EntryClass = ParallaxGLM4MoeLiteBlock diff --git a/src/parallax/server/cache_manager.py b/src/parallax/server/cache_manager.py index 458b17d2..3968de12 100644 --- a/src/parallax/server/cache_manager.py +++ b/src/parallax/server/cache_manager.py @@ -6,7 +6,7 @@ from parallax.server.cache.allocator import BlockAllocator, SlotAllocator from parallax.server.cache.base import BaseCache from parallax.server.cache.dsa_cache import DeepSeekSparseCache -from parallax.server.cache.kv_cache import KVCachePacked +from parallax.server.cache.kv_cache import KVCache, KVCachePacked from parallax.server.cache.linear_cache import LinearCache from parallax_utils.logging_config import get_logger @@ -150,6 +150,17 @@ def _create_cache(self, layer_type: str) -> BaseCache: index_head_dim=self.index_head_dim, index_n_heads=self.index_n_heads, ) + elif self.head_dim != self.head_dim_v: + # Different k/v head dims (e.g. MLA latent attention): + # use standard KVCache layout for Metal shader kernels + return KVCache( + num_blocks=self.num_gpu_blocks, + block_size=self.block_size, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + head_dim_v=self.head_dim_v, + dtype=self.dtype, + ) else: return KVCachePacked( num_blocks=self.num_gpu_blocks, diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index c3ab5be3..d3cc20ed 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -142,13 +142,32 @@ def __init__( ) qk_nope_head_dim = self.config.get("qk_nope_head_dim", None) qk_rope_head_dim = self.config.get("qk_rope_head_dim", None) + kv_lora_rank = self.config.get("kv_lora_rank", None) + v_head_dim = self.config.get("v_head_dim", head_dim) + model_type = self.config.get("model_type", "") + + # KV cache dimensions: default to head_dim / v_head_dim / num_key_value_heads, + # but override for MLA models based on their cache storage format. + cache_head_dim = head_dim + cache_head_dim_v = v_head_dim + cache_num_kv_heads = num_key_value_heads + if qk_nope_head_dim is not None and qk_rope_head_dim is not None: logger.debug( f"qk_nope_head_dim={qk_nope_head_dim}, qk_rope_head_dim={qk_rope_head_dim}" ) - head_dim = qk_nope_head_dim + qk_rope_head_dim - - v_head_dim = self.config.get("v_head_dim", None) + if model_type in ("glm4_moe_lite",) and kv_lora_rank is not None: + # MLA without kv_b_proj: cache stores compressed latent KV (1 head) + cache_head_dim = kv_lora_rank + qk_rope_head_dim + cache_head_dim_v = kv_lora_rank + cache_num_kv_heads = 1 + else: + # MLA with kv_b_proj (DeepSeek V3/V2): cache stores expanded per-head KV + cache_head_dim = qk_nope_head_dim + qk_rope_head_dim + logger.debug( + f"MLA cache ({model_type}): key_dim={cache_head_dim}, " + f"value_dim={cache_head_dim_v}, num_kv_heads={cache_num_kv_heads}" + ) linear_key_head_dim = self.config.get("linear_key_head_dim", None) linear_value_head_dim = self.config.get("linear_value_head_dim", None) linear_conv_kernel_dim = self.config.get("linear_conv_kernel_dim", None) @@ -189,12 +208,12 @@ def __init__( ) self.cache_manager = CacheManager( num_layers=self.num_shard_layers, - num_kv_heads=num_key_value_heads // tp_size, - head_dim=head_dim, + num_kv_heads=cache_num_kv_heads // tp_size, + head_dim=cache_head_dim, dtype=self.dtype, block_size=kv_block_size, cache_memory_fraction=kv_cache_memory_fraction, - head_dim_v=v_head_dim, + head_dim_v=cache_head_dim_v, index_head_dim=index_head_dim, index_n_heads=index_n_heads, layer_types=layer_types, diff --git a/src/parallax/utils/prefix_cache_utils.py b/src/parallax/utils/prefix_cache_utils.py index a4612223..57060402 100644 --- a/src/parallax/utils/prefix_cache_utils.py +++ b/src/parallax/utils/prefix_cache_utils.py @@ -21,6 +21,7 @@ def compute_attention_with_prefix_cache( mask: Optional[mx.array] = None, sinks: Optional[mx.array] = None, window_size: Optional[int] = None, + unembed_out: Optional[bool] = False, ) -> mx.array: """ Compute attention with prefix cache support. @@ -57,13 +58,14 @@ def compute_attention_with_prefix_cache( if max_prefix_len > 0: # Initialize prefix KV arrays with zeros for padding - head_dim = k_new.shape[-1] + k_head_dim = k_new.shape[-1] + v_head_dim = v_new.shape[-1] prefix_k_batch = mx.zeros( - (batch, num_kv_heads, max_prefix_len, head_dim), dtype=k_new.dtype - ) # (batch, n_kv_heads, max_prefix_len, head_dim) + (batch, num_kv_heads, max_prefix_len, k_head_dim), dtype=k_new.dtype + ) # (batch, n_kv_heads, max_prefix_len, k_head_dim) prefix_v_batch = mx.zeros( - (batch, num_kv_heads, max_prefix_len, head_dim), dtype=v_new.dtype - ) # (batch, n_kv_heads, max_prefix_len, head_dim) + (batch, num_kv_heads, max_prefix_len, v_head_dim), dtype=v_new.dtype + ) # (batch, n_kv_heads, max_prefix_len, v_head_dim) # Batch read prefix KV for all requests for i in range(batch): @@ -150,5 +152,6 @@ def compute_attention_with_prefix_cache( cache=None, **attention_kwargs, ) - output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) + if not unembed_out: + output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) return output