From b10257ed2c747919b4969b78db3fae2a5fa4387e Mon Sep 17 00:00:00 2001 From: faresobeid Date: Mon, 18 May 2026 19:07:34 +0530 Subject: [PATCH 1/4] Add indexCache training support --- .../src/prime_rl/configs/trainer.py | 21 ++++++++++++++ src/prime_rl/trainer/model.py | 4 +++ .../glm_moe_dsa/configuration_glm_moe_dsa.py | 13 +++++++++ .../glm_moe_dsa/modeling_glm_moe_dsa.py | 29 +++++++++++++++---- .../glm_moe_dsa/sparse_mla_attention.py | 29 +++++++++++-------- 5 files changed, 78 insertions(+), 18 deletions(-) diff --git a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py index a076d1e29c..b29b3b47e0 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py @@ -345,6 +345,27 @@ class ModelConfig(BaseModelConfig): ), ] = True + use_index_cache: Annotated[ + bool, + Field( + description="Enable DSA IndexCache by reusing sparse attention top-k indices across layers. Matches vLLM's use_index_cache HF override.", + ), + ] = False + + index_topk_freq: Annotated[ + int, + Field( + description="DSA IndexCache frequency for recomputing top-k indices. 1 computes every layer; 4 computes on the vLLM PR's example schedule.", + ), + ] = 1 + + index_topk_pattern: Annotated[ + str | None, + Field( + description="Optional DSA IndexCache per-layer pattern where 'F' computes fresh top-k indices and 'S' reuses the previous full layer.", + ), + ] = None + fp8: Annotated[ bool, Field( diff --git a/src/prime_rl/trainer/model.py b/src/prime_rl/trainer/model.py index 910a978a66..2cdb802c6c 100644 --- a/src/prime_rl/trainer/model.py +++ b/src/prime_rl/trainer/model.py @@ -473,6 +473,10 @@ def get_model( model_config.use_grouped_mm = config.moe_use_grouped_mm model_config.fp8 = config.fp8 + for field_name in ("use_index_cache", "index_topk_freq", "index_topk_pattern"): + if field_name in config.model_fields_set or not hasattr(model_config, field_name): + setattr(model_config, field_name, getattr(config, field_name)) + # Ensure pad_token_id is set (some models like Qwen3MoE don't have it). # In transformers v5, token IDs moved from PretrainedConfig to GenerationConfig. if not hasattr(model_config, "pad_token_id") or model_config.pad_token_id is None: diff --git a/src/prime_rl/trainer/models/glm_moe_dsa/configuration_glm_moe_dsa.py b/src/prime_rl/trainer/models/glm_moe_dsa/configuration_glm_moe_dsa.py index caca4fc5e9..d796455b46 100644 --- a/src/prime_rl/trainer/models/glm_moe_dsa/configuration_glm_moe_dsa.py +++ b/src/prime_rl/trainer/models/glm_moe_dsa/configuration_glm_moe_dsa.py @@ -73,6 +73,13 @@ class GlmMoeDsaConfig(PretrainedConfig): Whether to use interleaved RoPE style in the sparse indexer. index_topk (`int`, defaults to 2048): Number of top tokens selected by the sparse indexer. + use_index_cache (`bool`, defaults to `False`): + Whether to reuse sparse attention top-k indices across DSA layers. + index_topk_freq (`int`, defaults to 1): + Frequency for recomputing top-k indices when IndexCache is enabled. + index_topk_pattern (`str`, *optional*): + Optional per-layer pattern where ``"F"`` computes fresh indices and + ``"S"`` reuses the cached indices from the previous full layer. scoring_func (`str`, defaults to `"sigmoid"`): Scoring function for MoE router. Must match the vLLM inference server's expectation (vLLM defaults to ``"softmax"`` when this @@ -141,6 +148,9 @@ def __init__( indexer_rope_interleave=True, pad_token_id=154820, index_topk=2048, + use_index_cache=False, + index_topk_freq=1, + index_topk_pattern=None, scoring_func="sigmoid", topk_method="noaux_tc", use_grouped_mm=True, @@ -194,6 +204,9 @@ def __init__( self.index_head_dim = index_head_dim self.indexer_rope_interleave = indexer_rope_interleave self.index_topk = index_topk + self.use_index_cache = use_index_cache + self.index_topk_freq = index_topk_freq + self.index_topk_pattern = index_topk_pattern self.scoring_func = scoring_func self.topk_method = topk_method self.use_grouped_mm = use_grouped_mm diff --git a/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py b/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py index 2bb90d94ff..4371a5e62a 100644 --- a/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py +++ b/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py @@ -29,7 +29,19 @@ from prime_rl.trainer.models.layers.rotary_emb import RotaryEmbedding, RotaryEmbeddingConfig -def _sparse_mla_attention_args(config: GlmMoeDsaConfig) -> SparseMlaAttentionArgs: +def _index_cache_skip_topk(config: GlmMoeDsaConfig, layer_idx: int) -> bool: + if not getattr(config, "use_index_cache", False): + return False + + index_topk_pattern = getattr(config, "index_topk_pattern", None) + if index_topk_pattern is not None: + return layer_idx < len(index_topk_pattern) and index_topk_pattern[layer_idx] == "S" + + index_topk_freq = getattr(config, "index_topk_freq", 1) + return max(layer_idx - 1, 0) % index_topk_freq != 0 + + +def _sparse_mla_attention_args(config: GlmMoeDsaConfig, layer_idx: int) -> SparseMlaAttentionArgs: if config.q_lora_rank is None: raise ValueError("Sparse MLA attention requires q_lora_rank to be set") return SparseMlaAttentionArgs( @@ -46,6 +58,7 @@ def _sparse_mla_attention_args(config: GlmMoeDsaConfig) -> SparseMlaAttentionArg index_n_heads=config.index_n_heads, index_head_dim=config.index_head_dim, index_topk=config.index_topk, + skip_topk=_index_cache_skip_topk(config, layer_idx), ) @@ -53,7 +66,7 @@ class GlmMoeDsaDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: GlmMoeDsaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = GlmMoeDsaAttention(_sparse_mla_attention_args(config)) + self.self_attn = GlmMoeDsaAttention(_sparse_mla_attention_args(config, layer_idx)) moe_args = MoEArgs( num_experts=config.n_routed_experts, @@ -95,15 +108,17 @@ def forward( position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, ks: Optional[torch.Tensor] = None, ke: Optional[torch.Tensor] = None, + cached_indices: Optional[torch.Tensor] = None, routed_experts: Optional[torch.LongTensor] = None, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor | None]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - hidden_states, _ = self.self_attn( + hidden_states, cached_indices = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, ks=ks, ke=ke, + cached_indices=cached_indices, ) hidden_states = residual + hidden_states @@ -111,7 +126,7 @@ def forward( hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states, routed_experts=routed_experts) hidden_states = residual + hidden_states - return hidden_states + return hidden_states, cached_indices @auto_docstring @@ -256,13 +271,15 @@ def forward( else: ks, ke = ks_full, ke_full + cached_indices = None for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): routed_experts_layer = routed_experts[:, :, layer_idx, :] if routed_experts is not None else None - hidden_states = decoder_layer( + hidden_states, cached_indices = decoder_layer( hidden_states, position_embeddings=position_embeddings, ks=ks, ke=ke, + cached_indices=cached_indices, routed_experts=routed_experts_layer, ) diff --git a/src/prime_rl/trainer/models/glm_moe_dsa/sparse_mla_attention.py b/src/prime_rl/trainer/models/glm_moe_dsa/sparse_mla_attention.py index 2488505e9f..d4e4325067 100644 --- a/src/prime_rl/trainer/models/glm_moe_dsa/sparse_mla_attention.py +++ b/src/prime_rl/trainer/models/glm_moe_dsa/sparse_mla_attention.py @@ -32,6 +32,7 @@ class SparseMlaAttentionArgs: index_n_heads: int index_head_dim: int index_topk: int + skip_topk: bool = False class _SparseMLA(torch.autograd.Function): @@ -157,6 +158,7 @@ def __init__(self, args: SparseMlaAttentionArgs): self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, args.hidden_size, bias=args.attention_bias) self.indexer = Indexer(args) + self.skip_topk = args.skip_topk self.scaling = self.qk_head_dim ** (-0.5) self._cp_group: dist.ProcessGroup | None = None @@ -234,6 +236,7 @@ def forward( position_embeddings: tuple[torch.Tensor, torch.Tensor], ks: torch.Tensor | None = None, ke: torch.Tensor | None = None, + cached_indices: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: q_latent, k_compressed_normed, k_rope = self.attn_projections(hidden_states) @@ -241,17 +244,19 @@ def forward( k_compressed_normed = gather_for_cp(k_compressed_normed, self._cp_group) k_rope = gather_for_cp(k_rope, self._cp_group) - indices = self.indexer.compute_sparse_indices( - hidden_states_local=hidden_states, - q_latent_local=q_latent, - ks=ks, - ke=ke, - index_topk=self.args.index_topk, - position_embeddings_full=position_embeddings, - cp_group=self._cp_group, - cp_world_size=self._cp_world_size, - cp_rank=self._cp_rank, - ) + indices = cached_indices + if not self.skip_topk: + indices = self.indexer.compute_sparse_indices( + hidden_states_local=hidden_states, + q_latent_local=q_latent, + ks=ks, + ke=ke, + index_topk=self.args.index_topk, + position_embeddings_full=position_embeddings, + cp_group=self._cp_group, + cp_world_size=self._cp_world_size, + cp_rank=self._cp_rank, + ) sparse_q, sparse_kv, w_v = self.mla_up_proj( q_latent_local=q_latent, @@ -261,4 +266,4 @@ def forward( ) out = _SparseMLA.apply(sparse_q, sparse_kv, indices, self.scaling) - return self.output_proj(out, w_v), None + return self.output_proj(out, w_v), indices From f1e1a5edd7eaad0b3425b0f2725654446e82964d Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Mon, 18 May 2026 14:31:22 +0000 Subject: [PATCH 2/4] fix glm moe dsa index cache threading Co-authored-by: faresobeid --- .../glm_moe_dsa/modeling_glm_moe_dsa.py | 5 +- .../glm_moe_dsa/sparse_mla_attention.py | 5 +- .../models/test_glm_moe_dsa_index_cache.py | 170 ++++++++++++++++++ 3 files changed, 178 insertions(+), 2 deletions(-) create mode 100644 tests/unit/train/models/test_glm_moe_dsa_index_cache.py diff --git a/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py b/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py index 4371a5e62a..fa42e9c879 100644 --- a/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py +++ b/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py @@ -58,6 +58,7 @@ def _sparse_mla_attention_args(config: GlmMoeDsaConfig, layer_idx: int) -> Spars index_n_heads=config.index_n_heads, index_head_dim=config.index_head_dim, index_topk=config.index_topk, + use_index_cache=getattr(config, "use_index_cache", False), skip_topk=_index_cache_skip_topk(config, layer_idx), ) @@ -272,9 +273,10 @@ def forward( ks, ke = ks_full, ke_full cached_indices = None + use_index_cache = getattr(self.config, "use_index_cache", False) for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): routed_experts_layer = routed_experts[:, :, layer_idx, :] if routed_experts is not None else None - hidden_states, cached_indices = decoder_layer( + hidden_states, next_cached_indices = decoder_layer( hidden_states, position_embeddings=position_embeddings, ks=ks, @@ -282,6 +284,7 @@ def forward( cached_indices=cached_indices, routed_experts=routed_experts_layer, ) + cached_indices = next_cached_indices if use_index_cache else None hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast(last_hidden_state=hidden_states) diff --git a/src/prime_rl/trainer/models/glm_moe_dsa/sparse_mla_attention.py b/src/prime_rl/trainer/models/glm_moe_dsa/sparse_mla_attention.py index d4e4325067..d87dc183ca 100644 --- a/src/prime_rl/trainer/models/glm_moe_dsa/sparse_mla_attention.py +++ b/src/prime_rl/trainer/models/glm_moe_dsa/sparse_mla_attention.py @@ -32,6 +32,7 @@ class SparseMlaAttentionArgs: index_n_heads: int index_head_dim: int index_topk: int + use_index_cache: bool = False skip_topk: bool = False @@ -158,6 +159,7 @@ def __init__(self, args: SparseMlaAttentionArgs): self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, args.hidden_size, bias=args.attention_bias) self.indexer = Indexer(args) + self.use_index_cache = args.use_index_cache self.skip_topk = args.skip_topk self.scaling = self.qk_head_dim ** (-0.5) @@ -266,4 +268,5 @@ def forward( ) out = _SparseMLA.apply(sparse_q, sparse_kv, indices, self.scaling) - return self.output_proj(out, w_v), indices + cached_indices = indices if self.use_index_cache else None + return self.output_proj(out, w_v), cached_indices diff --git a/tests/unit/train/models/test_glm_moe_dsa_index_cache.py b/tests/unit/train/models/test_glm_moe_dsa_index_cache.py new file mode 100644 index 0000000000..87c30c8916 --- /dev/null +++ b/tests/unit/train/models/test_glm_moe_dsa_index_cache.py @@ -0,0 +1,170 @@ +import torch +from torch import nn + +from prime_rl.trainer.models.glm_moe_dsa import sparse_mla_attention +from prime_rl.trainer.models.glm_moe_dsa.configuration_glm_moe_dsa import GlmMoeDsaConfig +from prime_rl.trainer.models.glm_moe_dsa.modeling_glm_moe_dsa import GlmMoeDsaModel +from prime_rl.trainer.models.glm_moe_dsa.sparse_mla_attention import GlmMoeDsaAttention, SparseMlaAttentionArgs + + +def _attention_args(use_index_cache: bool, skip_topk: bool = False) -> SparseMlaAttentionArgs: + return SparseMlaAttentionArgs( + hidden_size=4, + num_attention_heads=1, + kv_lora_rank=2, + q_lora_rank=2, + qk_rope_head_dim=1, + qk_nope_head_dim=1, + qk_head_dim=2, + v_head_dim=2, + attention_bias=False, + rms_norm_eps=1e-5, + index_n_heads=1, + index_head_dim=2, + index_topk=64, + use_index_cache=use_index_cache, + skip_topk=skip_topk, + ) + + +def _stub_attention(monkeypatch, attention: GlmMoeDsaAttention, computed_indices: torch.Tensor) -> dict[str, torch.Tensor]: + captured = {} + + class FakeSparseMLA: + @staticmethod + def apply(sparse_q, sparse_kv, indices, scaling): + captured["indices"] = indices + return sparse_q + + def compute_sparse_indices(**kwargs): + return computed_indices + + monkeypatch.setattr(sparse_mla_attention, "_SparseMLA", FakeSparseMLA) + monkeypatch.setattr( + attention, + "attn_projections", + lambda hidden_states: ( + torch.zeros(1, hidden_states.shape[1], 2), + torch.zeros(1, hidden_states.shape[1], 2), + torch.zeros(1, hidden_states.shape[1], 1), + ), + ) + monkeypatch.setattr( + attention, + "mla_up_proj", + lambda **kwargs: ( + torch.zeros(1, computed_indices.shape[1], 1, 2), + torch.zeros(1, computed_indices.shape[1] + 1, 1, 2), + torch.zeros(1, 2, 2), + ), + ) + monkeypatch.setattr(attention, "output_proj", lambda attn_output, w_v: attn_output) + monkeypatch.setattr(attention.indexer, "compute_sparse_indices", compute_sparse_indices) + + return captured + + +def test_attention_does_not_return_indices_when_index_cache_disabled(monkeypatch): + attention = GlmMoeDsaAttention(_attention_args(use_index_cache=False)) + computed_indices = torch.ones(1, 3, 1, 64, dtype=torch.int32) + captured = _stub_attention(monkeypatch, attention, computed_indices) + + _, returned_indices = attention( + hidden_states=torch.zeros(1, 3, 4), + position_embeddings=(torch.zeros(1, 3, 1), torch.zeros(1, 3, 1)), + ks=torch.arange(3, dtype=torch.int32), + ke=torch.arange(1, 4, dtype=torch.int32), + ) + + assert captured["indices"] is computed_indices + assert returned_indices is None + + +def test_attention_returns_indices_when_index_cache_enabled(monkeypatch): + attention = GlmMoeDsaAttention(_attention_args(use_index_cache=True)) + computed_indices = torch.ones(1, 3, 1, 64, dtype=torch.int32) + _stub_attention(monkeypatch, attention, computed_indices) + + _, returned_indices = attention( + hidden_states=torch.zeros(1, 3, 4), + position_embeddings=(torch.zeros(1, 3, 1), torch.zeros(1, 3, 1)), + ks=torch.arange(3, dtype=torch.int32), + ke=torch.arange(1, 4, dtype=torch.int32), + ) + + assert returned_indices is computed_indices + + +def _tiny_config(use_index_cache: bool) -> GlmMoeDsaConfig: + return GlmMoeDsaConfig( + vocab_size=8, + hidden_size=4, + intermediate_size=8, + moe_intermediate_size=8, + num_hidden_layers=2, + num_attention_heads=1, + num_key_value_heads=1, + n_shared_experts=1, + n_routed_experts=2, + kv_lora_rank=2, + q_lora_rank=2, + qk_rope_head_dim=1, + qk_nope_head_dim=1, + v_head_dim=2, + first_k_dense_replace=2, + index_n_heads=1, + index_head_dim=2, + index_topk=64, + max_position_embeddings=16, + pad_token_id=0, + use_index_cache=use_index_cache, + ) + + +class RecordingLayer(nn.Module): + def __init__(self, next_cached_indices: torch.Tensor): + super().__init__() + self.next_cached_indices = next_cached_indices + self.seen_cached_indices = [] + + def forward( + self, + hidden_states, + position_embeddings=None, + ks=None, + ke=None, + cached_indices=None, + routed_experts=None, + ): + self.seen_cached_indices.append(cached_indices) + return hidden_states, self.next_cached_indices + + +def _run_model_with_recording_layers(use_index_cache: bool) -> list[RecordingLayer]: + model = GlmMoeDsaModel(_tiny_config(use_index_cache=use_index_cache)) + layers = [ + RecordingLayer(torch.ones(1, 3, 1, 64, dtype=torch.int32)), + RecordingLayer(torch.full((1, 3, 1, 64), 2, dtype=torch.int32)), + ] + model.layers = nn.ModuleList(layers) + + model( + inputs_embeds=torch.zeros(1, 3, 4), + position_ids=torch.arange(3).unsqueeze(0), + ) + + return layers + + +def test_model_does_not_thread_indices_when_index_cache_disabled(): + layers = _run_model_with_recording_layers(use_index_cache=False) + + assert layers[0].seen_cached_indices == [None] + assert layers[1].seen_cached_indices == [None] + + +def test_model_threads_indices_when_index_cache_enabled(): + layers = _run_model_with_recording_layers(use_index_cache=True) + + assert layers[0].seen_cached_indices == [None] + assert layers[1].seen_cached_indices[0] is layers[0].next_cached_indices From 9fb996a5ff08fa0a0ecd730a33e2742bcfc2d5c9 Mon Sep 17 00:00:00 2001 From: faresobeid <111092724+faresobeid@users.noreply.github.com> Date: Mon, 18 May 2026 15:33:44 +0100 Subject: [PATCH 3/4] Delete tests/unit/train/models/test_glm_moe_dsa_index_cache.py Signed-off-by: faresobeid <111092724+faresobeid@users.noreply.github.com> --- .../models/test_glm_moe_dsa_index_cache.py | 170 ------------------ 1 file changed, 170 deletions(-) delete mode 100644 tests/unit/train/models/test_glm_moe_dsa_index_cache.py diff --git a/tests/unit/train/models/test_glm_moe_dsa_index_cache.py b/tests/unit/train/models/test_glm_moe_dsa_index_cache.py deleted file mode 100644 index 87c30c8916..0000000000 --- a/tests/unit/train/models/test_glm_moe_dsa_index_cache.py +++ /dev/null @@ -1,170 +0,0 @@ -import torch -from torch import nn - -from prime_rl.trainer.models.glm_moe_dsa import sparse_mla_attention -from prime_rl.trainer.models.glm_moe_dsa.configuration_glm_moe_dsa import GlmMoeDsaConfig -from prime_rl.trainer.models.glm_moe_dsa.modeling_glm_moe_dsa import GlmMoeDsaModel -from prime_rl.trainer.models.glm_moe_dsa.sparse_mla_attention import GlmMoeDsaAttention, SparseMlaAttentionArgs - - -def _attention_args(use_index_cache: bool, skip_topk: bool = False) -> SparseMlaAttentionArgs: - return SparseMlaAttentionArgs( - hidden_size=4, - num_attention_heads=1, - kv_lora_rank=2, - q_lora_rank=2, - qk_rope_head_dim=1, - qk_nope_head_dim=1, - qk_head_dim=2, - v_head_dim=2, - attention_bias=False, - rms_norm_eps=1e-5, - index_n_heads=1, - index_head_dim=2, - index_topk=64, - use_index_cache=use_index_cache, - skip_topk=skip_topk, - ) - - -def _stub_attention(monkeypatch, attention: GlmMoeDsaAttention, computed_indices: torch.Tensor) -> dict[str, torch.Tensor]: - captured = {} - - class FakeSparseMLA: - @staticmethod - def apply(sparse_q, sparse_kv, indices, scaling): - captured["indices"] = indices - return sparse_q - - def compute_sparse_indices(**kwargs): - return computed_indices - - monkeypatch.setattr(sparse_mla_attention, "_SparseMLA", FakeSparseMLA) - monkeypatch.setattr( - attention, - "attn_projections", - lambda hidden_states: ( - torch.zeros(1, hidden_states.shape[1], 2), - torch.zeros(1, hidden_states.shape[1], 2), - torch.zeros(1, hidden_states.shape[1], 1), - ), - ) - monkeypatch.setattr( - attention, - "mla_up_proj", - lambda **kwargs: ( - torch.zeros(1, computed_indices.shape[1], 1, 2), - torch.zeros(1, computed_indices.shape[1] + 1, 1, 2), - torch.zeros(1, 2, 2), - ), - ) - monkeypatch.setattr(attention, "output_proj", lambda attn_output, w_v: attn_output) - monkeypatch.setattr(attention.indexer, "compute_sparse_indices", compute_sparse_indices) - - return captured - - -def test_attention_does_not_return_indices_when_index_cache_disabled(monkeypatch): - attention = GlmMoeDsaAttention(_attention_args(use_index_cache=False)) - computed_indices = torch.ones(1, 3, 1, 64, dtype=torch.int32) - captured = _stub_attention(monkeypatch, attention, computed_indices) - - _, returned_indices = attention( - hidden_states=torch.zeros(1, 3, 4), - position_embeddings=(torch.zeros(1, 3, 1), torch.zeros(1, 3, 1)), - ks=torch.arange(3, dtype=torch.int32), - ke=torch.arange(1, 4, dtype=torch.int32), - ) - - assert captured["indices"] is computed_indices - assert returned_indices is None - - -def test_attention_returns_indices_when_index_cache_enabled(monkeypatch): - attention = GlmMoeDsaAttention(_attention_args(use_index_cache=True)) - computed_indices = torch.ones(1, 3, 1, 64, dtype=torch.int32) - _stub_attention(monkeypatch, attention, computed_indices) - - _, returned_indices = attention( - hidden_states=torch.zeros(1, 3, 4), - position_embeddings=(torch.zeros(1, 3, 1), torch.zeros(1, 3, 1)), - ks=torch.arange(3, dtype=torch.int32), - ke=torch.arange(1, 4, dtype=torch.int32), - ) - - assert returned_indices is computed_indices - - -def _tiny_config(use_index_cache: bool) -> GlmMoeDsaConfig: - return GlmMoeDsaConfig( - vocab_size=8, - hidden_size=4, - intermediate_size=8, - moe_intermediate_size=8, - num_hidden_layers=2, - num_attention_heads=1, - num_key_value_heads=1, - n_shared_experts=1, - n_routed_experts=2, - kv_lora_rank=2, - q_lora_rank=2, - qk_rope_head_dim=1, - qk_nope_head_dim=1, - v_head_dim=2, - first_k_dense_replace=2, - index_n_heads=1, - index_head_dim=2, - index_topk=64, - max_position_embeddings=16, - pad_token_id=0, - use_index_cache=use_index_cache, - ) - - -class RecordingLayer(nn.Module): - def __init__(self, next_cached_indices: torch.Tensor): - super().__init__() - self.next_cached_indices = next_cached_indices - self.seen_cached_indices = [] - - def forward( - self, - hidden_states, - position_embeddings=None, - ks=None, - ke=None, - cached_indices=None, - routed_experts=None, - ): - self.seen_cached_indices.append(cached_indices) - return hidden_states, self.next_cached_indices - - -def _run_model_with_recording_layers(use_index_cache: bool) -> list[RecordingLayer]: - model = GlmMoeDsaModel(_tiny_config(use_index_cache=use_index_cache)) - layers = [ - RecordingLayer(torch.ones(1, 3, 1, 64, dtype=torch.int32)), - RecordingLayer(torch.full((1, 3, 1, 64), 2, dtype=torch.int32)), - ] - model.layers = nn.ModuleList(layers) - - model( - inputs_embeds=torch.zeros(1, 3, 4), - position_ids=torch.arange(3).unsqueeze(0), - ) - - return layers - - -def test_model_does_not_thread_indices_when_index_cache_disabled(): - layers = _run_model_with_recording_layers(use_index_cache=False) - - assert layers[0].seen_cached_indices == [None] - assert layers[1].seen_cached_indices == [None] - - -def test_model_threads_indices_when_index_cache_enabled(): - layers = _run_model_with_recording_layers(use_index_cache=True) - - assert layers[0].seen_cached_indices == [None] - assert layers[1].seen_cached_indices[0] is layers[0].next_cached_indices From 563b77f2e4f7637cc399f1f410995688c188b685 Mon Sep 17 00:00:00 2001 From: faresobeid Date: Mon, 18 May 2026 21:12:37 +0530 Subject: [PATCH 4/4] fix --- .../glm_moe_dsa/configuration_glm_moe_dsa.py | 11 +++++++++++ .../models/glm_moe_dsa/modeling_glm_moe_dsa.py | 14 +------------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/prime_rl/trainer/models/glm_moe_dsa/configuration_glm_moe_dsa.py b/src/prime_rl/trainer/models/glm_moe_dsa/configuration_glm_moe_dsa.py index d796455b46..b6520d7c06 100644 --- a/src/prime_rl/trainer/models/glm_moe_dsa/configuration_glm_moe_dsa.py +++ b/src/prime_rl/trainer/models/glm_moe_dsa/configuration_glm_moe_dsa.py @@ -3,6 +3,17 @@ from transformers.configuration_utils import PretrainedConfig +def _index_cache_skip_topk(config, layer_idx: int) -> bool: + if not getattr(config, "use_index_cache", False): + return False + + index_topk_pattern = getattr(config, "index_topk_pattern", None) + if index_topk_pattern is not None: + return layer_idx < len(index_topk_pattern) and index_topk_pattern[layer_idx] == "S" + + return layer_idx % getattr(config, "index_topk_freq", 1) != 0 + + class GlmMoeDsaConfig(PretrainedConfig): r""" Configuration class for the GLM-5 (GlmMoeDsa) model. diff --git a/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py b/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py index fa42e9c879..699ba60a4b 100644 --- a/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py +++ b/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py @@ -13,7 +13,7 @@ from transformers.utils.deprecation import deprecate_kwarg from prime_rl.trainer.models.base import PreTrainedModelPrimeRL -from prime_rl.trainer.models.glm_moe_dsa.configuration_glm_moe_dsa import GlmMoeDsaConfig +from prime_rl.trainer.models.glm_moe_dsa.configuration_glm_moe_dsa import GlmMoeDsaConfig, _index_cache_skip_topk from prime_rl.trainer.models.glm_moe_dsa.converting_glm_moe_dsa import ( convert_hf_layer_to_tt, convert_hf_to_tt_moe, @@ -29,18 +29,6 @@ from prime_rl.trainer.models.layers.rotary_emb import RotaryEmbedding, RotaryEmbeddingConfig -def _index_cache_skip_topk(config: GlmMoeDsaConfig, layer_idx: int) -> bool: - if not getattr(config, "use_index_cache", False): - return False - - index_topk_pattern = getattr(config, "index_topk_pattern", None) - if index_topk_pattern is not None: - return layer_idx < len(index_topk_pattern) and index_topk_pattern[layer_idx] == "S" - - index_topk_freq = getattr(config, "index_topk_freq", 1) - return max(layer_idx - 1, 0) % index_topk_freq != 0 - - def _sparse_mla_attention_args(config: GlmMoeDsaConfig, layer_idx: int) -> SparseMlaAttentionArgs: if config.q_lora_rank is None: raise ValueError("Sparse MLA attention requires q_lora_rank to be set")