Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -227,14 +227,11 @@ torchvision = { index = "pytorch-cu128" }
torchaudio = { index = "pytorch-cu128" }
torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" }
dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" }
transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" }
transformers = { git = "https://github.com/JJJYmmm/transformers", rev = "d362c90c378b4b32b54513f1627b6d9d59ccc6a1" }
flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdirectory = "flash_attn/cute", rev = "96bd151" }
pydantic-config = { git = "https://github.com/samsja/pydantic_config.git", branch = "main" }
vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.22/vllm_router-0.1.22-cp38-abi3-manylinux_2_28_x86_64.whl" }
vllm = [
{ url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_x86_64.whl", marker = "platform_machine == 'x86_64'" },
{ url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl", marker = "platform_machine == 'aarch64'" },
]
vllm = { git = "https://github.com/Zyphra/vllm.git", branch = "zaya1-pr" }
deep-ep = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/deep_ep-1.2.1+29d31c0-cp312-cp312-linux_x86_64.whl" }
deep-gemm = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/deep_gemm-2.5.0+891d57b-cp312-cp312-linux_x86_64.whl" }
nixl-cu12 = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/nixl_cu12-0.10.1-cp312-cp312-linux_x86_64.whl" }
Expand Down
47 changes: 47 additions & 0 deletions scripts/mini_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import Glm4MoeForCausalLM as HFGlm4MoeForCausalLM
from transformers import ZayaForCausalLM as HFZayaForCausalLM
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
Qwen3_5MoeForConditionalGeneration as HFQwen3_5MoeVLM,
)
Expand All @@ -29,6 +30,8 @@
from prime_rl.trainer.models.minimax_m2 import MiniMaxM2Config
from prime_rl.trainer.models.minimax_m2 import MiniMaxM2ForCausalLM as PrimeRLMiniMaxM2ForCausalLM
from prime_rl.trainer.models.qwen3_5_moe import Qwen3_5MoeForCausalLM as PrimeRLQwen3_5MoeVLM
from prime_rl.trainer.models.zaya import ZayaConfig
from prime_rl.trainer.models.zaya import ZayaForCausalLM as PrimeRLZayaForCausalLM
from prime_rl.utils.logger import setup_logger
from prime_rl.utils.utils import default_dtype

Expand Down Expand Up @@ -192,6 +195,50 @@ def _qwen3_5_moe_vlm_config():
"tokenizer_source": "Qwen/Qwen3.5-35B-A3B",
"is_vlm": True,
},
"zaya": {
"config_class": ZayaConfig,
"config_kwargs": dict(
vocab_size=128,
hidden_size=32,
ffn_hidden_size=16,
num_hidden_layers=4,
num_experts=3,
num_attention_heads=4,
num_query_groups=2,
num_key_value_heads=2,
head_dim=8,
max_position_embeddings=512,
norm_epsilon=1e-5,
rope_theta=10000.0,
partial_rotary_factor=0.5,
moe_router_topk=1,
zaya_mlp_expansion=8,
zaya_use_mod=True,
zaya_use_eda=True,
add_bias_linear=False,
attention_bias=False,
lm_head_bias=False,
tie_word_embeddings=True,
use_cache=False,
use_grouped_mm=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
layer_types=["hybrid", "hybrid", "hybrid", "hybrid"],
rope_parameters={
"hybrid": {
"rope_type": "default",
"rope_theta": 10000.0,
"partial_rotary_factor": 0.5,
}
},
_attn_implementation="sdpa",
),
"hf_model_class": HFZayaForCausalLM,
"prime_model_class": PrimeRLZayaForCausalLM,
# Different weight format from official release but this the 'official' HF supported version
"tokenizer_source": "JJJYmmm/ZAYA1-8B-HF",
},
# glm_moe_dsa: HF implementation is incorrect, not supported here
}

