diff --git a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py index a076d1e29c..b29b3b47e0 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py @@ -345,6 +345,27 @@ class ModelConfig(BaseModelConfig): ), ] = True + use_index_cache: Annotated[ + bool, + Field( + description="Enable DSA IndexCache by reusing sparse attention top-k indices across layers. Matches vLLM's use_index_cache HF override.", + ), + ] = False + + index_topk_freq: Annotated[ + int, + Field( + description="DSA IndexCache frequency for recomputing top-k indices. 1 computes every layer; 4 computes on the vLLM PR's example schedule.", + ), + ] = 1 + + index_topk_pattern: Annotated[ + str | None, + Field( + description="Optional DSA IndexCache per-layer pattern where 'F' computes fresh top-k indices and 'S' reuses the previous full layer.", + ), + ] = None + fp8: Annotated[ bool, Field( diff --git a/src/prime_rl/trainer/model.py b/src/prime_rl/trainer/model.py index 910a978a66..2cdb802c6c 100644 --- a/src/prime_rl/trainer/model.py +++ b/src/prime_rl/trainer/model.py @@ -473,6 +473,10 @@ def get_model( model_config.use_grouped_mm = config.moe_use_grouped_mm model_config.fp8 = config.fp8 + for field_name in ("use_index_cache", "index_topk_freq", "index_topk_pattern"): + if field_name in config.model_fields_set or not hasattr(model_config, field_name): + setattr(model_config, field_name, getattr(config, field_name)) + # Ensure pad_token_id is set (some models like Qwen3MoE don't have it). # In transformers v5, token IDs moved from PretrainedConfig to GenerationConfig. if not hasattr(model_config, "pad_token_id") or model_config.pad_token_id is None: diff --git a/src/prime_rl/trainer/models/glm_moe_dsa/configuration_glm_moe_dsa.py b/src/prime_rl/trainer/models/glm_moe_dsa/configuration_glm_moe_dsa.py index caca4fc5e9..b6520d7c06 100644 --- a/src/prime_rl/trainer/models/glm_moe_dsa/configuration_glm_moe_dsa.py +++ b/src/prime_rl/trainer/models/glm_moe_dsa/configuration_glm_moe_dsa.py @@ -3,6 +3,17 @@ from transformers.configuration_utils import PretrainedConfig +def _index_cache_skip_topk(config, layer_idx: int) -> bool: + if not getattr(config, "use_index_cache", False): + return False + + index_topk_pattern = getattr(config, "index_topk_pattern", None) + if index_topk_pattern is not None: + return layer_idx < len(index_topk_pattern) and index_topk_pattern[layer_idx] == "S" + + return layer_idx % getattr(config, "index_topk_freq", 1) != 0 + + class GlmMoeDsaConfig(PretrainedConfig): r""" Configuration class for the GLM-5 (GlmMoeDsa) model. @@ -73,6 +84,13 @@ class GlmMoeDsaConfig(PretrainedConfig): Whether to use interleaved RoPE style in the sparse indexer. index_topk (`int`, defaults to 2048): Number of top tokens selected by the sparse indexer. + use_index_cache (`bool`, defaults to `False`): + Whether to reuse sparse attention top-k indices across DSA layers. + index_topk_freq (`int`, defaults to 1): + Frequency for recomputing top-k indices when IndexCache is enabled. + index_topk_pattern (`str`, *optional*): + Optional per-layer pattern where ``"F"`` computes fresh indices and + ``"S"`` reuses the cached indices from the previous full layer. scoring_func (`str`, defaults to `"sigmoid"`): Scoring function for MoE router. Must match the vLLM inference server's expectation (vLLM defaults to ``"softmax"`` when this @@ -141,6 +159,9 @@ def __init__( indexer_rope_interleave=True, pad_token_id=154820, index_topk=2048, + use_index_cache=False, + index_topk_freq=1, + index_topk_pattern=None, scoring_func="sigmoid", topk_method="noaux_tc", use_grouped_mm=True, @@ -194,6 +215,9 @@ def __init__( self.index_head_dim = index_head_dim self.indexer_rope_interleave = indexer_rope_interleave self.index_topk = index_topk + self.use_index_cache = use_index_cache + self.index_topk_freq = index_topk_freq + self.index_topk_pattern = index_topk_pattern self.scoring_func = scoring_func self.topk_method = topk_method self.use_grouped_mm = use_grouped_mm diff --git a/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py b/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py index 2bb90d94ff..699ba60a4b 100644 --- a/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py +++ b/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py @@ -13,7 +13,7 @@ from transformers.utils.deprecation import deprecate_kwarg from prime_rl.trainer.models.base import PreTrainedModelPrimeRL -from prime_rl.trainer.models.glm_moe_dsa.configuration_glm_moe_dsa import GlmMoeDsaConfig +from prime_rl.trainer.models.glm_moe_dsa.configuration_glm_moe_dsa import GlmMoeDsaConfig, _index_cache_skip_topk from prime_rl.trainer.models.glm_moe_dsa.converting_glm_moe_dsa import ( convert_hf_layer_to_tt, convert_hf_to_tt_moe, @@ -29,7 +29,7 @@ from prime_rl.trainer.models.layers.rotary_emb import RotaryEmbedding, RotaryEmbeddingConfig -def _sparse_mla_attention_args(config: GlmMoeDsaConfig) -> SparseMlaAttentionArgs: +def _sparse_mla_attention_args(config: GlmMoeDsaConfig, layer_idx: int) -> SparseMlaAttentionArgs: if config.q_lora_rank is None: raise ValueError("Sparse MLA attention requires q_lora_rank to be set") return SparseMlaAttentionArgs( @@ -46,6 +46,8 @@ def _sparse_mla_attention_args(config: GlmMoeDsaConfig) -> SparseMlaAttentionArg index_n_heads=config.index_n_heads, index_head_dim=config.index_head_dim, index_topk=config.index_topk, + use_index_cache=getattr(config, "use_index_cache", False), + skip_topk=_index_cache_skip_topk(config, layer_idx), ) @@ -53,7 +55,7 @@ class GlmMoeDsaDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: GlmMoeDsaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = GlmMoeDsaAttention(_sparse_mla_attention_args(config)) + self.self_attn = GlmMoeDsaAttention(_sparse_mla_attention_args(config, layer_idx)) moe_args = MoEArgs( num_experts=config.n_routed_experts, @@ -95,15 +97,17 @@ def forward( position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, ks: Optional[torch.Tensor] = None, ke: Optional[torch.Tensor] = None, + cached_indices: Optional[torch.Tensor] = None, routed_experts: Optional[torch.LongTensor] = None, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor | None]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - hidden_states, _ = self.self_attn( + hidden_states, cached_indices = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, ks=ks, ke=ke, + cached_indices=cached_indices, ) hidden_states = residual + hidden_states @@ -111,7 +115,7 @@ def forward( hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states, routed_experts=routed_experts) hidden_states = residual + hidden_states - return hidden_states + return hidden_states, cached_indices @auto_docstring @@ -256,15 +260,19 @@ def forward( else: ks, ke = ks_full, ke_full + cached_indices = None + use_index_cache = getattr(self.config, "use_index_cache", False) for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): routed_experts_layer = routed_experts[:, :, layer_idx, :] if routed_experts is not None else None - hidden_states = decoder_layer( + hidden_states, next_cached_indices = decoder_layer( hidden_states, position_embeddings=position_embeddings, ks=ks, ke=ke, + cached_indices=cached_indices, routed_experts=routed_experts_layer, ) + cached_indices = next_cached_indices if use_index_cache else None hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast(last_hidden_state=hidden_states) diff --git a/src/prime_rl/trainer/models/glm_moe_dsa/sparse_mla_attention.py b/src/prime_rl/trainer/models/glm_moe_dsa/sparse_mla_attention.py index 2488505e9f..d87dc183ca 100644 --- a/src/prime_rl/trainer/models/glm_moe_dsa/sparse_mla_attention.py +++ b/src/prime_rl/trainer/models/glm_moe_dsa/sparse_mla_attention.py @@ -32,6 +32,8 @@ class SparseMlaAttentionArgs: index_n_heads: int index_head_dim: int index_topk: int + use_index_cache: bool = False + skip_topk: bool = False class _SparseMLA(torch.autograd.Function): @@ -157,6 +159,8 @@ def __init__(self, args: SparseMlaAttentionArgs): self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, args.hidden_size, bias=args.attention_bias) self.indexer = Indexer(args) + self.use_index_cache = args.use_index_cache + self.skip_topk = args.skip_topk self.scaling = self.qk_head_dim ** (-0.5) self._cp_group: dist.ProcessGroup | None = None @@ -234,6 +238,7 @@ def forward( position_embeddings: tuple[torch.Tensor, torch.Tensor], ks: torch.Tensor | None = None, ke: torch.Tensor | None = None, + cached_indices: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: q_latent, k_compressed_normed, k_rope = self.attn_projections(hidden_states) @@ -241,17 +246,19 @@ def forward( k_compressed_normed = gather_for_cp(k_compressed_normed, self._cp_group) k_rope = gather_for_cp(k_rope, self._cp_group) - indices = self.indexer.compute_sparse_indices( - hidden_states_local=hidden_states, - q_latent_local=q_latent, - ks=ks, - ke=ke, - index_topk=self.args.index_topk, - position_embeddings_full=position_embeddings, - cp_group=self._cp_group, - cp_world_size=self._cp_world_size, - cp_rank=self._cp_rank, - ) + indices = cached_indices + if not self.skip_topk: + indices = self.indexer.compute_sparse_indices( + hidden_states_local=hidden_states, + q_latent_local=q_latent, + ks=ks, + ke=ke, + index_topk=self.args.index_topk, + position_embeddings_full=position_embeddings, + cp_group=self._cp_group, + cp_world_size=self._cp_world_size, + cp_rank=self._cp_rank, + ) sparse_q, sparse_kv, w_v = self.mla_up_proj( q_latent_local=q_latent, @@ -261,4 +268,5 @@ def forward( ) out = _SparseMLA.apply(sparse_q, sparse_kv, indices, self.scaling) - return self.output_proj(out, w_v), None + cached_indices = indices if self.use_index_cache else None + return self.output_proj(out, w_v), cached_indices