Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/sgl_jax/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
**kwargs,
)

def get_padded_head_dim(self) -> int:
return (self.head_dim + 127) // 128 * 128

# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads (original, not replicated)."""
Expand Down
1 change: 1 addition & 0 deletions python/sgl_jax/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def load_model(self):
self.model_config.configure_for_tensor_parallel(self.tp_size)
self.model_config.log_kv_heads_info(self.tp_size)
self.model_config.hf_config.ep_size = self.ep_size
self.model_config.hf_config.head_dim_padded = self.model_config.get_padded_head_dim()

self.model = self.model_loader.load_model(
model_config=self.model_config,
Expand Down
2 changes: 1 addition & 1 deletion python/sgl_jax/srt/models/bailing_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def __init__(
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 40960)
self.head_dim = getattr(config, "head_dim", None)
self.head_dim = getattr(config, "head_dim_padded", None)
use_qk_norm = getattr(config, "use_qk_norm", False)
if hasattr(config, "partial_rotary_factor"):
rotary_dim = int(self.head_dim * config.partial_rotary_factor)
Expand Down
2 changes: 1 addition & 1 deletion python/sgl_jax/srt/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def __init__(
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(config, "bias", False)

head_dim = getattr(config, "head_dim", None)
head_dim = getattr(config, "head_dim_padded", None)
self.self_attn = LlamaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
Expand Down
25 changes: 13 additions & 12 deletions python/sgl_jax/srt/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,15 @@ def __init__(
max_position_embeddings: int,
rope_theta: float = 10000,
rope_scaling: dict[str, Any] | None = None,
head_dim: int | None = None,
layer_id: int = 0,
dtype: jnp.dtype = jnp.float16,
rngs: nnx.Rngs = None,
):
self.hidden_size = hidden_size
self.num_heads = num_heads
head_size = hidden_size // num_heads
self.head_size = head_size
self.scaling = head_size**-0.5
self.head_dim = head_dim or hidden_size // num_heads
self.scaling = head_dim**-0.5

self.q_proj = LinearBase(
input_size=hidden_size,
Expand All @@ -112,7 +112,7 @@ def __init__(
params_dtype=dtype,
)
self.c_proj = LinearBase(
input_size=num_heads * head_size,
input_size=num_heads * head_dim,
output_size=hidden_size,
use_bias=False,
kernel_axes=("tensor", None),
Expand All @@ -122,17 +122,17 @@ def __init__(

# Use torch version of RotaryEmbedding directly
self.rotary_emb = RotaryEmbedding(
head_size=head_size,
rotary_dim=head_size,
head_size=head_dim,
rotary_dim=head_dim,
max_position_embeddings=max_position_embeddings,
base=rope_theta,
is_neox_style=True,
dtype=dtype,
)
self.scaling = head_size**-0.5
self.scaling = head_dim**-0.5
self.attn = RadixAttention(
num_heads=num_heads,
head_dim=head_size,
head_dim=head_dim,
scaling=self.scaling,
num_kv_heads=num_heads,
layer_id=layer_id,
Expand All @@ -150,9 +150,9 @@ def __call__(
k, _ = self.k_proj(hidden_states)
v, _ = self.v_proj(hidden_states)

q = q.reshape(-1, self.num_heads, self.head_size)
k = k.reshape(-1, self.num_heads, self.head_size)
v = v.reshape(-1, self.num_heads, self.head_size)
q = q.reshape(-1, self.num_heads, self.head_dim)
k = k.reshape(-1, self.num_heads, self.head_dim)
v = v.reshape(-1, self.num_heads, self.head_dim)

q, k = self.rotary_emb(positions, q, k)
attn_output, kv_fused = self.attn(q, k, v, forward_batch, token_to_kv_pool)
Expand All @@ -169,7 +169,7 @@ def __init__(
rngs: nnx.Rngs = None,
):
self.layer_id = layer_id

head_dim = getattr(config, "head_dim_padded", None)
self.ln_1 = RMSNorm(
config.hidden_size,
epsilon=config.layer_norm_epsilon,
Expand All @@ -186,6 +186,7 @@ def __init__(
config.max_position_embeddings,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
head_dim=head_dim,
layer_id=layer_id,
dtype=dtype,
rngs=rngs,
Expand Down
2 changes: 1 addition & 1 deletion python/sgl_jax/srt/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __init__(
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
head_dim = getattr(config, "head_dim", None)
head_dim = getattr(config, "head_dim_padded", None)
self.self_attn = Qwen2Attention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
Expand Down
2 changes: 1 addition & 1 deletion python/sgl_jax/srt/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def __init__(
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
head_dim = getattr(config, "head_dim", None)
head_dim = getattr(config, "head_dim_padded", None)
self.self_attn = QWen3Attention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
Expand Down
2 changes: 1 addition & 1 deletion python/sgl_jax/srt/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __init__(
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 40960)
head_dim = getattr(config, "head_dim", None)
head_dim = getattr(config, "head_dim_padded", None)

self.self_attn = QWen3MoeAttention(
hidden_size=config.hidden_size,
Expand Down