Expand Down
6 changes: 5 additions & 1 deletion src/prime_rl/inference/vllm/worker/weight_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@ def load_weights_checkpoint_layerwise(
model_config,
vllm_config,
) -> None:
logger.info("Reloading checkpoint-format weights with vLLM layerwise processing")
device = next(model.parameters()).device
with torch.device(device), set_current_vllm_config(vllm_config):
if getattr(model_config.hf_config, "model_type", None) == "zaya":
model.load_weights(state_iter) # type: ignore
return

logger.info("Reloading checkpoint-format weights with vLLM layerwise processing")
initialize_layerwise_reload(model)
model.load_weights(state_iter) # type: ignore
finalize_layerwise_reload(model, model_config)
Expand Down
18 changes: 9 additions & 9 deletions src/prime_rl/trainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
)
from prime_rl.trainer.models.layers.fp8_linear import replace_linear_with_fp8_blockwise_linear
from prime_rl.trainer.models.layers.lm_head import inject_prime_lm_head
from prime_rl.trainer.models.layers.moe import LatentMoE, MoE
from prime_rl.trainer.models.layers.moe import LatentMoE, MoE, ZayaMoE
from prime_rl.trainer.parallel_dims import ParallelDims
from prime_rl.trainer.weights import (
load_state_dict,
Expand Down Expand Up @@ -334,8 +334,8 @@ def freeze_moe_router(model: nn.Module) -> None:
if mlp is None:
continue

# Custom implementation: MoE/LatentMoE class with router attribute
if isinstance(mlp, (MoE, LatentMoE)):
# Custom implementation: MoE/LatentMoE/ZayaMoE class with router attribute
if isinstance(mlp, (MoE, LatentMoE, ZayaMoE)):
for param in mlp.router.parameters():
param.requires_grad = False
num_frozen += 1
Expand Down Expand Up @@ -384,7 +384,7 @@ def configure_moe_ep_backend(model: nn.Module, config: ModelConfig) -> None:
configure_num_sms(config.deepep_num_sms)
language_model = get_language_model(model)
for transformer_block in language_model.layers:
if not isinstance(transformer_block.mlp, (MoE, LatentMoE)):
if not isinstance(transformer_block.mlp, (MoE, LatentMoE, ZayaMoE)):
continue
transformer_block.mlp.set_ep_comm_backend(backend)
transformer_block.mlp.set_deepep_token_chunk_size(config.deepep_token_chunk_size)
Expand Down Expand Up @@ -617,7 +617,7 @@ def setup_fsdp(model: nn.Module, config: ModelConfig, parallel_dims: ParallelDim

for transformer_block in transformer_layers:
block_mlp = getattr(transformer_block, "mlp", None)
if parallel_dims.ep_enabled and block_mlp is not None and isinstance(block_mlp, (MoE, LatentMoE)):
if parallel_dims.ep_enabled and block_mlp is not None and isinstance(block_mlp, (MoE, LatentMoE, ZayaMoE)):
fully_shard(block_mlp.experts, mesh=dp_mod_ep_mesh, **fsdp_config)

block_mlp.experts.set_gradient_divide_factor(parallel_dims.fsdp_gradient_divide_factor)
Expand Down Expand Up @@ -674,7 +674,7 @@ def setup_fsdp(model: nn.Module, config: ModelConfig, parallel_dims: ParallelDim
for transformer_block, next_transformer_block in zip(transformer_blocks, next_transformer_blocks):
if next_transformer_block is not None:
next_mlp = getattr(next_transformer_block, "mlp", None)
if next_mlp is not None and isinstance(next_mlp, (MoE, LatentMoE)):
if next_mlp is not None and isinstance(next_mlp, (MoE, LatentMoE, ZayaMoE)):
transformer_block.set_modules_to_forward_prefetch([next_transformer_block, next_mlp.experts])
else:
transformer_block.set_modules_to_forward_prefetch([next_transformer_block])
Expand All @@ -695,7 +695,7 @@ def setup_fsdp(model: nn.Module, config: ModelConfig, parallel_dims: ParallelDim
for transformer_block, prev_transformer_block in zip(reversed_transformer_blocks, prev_transformer_blocks):
if prev_transformer_block is not None:
prev_mlp = getattr(prev_transformer_block, "mlp", None)
if prev_mlp is not None and isinstance(prev_mlp, (MoE, LatentMoE)):
if prev_mlp is not None and isinstance(prev_mlp, (MoE, LatentMoE, ZayaMoE)):
transformer_block.set_modules_to_backward_prefetch([prev_transformer_block, prev_mlp.experts])
else:
transformer_block.set_modules_to_backward_prefetch([prev_transformer_block])
Expand Down Expand Up @@ -939,7 +939,7 @@ def apply_ep(model: nn.Module, config: ModelConfig, parallel_dims: ParallelDims)
language_model = get_language_model(model)
for transformer_block in language_model.layers:
block_mlp = getattr(transformer_block, "mlp", None)
if block_mlp is not None and isinstance(block_mlp, (MoE, LatentMoE)):
if block_mlp is not None and isinstance(block_mlp, (MoE, LatentMoE, ZayaMoE)):
if config.ep_comm_backend == "torch":
parallelize_plan = ExpertParallel()
else:
Expand All @@ -962,7 +962,7 @@ def _move_buffers_to_cuda(model: nn.Module, config: ModelConfig) -> None:

def _reset_runtime_moe_buffers(model: nn.Module) -> None:
for module in model.modules():
if isinstance(module, (MoE, LatentMoE)) and module.tokens_per_expert.device.type != "meta":
if isinstance(module, (MoE, LatentMoE, ZayaMoE)) and module.tokens_per_expert.device.type != "meta":
module.tokens_per_expert.zero_()


Expand Down
4 changes: 4 additions & 0 deletions src/prime_rl/trainer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config

import prime_rl._compat # noqa: F401 — apply shims before transitive model imports
from prime_rl.trainer.models.afmoe import AfmoeConfig, AfmoeForCausalLM
from prime_rl.trainer.models.base import PreTrainedModelPrimeRL
from prime_rl.trainer.models.glm4_moe import Glm4MoeConfig, Glm4MoeForCausalLM
Expand All @@ -22,6 +23,7 @@
from prime_rl.trainer.models.qwen3 import Qwen3ForCausalLM
from prime_rl.trainer.models.qwen3_5_moe import Qwen3_5MoeConfig, Qwen3_5MoeForCausalLM
from prime_rl.trainer.models.qwen3_moe import Qwen3MoeConfig, Qwen3MoeForCausalLM
from prime_rl.trainer.models.zaya import ZayaConfig, ZayaForCausalLM

# Make custom config discoverable by AutoConfig
AutoConfig.register("afmoe", AfmoeConfig, exist_ok=True)
Expand All @@ -32,6 +34,7 @@
AutoConfig.register("nemotron_h", NemotronHConfig, exist_ok=True)
AutoConfig.register("qwen3_moe", Qwen3MoeConfig, exist_ok=True)
AutoConfig.register("qwen3_5_moe_text", Qwen3_5MoeConfig, exist_ok=True)
AutoConfig.register("zaya", ZayaConfig, exist_ok=True)
# GptOssConfig is just HF's class - already registered by transformers, no override needed.

_CUSTOM_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, OrderedDict())
Expand All @@ -46,6 +49,7 @@
_CUSTOM_CAUSAL_LM_MAPPING.register(Qwen3MoeConfig, Qwen3MoeForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(Qwen3_5MoeConfig, Qwen3_5MoeForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(GptOssConfig, GptOssForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(ZayaConfig, ZayaForCausalLM, exist_ok=True)


class AutoModelForCausalLMPrimeRL(_BaseAutoModelClass):
Expand Down
Loading