Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions packages/prime-rl-configs/src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,27 @@ class ModelConfig(BaseModelConfig):
),
] = True

use_index_cache: Annotated[
bool,
Field(
description="Enable DSA IndexCache by reusing sparse attention top-k indices across layers. Matches vLLM's use_index_cache HF override.",
),
] = False

index_topk_freq: Annotated[
int,
Field(
description="DSA IndexCache frequency for recomputing top-k indices. 1 computes every layer; 4 computes on the vLLM PR's example schedule.",
),
] = 1

index_topk_pattern: Annotated[
str | None,
Field(
description="Optional DSA IndexCache per-layer pattern where 'F' computes fresh top-k indices and 'S' reuses the previous full layer.",
),
] = None

fp8: Annotated[
bool,
Field(
Expand Down
4 changes: 4 additions & 0 deletions src/prime_rl/trainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,10 @@ def get_model(
model_config.use_grouped_mm = config.moe_use_grouped_mm
model_config.fp8 = config.fp8

for field_name in ("use_index_cache", "index_topk_freq", "index_topk_pattern"):
if field_name in config.model_fields_set or not hasattr(model_config, field_name):
setattr(model_config, field_name, getattr(config, field_name))

# Ensure pad_token_id is set (some models like Qwen3MoE don't have it).
# In transformers v5, token IDs moved from PretrainedConfig to GenerationConfig.
if not hasattr(model_config, "pad_token_id") or model_config.pad_token_id is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,17 @@
from transformers.configuration_utils import PretrainedConfig


def _index_cache_skip_topk(config, layer_idx: int) -> bool:
if not getattr(config, "use_index_cache", False):
return False

index_topk_pattern = getattr(config, "index_topk_pattern", None)
if index_topk_pattern is not None:
return layer_idx < len(index_topk_pattern) and index_topk_pattern[layer_idx] == "S"

return layer_idx % getattr(config, "index_topk_freq", 1) != 0


class GlmMoeDsaConfig(PretrainedConfig):
r"""
Configuration class for the GLM-5 (GlmMoeDsa) model.
Expand Down Expand Up @@ -73,6 +84,13 @@ class GlmMoeDsaConfig(PretrainedConfig):
Whether to use interleaved RoPE style in the sparse indexer.
index_topk (`int`, defaults to 2048):
Number of top tokens selected by the sparse indexer.
use_index_cache (`bool`, defaults to `False`):
Whether to reuse sparse attention top-k indices across DSA layers.
index_topk_freq (`int`, defaults to 1):
Frequency for recomputing top-k indices when IndexCache is enabled.
index_topk_pattern (`str`, *optional*):
Optional per-layer pattern where ``"F"`` computes fresh indices and
``"S"`` reuses the cached indices from the previous full layer.
scoring_func (`str`, defaults to `"sigmoid"`):
Scoring function for MoE router. Must match the vLLM inference
server's expectation (vLLM defaults to ``"softmax"`` when this
Expand Down Expand Up @@ -141,6 +159,9 @@ def __init__(
indexer_rope_interleave=True,
pad_token_id=154820,
index_topk=2048,
use_index_cache=False,
index_topk_freq=1,
index_topk_pattern=None,
scoring_func="sigmoid",
topk_method="noaux_tc",
use_grouped_mm=True,
Expand Down Expand Up @@ -194,6 +215,9 @@ def __init__(
self.index_head_dim = index_head_dim
self.indexer_rope_interleave = indexer_rope_interleave
self.index_topk = index_topk
self.use_index_cache = use_index_cache
self.index_topk_freq = index_topk_freq
self.index_topk_pattern = index_topk_pattern
self.scoring_func = scoring_func
self.topk_method = topk_method
self.use_grouped_mm = use_grouped_mm
Expand Down
22 changes: 15 additions & 7 deletions src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from transformers.utils.deprecation import deprecate_kwarg

from prime_rl.trainer.models.base import PreTrainedModelPrimeRL
from prime_rl.trainer.models.glm_moe_dsa.configuration_glm_moe_dsa import GlmMoeDsaConfig
from prime_rl.trainer.models.glm_moe_dsa.configuration_glm_moe_dsa import GlmMoeDsaConfig, _index_cache_skip_topk
from prime_rl.trainer.models.glm_moe_dsa.converting_glm_moe_dsa import (
convert_hf_layer_to_tt,
convert_hf_to_tt_moe,
Expand All @@ -29,7 +29,7 @@
from prime_rl.trainer.models.layers.rotary_emb import RotaryEmbedding, RotaryEmbeddingConfig


def _sparse_mla_attention_args(config: GlmMoeDsaConfig) -> SparseMlaAttentionArgs:
def _sparse_mla_attention_args(config: GlmMoeDsaConfig, layer_idx: int) -> SparseMlaAttentionArgs:
if config.q_lora_rank is None:
raise ValueError("Sparse MLA attention requires q_lora_rank to be set")
return SparseMlaAttentionArgs(
Expand All @@ -46,14 +46,16 @@ def _sparse_mla_attention_args(config: GlmMoeDsaConfig) -> SparseMlaAttentionArg
index_n_heads=config.index_n_heads,
index_head_dim=config.index_head_dim,
index_topk=config.index_topk,
use_index_cache=getattr(config, "use_index_cache", False),
skip_topk=_index_cache_skip_topk(config, layer_idx),
)


class GlmMoeDsaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: GlmMoeDsaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = GlmMoeDsaAttention(_sparse_mla_attention_args(config))
self.self_attn = GlmMoeDsaAttention(_sparse_mla_attention_args(config, layer_idx))

moe_args = MoEArgs(
num_experts=config.n_routed_experts,
Expand Down Expand Up @@ -95,23 +97,25 @@ def forward(
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
ks: Optional[torch.Tensor] = None,
ke: Optional[torch.Tensor] = None,
cached_indices: Optional[torch.Tensor] = None,
routed_experts: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor | None]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, _ = self.self_attn(
hidden_states, cached_indices = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
ks=ks,
ke=ke,
cached_indices=cached_indices,
)
hidden_states = residual + hidden_states

residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states, routed_experts=routed_experts)
hidden_states = residual + hidden_states
return hidden_states
return hidden_states, cached_indices


@auto_docstring
Expand Down Expand Up @@ -256,15 +260,19 @@ def forward(
else:
ks, ke = ks_full, ke_full

cached_indices = None
use_index_cache = getattr(self.config, "use_index_cache", False)
for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
routed_experts_layer = routed_experts[:, :, layer_idx, :] if routed_experts is not None else None
hidden_states = decoder_layer(
hidden_states, next_cached_indices = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
ks=ks,
ke=ke,
cached_indices=cached_indices,
routed_experts=routed_experts_layer,
)
cached_indices = next_cached_indices if use_index_cache else None

hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(last_hidden_state=hidden_states)
Expand Down
32 changes: 20 additions & 12 deletions src/prime_rl/trainer/models/glm_moe_dsa/sparse_mla_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class SparseMlaAttentionArgs:
index_n_heads: int
index_head_dim: int
index_topk: int
use_index_cache: bool = False
skip_topk: bool = False


class _SparseMLA(torch.autograd.Function):
Expand Down Expand Up @@ -157,6 +159,8 @@ def __init__(self, args: SparseMlaAttentionArgs):

self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, args.hidden_size, bias=args.attention_bias)
self.indexer = Indexer(args)
self.use_index_cache = args.use_index_cache
self.skip_topk = args.skip_topk
self.scaling = self.qk_head_dim ** (-0.5)

self._cp_group: dist.ProcessGroup | None = None
Expand Down Expand Up @@ -234,24 +238,27 @@ def forward(
position_embeddings: tuple[torch.Tensor, torch.Tensor],
ks: torch.Tensor | None = None,
ke: torch.Tensor | None = None,
cached_indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
q_latent, k_compressed_normed, k_rope = self.attn_projections(hidden_states)

if self.cp_enabled:
k_compressed_normed = gather_for_cp(k_compressed_normed, self._cp_group)
k_rope = gather_for_cp(k_rope, self._cp_group)

indices = self.indexer.compute_sparse_indices(
hidden_states_local=hidden_states,
q_latent_local=q_latent,
ks=ks,
ke=ke,
index_topk=self.args.index_topk,
position_embeddings_full=position_embeddings,
cp_group=self._cp_group,
cp_world_size=self._cp_world_size,
cp_rank=self._cp_rank,
)
indices = cached_indices
if not self.skip_topk:
indices = self.indexer.compute_sparse_indices(
hidden_states_local=hidden_states,
q_latent_local=q_latent,
ks=ks,
ke=ke,
index_topk=self.args.index_topk,
position_embeddings_full=position_embeddings,
cp_group=self._cp_group,
cp_world_size=self._cp_world_size,
cp_rank=self._cp_rank,
)
Comment thread
faresobeid marked this conversation as resolved.

sparse_q, sparse_kv, w_v = self.mla_up_proj(
q_latent_local=q_latent,
Expand All @@ -261,4 +268,5 @@ def forward(
)

out = _SparseMLA.apply(sparse_q, sparse_kv, indices, self.scaling)
return self.output_proj(out, w_v), None
cached_indices = indices if self.use_index_cache else None
return self.output_proj(out, w_v), cached_indices
Loading