diff --git a/python/sgl_jax/srt/configs/model_config.py b/python/sgl_jax/srt/configs/model_config.py index 9c9b2db5..b288006f 100644 --- a/python/sgl_jax/srt/configs/model_config.py +++ b/python/sgl_jax/srt/configs/model_config.py @@ -169,6 +169,9 @@ def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs): **kwargs, ) + def get_padded_head_dim(self) -> int: + return (self.head_dim + 127) // 128 * 128 + # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289 def get_total_num_kv_heads(self) -> int: """Returns the total number of KV heads (original, not replicated).""" diff --git a/python/sgl_jax/srt/model_executor/model_runner.py b/python/sgl_jax/srt/model_executor/model_runner.py index dfb51db6..99e4a32f 100644 --- a/python/sgl_jax/srt/model_executor/model_runner.py +++ b/python/sgl_jax/srt/model_executor/model_runner.py @@ -205,6 +205,7 @@ def load_model(self): self.model_config.configure_for_tensor_parallel(self.tp_size) self.model_config.log_kv_heads_info(self.tp_size) self.model_config.hf_config.ep_size = self.ep_size + self.model_config.hf_config.head_dim_padded = self.model_config.get_padded_head_dim() self.model = self.model_loader.load_model( model_config=self.model_config, diff --git a/python/sgl_jax/srt/models/bailing_moe.py b/python/sgl_jax/srt/models/bailing_moe.py index 7adced76..382c0549 100644 --- a/python/sgl_jax/srt/models/bailing_moe.py +++ b/python/sgl_jax/srt/models/bailing_moe.py @@ -210,7 +210,7 @@ def __init__( rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 40960) - self.head_dim = getattr(config, "head_dim", None) + self.head_dim = getattr(config, "head_dim_padded", None) use_qk_norm = getattr(config, "use_qk_norm", False) if hasattr(config, "partial_rotary_factor"): rotary_dim = int(self.head_dim * config.partial_rotary_factor) diff --git a/python/sgl_jax/srt/models/llama.py b/python/sgl_jax/srt/models/llama.py index 1d367bb7..cffc38be 100644 --- a/python/sgl_jax/srt/models/llama.py +++ b/python/sgl_jax/srt/models/llama.py @@ -216,7 +216,7 @@ def __init__( # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr(config, "bias", False) - head_dim = getattr(config, "head_dim", None) + head_dim = getattr(config, "head_dim_padded", None) self.self_attn = LlamaAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, diff --git a/python/sgl_jax/srt/models/qwen.py b/python/sgl_jax/srt/models/qwen.py index 642dc2a5..b6facc07 100644 --- a/python/sgl_jax/srt/models/qwen.py +++ b/python/sgl_jax/srt/models/qwen.py @@ -77,15 +77,15 @@ def __init__( max_position_embeddings: int, rope_theta: float = 10000, rope_scaling: dict[str, Any] | None = None, + head_dim: int | None = None, layer_id: int = 0, dtype: jnp.dtype = jnp.float16, rngs: nnx.Rngs = None, ): self.hidden_size = hidden_size self.num_heads = num_heads - head_size = hidden_size // num_heads - self.head_size = head_size - self.scaling = head_size**-0.5 + self.head_dim = head_dim or hidden_size // num_heads + self.scaling = head_dim**-0.5 self.q_proj = LinearBase( input_size=hidden_size, @@ -112,7 +112,7 @@ def __init__( params_dtype=dtype, ) self.c_proj = LinearBase( - input_size=num_heads * head_size, + input_size=num_heads * head_dim, output_size=hidden_size, use_bias=False, kernel_axes=("tensor", None), @@ -122,17 +122,17 @@ def __init__( # Use torch version of RotaryEmbedding directly self.rotary_emb = RotaryEmbedding( - head_size=head_size, - rotary_dim=head_size, + head_size=head_dim, + rotary_dim=head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta, is_neox_style=True, dtype=dtype, ) - self.scaling = head_size**-0.5 + self.scaling = head_dim**-0.5 self.attn = RadixAttention( num_heads=num_heads, - head_dim=head_size, + head_dim=head_dim, scaling=self.scaling, num_kv_heads=num_heads, layer_id=layer_id, @@ -150,9 +150,9 @@ def __call__( k, _ = self.k_proj(hidden_states) v, _ = self.v_proj(hidden_states) - q = q.reshape(-1, self.num_heads, self.head_size) - k = k.reshape(-1, self.num_heads, self.head_size) - v = v.reshape(-1, self.num_heads, self.head_size) + q = q.reshape(-1, self.num_heads, self.head_dim) + k = k.reshape(-1, self.num_heads, self.head_dim) + v = v.reshape(-1, self.num_heads, self.head_dim) q, k = self.rotary_emb(positions, q, k) attn_output, kv_fused = self.attn(q, k, v, forward_batch, token_to_kv_pool) @@ -169,7 +169,7 @@ def __init__( rngs: nnx.Rngs = None, ): self.layer_id = layer_id - + head_dim = getattr(config, "head_dim_padded", None) self.ln_1 = RMSNorm( config.hidden_size, epsilon=config.layer_norm_epsilon, @@ -186,6 +186,7 @@ def __init__( config.max_position_embeddings, rope_theta=rope_theta, rope_scaling=rope_scaling, + head_dim=head_dim, layer_id=layer_id, dtype=dtype, rngs=rngs, diff --git a/python/sgl_jax/srt/models/qwen2.py b/python/sgl_jax/srt/models/qwen2.py index cec0f315..87cfb0ce 100644 --- a/python/sgl_jax/srt/models/qwen2.py +++ b/python/sgl_jax/srt/models/qwen2.py @@ -179,7 +179,7 @@ def __init__( rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 32768) - head_dim = getattr(config, "head_dim", None) + head_dim = getattr(config, "head_dim_padded", None) self.self_attn = Qwen2Attention( hidden_size=config.hidden_size, num_heads=config.num_attention_heads, diff --git a/python/sgl_jax/srt/models/qwen3.py b/python/sgl_jax/srt/models/qwen3.py index 67bbc06b..3cf8c126 100644 --- a/python/sgl_jax/srt/models/qwen3.py +++ b/python/sgl_jax/srt/models/qwen3.py @@ -198,7 +198,7 @@ def __init__( rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 32768) - head_dim = getattr(config, "head_dim", None) + head_dim = getattr(config, "head_dim_padded", None) self.self_attn = QWen3Attention( hidden_size=config.hidden_size, num_heads=config.num_attention_heads, diff --git a/python/sgl_jax/srt/models/qwen3_moe.py b/python/sgl_jax/srt/models/qwen3_moe.py index 82ee1687..f7256cfa 100644 --- a/python/sgl_jax/srt/models/qwen3_moe.py +++ b/python/sgl_jax/srt/models/qwen3_moe.py @@ -152,7 +152,7 @@ def __init__( rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 40960) - head_dim = getattr(config, "head_dim", None) + head_dim = getattr(config, "head_dim_padded", None) self.self_attn = QWen3MoeAttention( hidden_size=config.hidden_size,