diff --git a/pyproject.toml b/pyproject.toml index 37c51f3933..7dc5fb8388 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -227,14 +227,11 @@ torchvision = { index = "pytorch-cu128" } torchaudio = { index = "pytorch-cu128" } torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" } dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" } -transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" } +transformers = { git = "https://github.com/JJJYmmm/transformers", rev = "d362c90c378b4b32b54513f1627b6d9d59ccc6a1" } flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdirectory = "flash_attn/cute", rev = "96bd151" } pydantic-config = { git = "https://github.com/samsja/pydantic_config.git", branch = "main" } vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.22/vllm_router-0.1.22-cp38-abi3-manylinux_2_28_x86_64.whl" } -vllm = [ - { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_x86_64.whl", marker = "platform_machine == 'x86_64'" }, - { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl", marker = "platform_machine == 'aarch64'" }, -] +vllm = { git = "https://github.com/Zyphra/vllm.git", branch = "zaya1-pr" } deep-ep = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/deep_ep-1.2.1+29d31c0-cp312-cp312-linux_x86_64.whl" } deep-gemm = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/deep_gemm-2.5.0+891d57b-cp312-cp312-linux_x86_64.whl" } nixl-cu12 = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/nixl_cu12-0.10.1-cp312-cp312-linux_x86_64.whl" } diff --git a/scripts/mini_moe.py b/scripts/mini_moe.py index 0aca7cf6cf..63343657c6 100644 --- a/scripts/mini_moe.py +++ b/scripts/mini_moe.py @@ -17,6 +17,7 @@ import torch from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers import Glm4MoeForCausalLM as HFGlm4MoeForCausalLM +from transformers import ZayaForCausalLM as HFZayaForCausalLM from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( Qwen3_5MoeForConditionalGeneration as HFQwen3_5MoeVLM, ) @@ -29,6 +30,8 @@ from prime_rl.trainer.models.minimax_m2 import MiniMaxM2Config from prime_rl.trainer.models.minimax_m2 import MiniMaxM2ForCausalLM as PrimeRLMiniMaxM2ForCausalLM from prime_rl.trainer.models.qwen3_5_moe import Qwen3_5MoeForCausalLM as PrimeRLQwen3_5MoeVLM +from prime_rl.trainer.models.zaya import ZayaConfig +from prime_rl.trainer.models.zaya import ZayaForCausalLM as PrimeRLZayaForCausalLM from prime_rl.utils.logger import setup_logger from prime_rl.utils.utils import default_dtype @@ -192,6 +195,50 @@ def _qwen3_5_moe_vlm_config(): "tokenizer_source": "Qwen/Qwen3.5-35B-A3B", "is_vlm": True, }, + "zaya": { + "config_class": ZayaConfig, + "config_kwargs": dict( + vocab_size=128, + hidden_size=32, + ffn_hidden_size=16, + num_hidden_layers=4, + num_experts=3, + num_attention_heads=4, + num_query_groups=2, + num_key_value_heads=2, + head_dim=8, + max_position_embeddings=512, + norm_epsilon=1e-5, + rope_theta=10000.0, + partial_rotary_factor=0.5, + moe_router_topk=1, + zaya_mlp_expansion=8, + zaya_use_mod=True, + zaya_use_eda=True, + add_bias_linear=False, + attention_bias=False, + lm_head_bias=False, + tie_word_embeddings=True, + use_cache=False, + use_grouped_mm=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + layer_types=["hybrid", "hybrid", "hybrid", "hybrid"], + rope_parameters={ + "hybrid": { + "rope_type": "default", + "rope_theta": 10000.0, + "partial_rotary_factor": 0.5, + } + }, + _attn_implementation="sdpa", + ), + "hf_model_class": HFZayaForCausalLM, + "prime_model_class": PrimeRLZayaForCausalLM, + # Different weight format from official release but this the 'official' HF supported version + "tokenizer_source": "JJJYmmm/ZAYA1-8B-HF", + }, # glm_moe_dsa: HF implementation is incorrect, not supported here } diff --git a/src/prime_rl/inference/vllm/worker/weight_transfer.py b/src/prime_rl/inference/vllm/worker/weight_transfer.py index 7b448f021a..7b9a9feee0 100644 --- a/src/prime_rl/inference/vllm/worker/weight_transfer.py +++ b/src/prime_rl/inference/vllm/worker/weight_transfer.py @@ -15,9 +15,13 @@ def load_weights_checkpoint_layerwise( model_config, vllm_config, ) -> None: - logger.info("Reloading checkpoint-format weights with vLLM layerwise processing") device = next(model.parameters()).device with torch.device(device), set_current_vllm_config(vllm_config): + if getattr(model_config.hf_config, "model_type", None) == "zaya": + model.load_weights(state_iter) # type: ignore + return + + logger.info("Reloading checkpoint-format weights with vLLM layerwise processing") initialize_layerwise_reload(model) model.load_weights(state_iter) # type: ignore finalize_layerwise_reload(model, model_config) diff --git a/src/prime_rl/trainer/model.py b/src/prime_rl/trainer/model.py index 910a978a66..ae41680e00 100644 --- a/src/prime_rl/trainer/model.py +++ b/src/prime_rl/trainer/model.py @@ -45,7 +45,7 @@ ) from prime_rl.trainer.models.layers.fp8_linear import replace_linear_with_fp8_blockwise_linear from prime_rl.trainer.models.layers.lm_head import inject_prime_lm_head -from prime_rl.trainer.models.layers.moe import LatentMoE, MoE +from prime_rl.trainer.models.layers.moe import LatentMoE, MoE, ZayaMoE from prime_rl.trainer.parallel_dims import ParallelDims from prime_rl.trainer.weights import ( load_state_dict, @@ -334,8 +334,8 @@ def freeze_moe_router(model: nn.Module) -> None: if mlp is None: continue - # Custom implementation: MoE/LatentMoE class with router attribute - if isinstance(mlp, (MoE, LatentMoE)): + # Custom implementation: MoE/LatentMoE/ZayaMoE class with router attribute + if isinstance(mlp, (MoE, LatentMoE, ZayaMoE)): for param in mlp.router.parameters(): param.requires_grad = False num_frozen += 1 @@ -384,7 +384,7 @@ def configure_moe_ep_backend(model: nn.Module, config: ModelConfig) -> None: configure_num_sms(config.deepep_num_sms) language_model = get_language_model(model) for transformer_block in language_model.layers: - if not isinstance(transformer_block.mlp, (MoE, LatentMoE)): + if not isinstance(transformer_block.mlp, (MoE, LatentMoE, ZayaMoE)): continue transformer_block.mlp.set_ep_comm_backend(backend) transformer_block.mlp.set_deepep_token_chunk_size(config.deepep_token_chunk_size) @@ -617,7 +617,7 @@ def setup_fsdp(model: nn.Module, config: ModelConfig, parallel_dims: ParallelDim for transformer_block in transformer_layers: block_mlp = getattr(transformer_block, "mlp", None) - if parallel_dims.ep_enabled and block_mlp is not None and isinstance(block_mlp, (MoE, LatentMoE)): + if parallel_dims.ep_enabled and block_mlp is not None and isinstance(block_mlp, (MoE, LatentMoE, ZayaMoE)): fully_shard(block_mlp.experts, mesh=dp_mod_ep_mesh, **fsdp_config) block_mlp.experts.set_gradient_divide_factor(parallel_dims.fsdp_gradient_divide_factor) @@ -674,7 +674,7 @@ def setup_fsdp(model: nn.Module, config: ModelConfig, parallel_dims: ParallelDim for transformer_block, next_transformer_block in zip(transformer_blocks, next_transformer_blocks): if next_transformer_block is not None: next_mlp = getattr(next_transformer_block, "mlp", None) - if next_mlp is not None and isinstance(next_mlp, (MoE, LatentMoE)): + if next_mlp is not None and isinstance(next_mlp, (MoE, LatentMoE, ZayaMoE)): transformer_block.set_modules_to_forward_prefetch([next_transformer_block, next_mlp.experts]) else: transformer_block.set_modules_to_forward_prefetch([next_transformer_block]) @@ -695,7 +695,7 @@ def setup_fsdp(model: nn.Module, config: ModelConfig, parallel_dims: ParallelDim for transformer_block, prev_transformer_block in zip(reversed_transformer_blocks, prev_transformer_blocks): if prev_transformer_block is not None: prev_mlp = getattr(prev_transformer_block, "mlp", None) - if prev_mlp is not None and isinstance(prev_mlp, (MoE, LatentMoE)): + if prev_mlp is not None and isinstance(prev_mlp, (MoE, LatentMoE, ZayaMoE)): transformer_block.set_modules_to_backward_prefetch([prev_transformer_block, prev_mlp.experts]) else: transformer_block.set_modules_to_backward_prefetch([prev_transformer_block]) @@ -939,7 +939,7 @@ def apply_ep(model: nn.Module, config: ModelConfig, parallel_dims: ParallelDims) language_model = get_language_model(model) for transformer_block in language_model.layers: block_mlp = getattr(transformer_block, "mlp", None) - if block_mlp is not None and isinstance(block_mlp, (MoE, LatentMoE)): + if block_mlp is not None and isinstance(block_mlp, (MoE, LatentMoE, ZayaMoE)): if config.ep_comm_backend == "torch": parallelize_plan = ExpertParallel() else: @@ -962,7 +962,7 @@ def _move_buffers_to_cuda(model: nn.Module, config: ModelConfig) -> None: def _reset_runtime_moe_buffers(model: nn.Module) -> None: for module in model.modules(): - if isinstance(module, (MoE, LatentMoE)) and module.tokens_per_expert.device.type != "meta": + if isinstance(module, (MoE, LatentMoE, ZayaMoE)) and module.tokens_per_expert.device.type != "meta": module.tokens_per_expert.zero_() diff --git a/src/prime_rl/trainer/models/__init__.py b/src/prime_rl/trainer/models/__init__.py index 180dcfef86..2ea848dcbe 100644 --- a/src/prime_rl/trainer/models/__init__.py +++ b/src/prime_rl/trainer/models/__init__.py @@ -9,6 +9,7 @@ from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.qwen3.configuration_qwen3 import Qwen3Config +import prime_rl._compat # noqa: F401 — apply shims before transitive model imports from prime_rl.trainer.models.afmoe import AfmoeConfig, AfmoeForCausalLM from prime_rl.trainer.models.base import PreTrainedModelPrimeRL from prime_rl.trainer.models.glm4_moe import Glm4MoeConfig, Glm4MoeForCausalLM @@ -22,6 +23,7 @@ from prime_rl.trainer.models.qwen3 import Qwen3ForCausalLM from prime_rl.trainer.models.qwen3_5_moe import Qwen3_5MoeConfig, Qwen3_5MoeForCausalLM from prime_rl.trainer.models.qwen3_moe import Qwen3MoeConfig, Qwen3MoeForCausalLM +from prime_rl.trainer.models.zaya import ZayaConfig, ZayaForCausalLM # Make custom config discoverable by AutoConfig AutoConfig.register("afmoe", AfmoeConfig, exist_ok=True) @@ -32,6 +34,7 @@ AutoConfig.register("nemotron_h", NemotronHConfig, exist_ok=True) AutoConfig.register("qwen3_moe", Qwen3MoeConfig, exist_ok=True) AutoConfig.register("qwen3_5_moe_text", Qwen3_5MoeConfig, exist_ok=True) +AutoConfig.register("zaya", ZayaConfig, exist_ok=True) # GptOssConfig is just HF's class - already registered by transformers, no override needed. _CUSTOM_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, OrderedDict()) @@ -46,6 +49,7 @@ _CUSTOM_CAUSAL_LM_MAPPING.register(Qwen3MoeConfig, Qwen3MoeForCausalLM, exist_ok=True) _CUSTOM_CAUSAL_LM_MAPPING.register(Qwen3_5MoeConfig, Qwen3_5MoeForCausalLM, exist_ok=True) _CUSTOM_CAUSAL_LM_MAPPING.register(GptOssConfig, GptOssForCausalLM, exist_ok=True) +_CUSTOM_CAUSAL_LM_MAPPING.register(ZayaConfig, ZayaForCausalLM, exist_ok=True) class AutoModelForCausalLMPrimeRL(_BaseAutoModelClass): diff --git a/src/prime_rl/trainer/models/layers/moe.py b/src/prime_rl/trainer/models/layers/moe.py index c59a720539..ce51b3b825 100644 --- a/src/prime_rl/trainer/models/layers/moe.py +++ b/src/prime_rl/trainer/models/layers/moe.py @@ -10,9 +10,11 @@ import torch import torch.nn.functional as F from torch import nn +from torch.distributed.tensor import DTensor from torchtitan.distributed.expert_parallel import expert_parallel from prime_rl.configs.trainer import EPCommBackend +from prime_rl.trainer.models.layers.norms import RMSNorm, RMSNormConfig @dataclass @@ -528,7 +530,7 @@ def forward( # group tokens together by expert indices from 0 to num_experts and pass that to experts forward selected_experts_indices = selected_experts_indices.reshape(-1) num_tokens_per_expert = torch.histc( - selected_experts_indices, + selected_experts_indices.float(), bins=self.num_experts, min=0, max=self.num_experts, @@ -1198,3 +1200,406 @@ def init_weights(self, init_std: float, buffer_device: torch.device): self.tokens_per_expert = torch.zeros(self.experts.num_experts, dtype=torch.float32) if self.load_balance_coeff is not None: self.expert_bias = torch.zeros(self.experts.num_experts, dtype=torch.float32) + + +class ZayaRouterMLP(nn.Module): + def __init__(self, hidden_size: int, num_experts: int, rms_norm_eps: float): + super().__init__() + self.rmsnorm_eda = RMSNorm(RMSNormConfig(hidden_size=hidden_size, eps=rms_norm_eps)) + self.fc1 = nn.Linear(hidden_size, hidden_size, bias=True) + self.fc2 = nn.Linear(hidden_size, hidden_size, bias=True) + self.out_proj = nn.Linear(hidden_size, num_experts, bias=False) + self.act_fn = nn.GELU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.rmsnorm_eda(hidden_states) + hidden_states = self.act_fn(self.fc1(hidden_states)) + hidden_states = self.act_fn(self.fc2(hidden_states)) + return self.out_proj(hidden_states) + + +class ZayaRouter(nn.Module): + def __init__( + self, + layer_idx: int, + hidden_size: int, + num_experts: int, + router_topk: int, + router_hidden_size: int, + use_eda: bool, + norm_epsilon: float, + ): + super().__init__() + self.hidden_size = hidden_size + self.layer_idx = layer_idx + + self.num_experts = num_experts + 1 + self.top_k = router_topk + self.router_hidden_size = router_hidden_size + + self.down_proj = nn.Linear(self.hidden_size, self.router_hidden_size, bias=True) + + self.use_eda = use_eda and self.layer_idx != 0 + if self.use_eda: + self.router_states_scale = nn.Parameter(torch.ones(self.router_hidden_size)) + + self.router_mlp = ZayaRouterMLP(self.router_hidden_size, self.num_experts, norm_epsilon) + + self.register_buffer("balancing_biases", torch.zeros(self.num_experts, dtype=torch.float32)) + self.balancing_biases[-1] = -1.0 + + def forward( + self, + hidden_states: torch.Tensor, + router_states: torch.Tensor | None = None, + routed_experts: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + final_shape = (-1, self.top_k) + batch_size, seq_length, _ = hidden_states.shape + + router_hidden_states = self.down_proj(hidden_states) + + if self.use_eda and router_states is not None: + router_hidden_states = router_hidden_states + router_states * self.router_states_scale + + router_hidden_states_next = router_hidden_states[:, -seq_length:].clone() + router_logits = self.router_mlp(router_hidden_states) + router_probs = torch.softmax(router_logits, dim=-1) + + if routed_experts is not None: + router_indices = routed_experts.reshape(batch_size, seq_length, self.top_k) + else: + biased_router_probs = router_probs.detach().to(torch.float32) + self.balancing_biases + _, router_indices = torch.topk(biased_router_probs, self.top_k, dim=-1) + + router_probs = torch.gather(router_probs, dim=2, index=router_indices) + + return ( + router_logits.reshape(-1, self.num_experts), + router_probs.reshape(final_shape), + router_indices.reshape(final_shape), + router_hidden_states_next, + ) + + +def _run_zaya_experts_for_loop_impl( + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, +) -> torch.Tensor: + num_tokens_per_expert = num_tokens_per_expert.tolist() + num_padding = x.shape[0] - sum(num_tokens_per_expert) + + x = torch.split( + x[: sum(num_tokens_per_expert)], + split_size_or_sections=num_tokens_per_expert, + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + gate_up = torch.matmul(x_expert, gate_up_proj[expert_idx].transpose(-2, -1)) + gate, up = gate_up.chunk(2, dim=-1) + h = F.silu(gate) * up + h = torch.matmul(h, down_proj[expert_idx].transpose(-2, -1)) + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) + return out + + +@expert_parallel +def _run_zaya_experts_for_loop( + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + _w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, +) -> torch.Tensor: + return _run_zaya_experts_for_loop_impl(gate_up_proj, down_proj, x, num_tokens_per_expert) + + +def _run_zaya_experts_grouped_mm_impl( + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + fp8: bool = False, +) -> torch.Tensor: + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + assert x.dim() == 2 + + if fp8: + from prime_rl.trainer.models.layers.fp8_grouped_gemm import grouped_fp8_gemm + + gate_up = grouped_fp8_gemm(x.bfloat16(), gate_up_proj.bfloat16().transpose(-2, -1), offsets) + gate, up = gate_up.chunk(2, dim=-1) + h = (F.silu(gate) * up).contiguous() + out = grouped_fp8_gemm(h, down_proj.bfloat16().transpose(-2, -1), offsets).type_as(x) + else: + gate_up = torch._grouped_mm(x.bfloat16(), gate_up_proj.bfloat16().transpose(-2, -1), offs=offsets) + gate, up = gate_up.chunk(2, dim=-1) + h = (F.silu(gate) * up).contiguous() + out = torch._grouped_mm(h, down_proj.bfloat16().transpose(-2, -1), offs=offsets).type_as(x) + return out + + +@expert_parallel +def _run_zaya_experts_grouped_mm( + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + _w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, +) -> torch.Tensor: + return _run_zaya_experts_grouped_mm_impl(gate_up_proj, down_proj, x, num_tokens_per_expert) + + +@expert_parallel +def _run_zaya_experts_fp8_grouped_mm( + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + _w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, +) -> torch.Tensor: + return _run_zaya_experts_grouped_mm_impl(gate_up_proj, down_proj, x, num_tokens_per_expert, fp8=True) + + +class ZayaGroupedExperts(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + num_experts: int, + use_grouped_mm: bool, + fp8: bool = False, + ): + super().__init__() + self.num_experts = num_experts + self.gate_up_proj = nn.Parameter(torch.empty(num_experts, 2 * hidden_dim, dim)) + self.down_proj = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) + self.use_grouped_mm = use_grouped_mm + self.fp8 = fp8 + self.ep_comm_backend: EPCommBackend = "torch" + + def set_ep_comm_backend(self, backend: EPCommBackend) -> None: + self.ep_comm_backend = backend + + def _forward_deepep(self, x: torch.Tensor, num_tokens_per_expert: torch.Tensor) -> torch.Tensor: + gate_up_proj = self.gate_up_proj.to_local() + down_proj = self.down_proj.to_local() + if self.use_grouped_mm: + return _run_zaya_experts_grouped_mm_impl(gate_up_proj, down_proj, x, num_tokens_per_expert, fp8=self.fp8) + return _run_zaya_experts_for_loop_impl(gate_up_proj, down_proj, x, num_tokens_per_expert) + + def forward(self, x: torch.Tensor, num_tokens_per_expert: torch.Tensor) -> torch.Tensor: + if self.ep_comm_backend == "deepep": + return self._forward_deepep(x, num_tokens_per_expert) + + num_tokens_per_expert = num_tokens_per_expert.to(torch.int32) + # This DTensor check is basically to ensure accuracy when comparing with local HF for + # tests/unit/train/models/test_zaya.py::test_zaya + # Bascially @expert_parallel does permutations which are not necessary on single GPU + if not isinstance(self.gate_up_proj, DTensor): + if self.use_grouped_mm: + return _run_zaya_experts_grouped_mm_impl( + self.gate_up_proj, self.down_proj, x, num_tokens_per_expert, fp8=self.fp8 + ) + return _run_zaya_experts_for_loop_impl(self.gate_up_proj, self.down_proj, x, num_tokens_per_expert) + + if self.use_grouped_mm: + if self.fp8: + return _run_zaya_experts_fp8_grouped_mm( + self.gate_up_proj, self.down_proj, self.down_proj, x, num_tokens_per_expert + ) + return _run_zaya_experts_grouped_mm( + self.gate_up_proj, self.down_proj, self.down_proj, x, num_tokens_per_expert + ) + return _run_zaya_experts_for_loop(self.gate_up_proj, self.down_proj, self.down_proj, x, num_tokens_per_expert) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.gate_up_proj, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.down_proj, mean=0.0, std=init_std) + + +class ZayaMoE(nn.Module): + def __init__( + self, + layer_idx: int, + hidden_size: int, + moe_intermediate_size: int, + num_experts: int, + *, + num_experts_per_tok: int, + router_hidden_size: int, + norm_epsilon: float, + use_grouped_mm: bool, + use_eda: bool = True, + fp8: bool = False, + ): + super().__init__() + if num_experts_per_tok < 1: + raise ValueError("num_experts_per_tok must be >= 1") + + self.top_k = num_experts_per_tok + self.router = ZayaRouter( + layer_idx=layer_idx, + hidden_size=hidden_size, + num_experts=num_experts, + router_topk=num_experts_per_tok, + router_hidden_size=router_hidden_size, + use_eda=use_eda, + norm_epsilon=norm_epsilon, + ) + self.experts = ZayaGroupedExperts( + dim=hidden_size, + hidden_dim=moe_intermediate_size, + num_experts=num_experts, + use_grouped_mm=use_grouped_mm, + fp8=fp8, + ) + self.reorderer = TokenReorderer(num_experts=self.experts.num_experts, top_k=num_experts_per_tok) + self.ep_comm_backend: EPCommBackend = "torch" + self.experts.set_ep_comm_backend(self.ep_comm_backend) + self.deepep_token_chunk_size: int | None = None + self.register_buffer( + "tokens_per_expert", + torch.zeros(self.experts.num_experts, dtype=torch.float32), + persistent=False, + ) + + def set_ep_comm_backend(self, backend: EPCommBackend) -> None: + self.ep_comm_backend = backend + self.experts.set_ep_comm_backend(backend) + + def set_deepep_token_chunk_size(self, chunk_size: int | None) -> None: + self.deepep_token_chunk_size = chunk_size + + def _run_routed_experts( + self, + x: torch.Tensor, + token_indices_experts_sorted: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + top_scores_experts_sorted: torch.Tensor, + ) -> torch.Tensor: + routed_indices = token_indices_experts_sorted.reshape(-1, 1).expand(-1, x.shape[-1]) + routed_input = torch.gather(x, dim=0, index=routed_indices) + routed_output = self.experts(routed_input, num_tokens_per_expert.to(torch.int32)) + return routed_output * top_scores_experts_sorted.reshape(-1, 1) + + def _run_local_routed_experts(self, x: torch.Tensor, num_tokens_per_expert: torch.Tensor) -> torch.Tensor: + return self.experts(x, num_tokens_per_expert) + + def _run_deepep_routed_experts( + self, + x: torch.Tensor, + selected_experts_indices: torch.Tensor, + top_scores: torch.Tensor, + ) -> torch.Tensor: + from prime_rl.trainer.distributed.deepep import ( + combine_tokens, + dispatch_tokens_async, + finalize_dispatch_tokens, + sync_combine, + ) + from prime_rl.trainer.distributed.expert_parallel import get_ep_group + + if x.shape[0] == 0: + return x.new_zeros(x.shape) + + group = get_ep_group(self.experts) + chunk_size = min(self.deepep_token_chunk_size or x.shape[0], x.shape[0]) + + def dispatch_chunk(start: int, end: int): + return dispatch_tokens_async( + x[start:end], + selected_experts_indices[start:end], + top_scores[start:end], + num_experts=self.experts.num_experts, + group=group, + score_before_experts=False, + ) + + def run_pending_chunk(pending_state): + hidden_states, num_tokens_per_expert, dispatch_state = finalize_dispatch_tokens(pending_state) + routed_output = self._run_local_routed_experts(hidden_states, num_tokens_per_expert) + return combine_tokens(routed_output, dispatch_state) + + pending_state = dispatch_chunk(0, chunk_size) + routed_outputs: list[torch.Tensor] = [] + for chunk_start in range(chunk_size, x.shape[0], chunk_size): + chunk_end = min(chunk_start + chunk_size, x.shape[0]) + next_pending_state = dispatch_chunk(chunk_start, chunk_end) + routed_outputs.append(run_pending_chunk(pending_state)) + pending_state = next_pending_state + routed_outputs.append(run_pending_chunk(pending_state)) + sync_combine() + routed_output = routed_outputs[0] if len(routed_outputs) == 1 else torch.cat(routed_outputs, dim=0) + return routed_output + + def forward( + self, + hidden_states: torch.Tensor, + prev_router_hidden_states: torch.Tensor | None = None, + routed_experts: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + _, route_prob, expert_choice, prev_router_hidden_states = self.router( + hidden_states, router_states=prev_router_hidden_states, routed_experts=routed_experts + ) + batch_size, seq_length, hidden_size = hidden_states.shape + hidden_states_flat = hidden_states.view(batch_size * seq_length, hidden_size) + + skip_expert = expert_choice == self.experts.num_experts + skip_tokens = skip_expert.squeeze(-1) + non_skip_tokens = ~skip_tokens + + expert_choice_routed = expert_choice[non_skip_tokens] + route_prob_routed = route_prob[non_skip_tokens] + hidden_states_routed = hidden_states_flat[non_skip_tokens] + + num_tokens_per_expert = torch.histc( + expert_choice_routed.reshape(-1).float(), + bins=self.experts.num_experts, + min=0, + max=self.experts.num_experts, + ) + with torch.no_grad(): + self.tokens_per_expert.add_(num_tokens_per_expert) + + out = torch.zeros_like(hidden_states_flat) + if skip_tokens.any(): + out[skip_tokens] = hidden_states_flat[skip_tokens] * route_prob[skip_tokens] + + if hidden_states_routed.shape[0] == 0: + return out.reshape(batch_size, seq_length, hidden_size), prev_router_hidden_states + + if self.ep_comm_backend == "deepep": + expert_output = self._run_deepep_routed_experts( + hidden_states_routed, expert_choice_routed, route_prob_routed + ) + out[non_skip_tokens] = expert_output + return out.reshape(batch_size, seq_length, hidden_size), prev_router_hidden_states + + top_scores_experts_sorted, token_indices_experts_sorted, num_tokens_per_expert = self.reorderer( + route_prob_routed, expert_choice_routed + ) + + routed_output = self._run_routed_experts( + hidden_states_routed, token_indices_experts_sorted, num_tokens_per_expert, top_scores_experts_sorted + ) + routed_out = torch.zeros_like(hidden_states_routed) + token_indices_full = token_indices_experts_sorted.reshape(-1, 1).expand(-1, hidden_size) + routed_out = routed_out.scatter_add(dim=0, index=token_indices_full, src=routed_output) + out[non_skip_tokens] = routed_out + return out.reshape(batch_size, seq_length, hidden_size), prev_router_hidden_states + + def init_weights(self, init_std: float, buffer_device: torch.device): + self.experts.init_weights(init_std) + for module in self.router.modules(): + if isinstance(module, nn.Linear): + nn.init.trunc_normal_(module.weight, mean=0.0, std=init_std) + if module.bias is not None: + nn.init.zeros_(module.bias) + with torch.device(buffer_device): + self.tokens_per_expert = torch.zeros(self.experts.num_experts, dtype=torch.float32) diff --git a/src/prime_rl/trainer/models/zaya/__init__.py b/src/prime_rl/trainer/models/zaya/__init__.py new file mode 100644 index 0000000000..b055da94d5 --- /dev/null +++ b/src/prime_rl/trainer/models/zaya/__init__.py @@ -0,0 +1,4 @@ +from .configuration_zaya import ZayaConfig +from .modeling_zaya import ZayaForCausalLM, ZayaModel, ZayaPreTrainedModel + +__all__ = ["ZayaConfig", "ZayaForCausalLM", "ZayaModel", "ZayaPreTrainedModel"] diff --git a/src/prime_rl/trainer/models/zaya/configuration_zaya.py b/src/prime_rl/trainer/models/zaya/configuration_zaya.py new file mode 100644 index 0000000000..765b18c92b --- /dev/null +++ b/src/prime_rl/trainer/models/zaya/configuration_zaya.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from transformers.configuration_utils import PretrainedConfig + + +class ZayaConfig(PretrainedConfig): + model_type = "zaya" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + use_cache=True, + attention_bias=False, + lm_head_bias=False, + vocab_size=262272, + hidden_size=2048, + num_hidden_layers=40, + num_experts=16, + num_attention_heads=8, + head_dim=128, + max_position_embeddings=131072, + pad_token_id=0, + bos_token_id=2, + eos_token_id=106, + tie_word_embeddings=True, + attention_dropout=0.0, + num_experts_per_tok=1, + moe_intermediate_size=2048, + router_hidden_size=256, + zaya_use_eda=True, + sliding_window=None, + rope_parameters=None, + partial_rotary_factor=0.5, + layer_types: list[str] | None = None, + num_key_value_heads=2, + cca_time0=2, + cca_time1=2, + rms_norm_eps: float = 1e-05, + hidden_act: str = "silu", + initializer_range: float = 0.02, + output_router_logits: bool = False, + _attn_implementation="eager", + use_grouped_mm=True, + load_balance_coeff=None, + **kwargs, + ): + if attention_bias: + raise ValueError("PrimeRL Zaya currently supports attention_bias=False") + if num_experts_per_tok != 1: + raise ValueError("PrimeRL Zaya currently supports num_experts_per_tok == 1") + + self.use_cache = use_cache + self.attention_bias = attention_bias + self.lm_head_bias = lm_head_bias + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_experts = num_experts + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + if self.head_dim is None: + raise ValueError("ZayaConfig requires head_dim") + + self.max_position_embeddings = max_position_embeddings + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.tie_word_embeddings = tie_word_embeddings + self.attention_dropout = attention_dropout + self.num_experts_per_tok = num_experts_per_tok + self.moe_intermediate_size = moe_intermediate_size + self.router_hidden_size = router_hidden_size + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.output_router_logits = output_router_logits + self.rms_norm_eps = float(rms_norm_eps) + self.norm_epsilon = self.rms_norm_eps + self.zaya_use_eda = zaya_use_eda + self.sliding_window = sliding_window + self.partial_rotary_factor = partial_rotary_factor + self.cca_time0 = cca_time0 + self.cca_time1 = cca_time1 + self.use_grouped_mm = use_grouped_mm + self.load_balance_coeff = load_balance_coeff + self.layer_types = layer_types or ["hybrid"] * self.num_hidden_layers + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=self.tie_word_embeddings, + **kwargs, + ) + + self._attn_implementation = _attn_implementation + self.rope_parameters = rope_parameters + self.rope_scaling = self.rope_parameters + self.validate_architecture() + + def convert_rope_params_to_dict(self, **kwargs): + return kwargs + + def validate_architecture(self) -> None: + if self.num_attention_heads % self.num_key_value_heads != 0: + raise ValueError("num_attention_heads must be a multiple of num_key_value_heads") + if len(self.layer_types) != self.num_hidden_layers: + raise ValueError("layer_types must have one entry per hidden layer") + if invalid_layer_types := set(self.layer_types) - {"hybrid", "hybrid_sliding"}: + raise ValueError(f"layer_types contains unsupported values: {sorted(invalid_layer_types)}") + if "hybrid_sliding" in self.layer_types and self.sliding_window is None: + raise ValueError("sliding_window must be set when layer_types contains hybrid_sliding") + if self.sliding_window is not None and self.sliding_window <= 0: + raise ValueError("sliding_window must be a strictly positive integer") + + +__all__ = ["ZayaConfig"] diff --git a/src/prime_rl/trainer/models/zaya/converting_zaya.py b/src/prime_rl/trainer/models/zaya/converting_zaya.py new file mode 100644 index 0000000000..b18a071914 --- /dev/null +++ b/src/prime_rl/trainer/models/zaya/converting_zaya.py @@ -0,0 +1,81 @@ +import torch +from torch import Tensor + +_HF_GATE_PREFIX = ".mlp.gate." +_PRIME_ROUTER_PREFIX = ".mlp.router." + + +def get_max_layer_num(state_dict: dict[str, Tensor]) -> int: + return max(int(key.split(".")[2]) for key in state_dict if key.startswith("model.layers.")) + 1 + + +def is_hf_state_dict(state_dict: dict[str, Tensor]) -> bool: + return any(_HF_GATE_PREFIX in name for name in state_dict) + + +def is_prime_state_dict(state_dict: dict[str, Tensor]) -> bool: + return any(_PRIME_ROUTER_PREFIX in name or name.endswith(".mlp.experts.w1") for name in state_dict) + + +def _rename_layer_prefix(state_dict: dict[str, Tensor], old_prefix: str, new_prefix: str) -> None: + for key in [key for key in state_dict if key.startswith(old_prefix)]: + state_dict[new_prefix + key[len(old_prefix) :]] = state_dict.pop(key) + + +def convert_hf_layer_to_prime(state_dict: dict[str, Tensor], layer_idx: int) -> None: + prefix = f"model.layers.{layer_idx}" + _rename_layer_prefix(state_dict, f"{prefix}.mlp.gate.", f"{prefix}.mlp.router.") + + gate_up_key = f"{prefix}.mlp.experts.gate_up_proj" + down_key = f"{prefix}.mlp.experts.down_proj" + if gate_up_key in state_dict: + state_dict[gate_up_key] = state_dict[gate_up_key].contiguous() + if down_key in state_dict: + state_dict[down_key] = state_dict[down_key].contiguous() + + +def convert_prime_layer_to_hf(state_dict: dict[str, Tensor], layer_idx: int) -> None: + prefix = f"model.layers.{layer_idx}" + + state_dict.pop(f"{prefix}.mlp.tokens_per_expert", None) + _rename_layer_prefix(state_dict, f"{prefix}.mlp.router.", f"{prefix}.mlp.gate.") + + gate_up_key = f"{prefix}.mlp.experts.gate_up_proj" + down_key = f"{prefix}.mlp.experts.down_proj" + if gate_up_key in state_dict: + state_dict[gate_up_key] = state_dict[gate_up_key].contiguous() + if down_key in state_dict: + state_dict[down_key] = state_dict[down_key].contiguous() + + w1_key = f"{prefix}.mlp.experts.w1" + w2_key = f"{prefix}.mlp.experts.w2" + w3_key = f"{prefix}.mlp.experts.w3" + if w1_key in state_dict: + w1 = state_dict.pop(w1_key) + w2 = state_dict.pop(w2_key) + w3 = state_dict.pop(w3_key) + state_dict[gate_up_key] = torch.cat([w1, w3], dim=1).contiguous() + state_dict[down_key] = w2.contiguous() + + +def convert_hf_to_prime(state_dict: dict[str, Tensor]) -> dict[str, Tensor]: + for layer_idx in range(get_max_layer_num(state_dict)): + convert_hf_layer_to_prime(state_dict, layer_idx) + return state_dict + + +def convert_prime_to_hf(state_dict: dict[str, Tensor]) -> dict[str, Tensor]: + for layer_idx in range(get_max_layer_num(state_dict)): + convert_prime_layer_to_hf(state_dict, layer_idx) + return state_dict + + +__all__ = [ + "convert_hf_layer_to_prime", + "convert_hf_to_prime", + "convert_prime_layer_to_hf", + "convert_prime_to_hf", + "get_max_layer_num", + "is_hf_state_dict", + "is_prime_state_dict", +] diff --git a/src/prime_rl/trainer/models/zaya/modeling_zaya.py b/src/prime_rl/trainer/models/zaya/modeling_zaya.py new file mode 100644 index 0000000000..d4017a3335 --- /dev/null +++ b/src/prime_rl/trainer/models/zaya/modeling_zaya.py @@ -0,0 +1,940 @@ +# coding=utf-8 +# Copyright 2025 Zyphra and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +from typing import Any, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from transformers.cache_utils import Cache +from transformers.generation import GenerationMixin +from transformers.masking_utils import create_causal_mask +from transformers.modeling_outputs import MoeModelOutputWithPast + +from prime_rl.trainer.models.base import PreTrainedModelPrimeRL +from prime_rl.trainer.models.layers.attn import ( + FlashAttention, + SDPAAttention, + flash_attn_3_varlen_func, + flash_attn_4_varlen_func, + flash_attn_varlen_func, +) +from prime_rl.trainer.models.layers.lm_head import PrimeLmOutput +from prime_rl.trainer.models.layers.moe import ZayaMoE +from prime_rl.trainer.models.layers.norms import RMSNorm, RMSNormConfig +from prime_rl.trainer.models.layers.rotary_emb import RotaryEmbedding, RotaryEmbeddingConfig, apply_rotary_pos_emb +from prime_rl.trainer.models.layers.ulysses_attn import ULYSSES_PARAMS, _all_to_all_head_to_seq, _all_to_all_seq_to_head +from prime_rl.trainer.models.zaya.configuration_zaya import ZayaConfig +from prime_rl.trainer.models.zaya.converting_zaya import ( + convert_hf_layer_to_prime, + convert_hf_to_prime, + convert_prime_layer_to_hf, + convert_prime_to_hf, +) +from prime_rl.trainer.models.zaya.converting_zaya import is_hf_state_dict as _is_hf_state_dict +from prime_rl.trainer.models.zaya.converting_zaya import is_prime_state_dict as _is_prime_state_dict +from prime_rl.trainer.models.zaya.vllm_postprocessing import convert_prime_to_vllm +from prime_rl.utils.cp import gather_for_cp +from prime_rl.utils.sequence import get_cu_seqlens_from_position_ids + + +def _all_to_all_seq_to_head_batched(t: torch.Tensor, cp_size: int, cp_group: dist.ProcessGroup) -> torch.Tensor: + assert t.shape[0] == 1, f"Zaya CP currently expects batch size 1, got {t.shape[0]}" + return _all_to_all_seq_to_head(t.squeeze(0), cp_size, cp_group).unsqueeze(0) + + +def _all_to_all_head_to_seq_batched(t: torch.Tensor, cp_size: int, cp_group: dist.ProcessGroup) -> torch.Tensor: + assert t.shape[0] == 1, f"Zaya CP currently expects batch size 1, got {t.shape[0]}" + return _all_to_all_head_to_seq(t.squeeze(0), cp_size, cp_group).unsqueeze(0) + + +class ZayaResidualScaling(nn.Module): + def __init__(self, config: ZayaConfig): + super().__init__() + self.hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size)) + self.hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size)) + self.residual_scale = nn.Parameter(torch.ones(config.hidden_size)) + self.residual_bias = nn.Parameter(torch.zeros(config.hidden_size)) + + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: + hidden_states = (hidden_states + self.hidden_states_bias) * self.hidden_states_scale + residual = (residual + self.residual_bias) * self.residual_scale + return hidden_states + residual + + +class ZayaCCAProjection(nn.Module): + def __init__(self, config: ZayaConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + + self.depthwise_kernel_size = config.cca_time0 + self.grouped_kernel_size = config.cca_time1 + self.conv_kernel_size = (self.depthwise_kernel_size - 1) + (self.grouped_kernel_size - 1) + + self.num_key_value_heads = config.num_key_value_heads + self.num_attention_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + + query_hidden_size = self.num_attention_heads * self.head_dim + key_value_hidden_size = self.num_key_value_heads * self.head_dim + + self.q_proj = nn.Linear(self.hidden_size, query_hidden_size, bias=self.config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, key_value_hidden_size, bias=self.config.attention_bias) + self.v_proj_current = nn.Linear(self.hidden_size, key_value_hidden_size // 2, bias=self.config.attention_bias) + self.v_proj_delayed = nn.Linear(self.hidden_size, key_value_hidden_size // 2, bias=self.config.attention_bias) + + conv_channels = key_value_hidden_size + query_hidden_size + self.conv_qk_depthwise = nn.Conv1d( + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=self.depthwise_kernel_size, + groups=conv_channels, + padding=0, + stride=1, + ) + self.conv_qk_grouped = nn.Conv1d( + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=self.grouped_kernel_size, + groups=(self.num_key_value_heads + self.num_attention_heads), + padding=0, + stride=1, + ) + self._cp_group = None + self._cp_rank = 0 + self._cp_world_size = 1 + + def set_context_parallel_attributes(self, cp_group: dist.ProcessGroup, cp_rank: int, cp_world_size: int) -> None: + self._cp_group = cp_group + self._cp_rank = cp_rank + self._cp_world_size = cp_world_size + + @property + def cp_enabled(self) -> bool: + return self._cp_world_size > 1 + + def _local_head_channel_indices(self) -> torch.Tensor: + local_q = self.num_attention_heads // self._cp_world_size + local_kv = self.num_key_value_heads // self._cp_world_size + q_start = self._cp_rank * local_q * self.head_dim + q_end = q_start + local_q * self.head_dim + k_start = self.num_attention_heads * self.head_dim + self._cp_rank * local_kv * self.head_dim + k_end = k_start + local_kv * self.head_dim + return torch.cat( + [ + torch.arange(q_start, q_end, device=self.conv_qk_depthwise.weight.device), + torch.arange(k_start, k_end, device=self.conv_qk_depthwise.weight.device), + ] + ) + + def _forward_context_parallel(self, hidden_states: torch.Tensor, padding_mask: torch.Tensor | None = None): + # TODO: support packed sequences in Context Parallel + if padding_mask is not None: + hidden_states = hidden_states * padding_mask[:, :, None].to(hidden_states.dtype) + + input_shape = hidden_states.shape[:-1] + projected_queries = self.q_proj(hidden_states).view(*input_shape, self.num_attention_heads, self.head_dim) + projected_keys = self.k_proj(hidden_states).view(*input_shape, self.num_key_value_heads, self.head_dim) + value_current = self.v_proj_current(hidden_states) + delayed_v_state = self.v_proj_delayed(hidden_states) + + projected_queries = _all_to_all_seq_to_head_batched(projected_queries, self._cp_world_size, self._cp_group) + projected_keys = _all_to_all_seq_to_head_batched(projected_keys, self._cp_world_size, self._cp_group) + + local_q_heads = projected_queries.shape[-2] + local_kv_heads = projected_keys.shape[-2] + local_groups = local_q_heads // local_kv_heads + query_residual = projected_queries + key_residual = _repeat_kv(projected_keys.transpose(1, 2), local_groups).transpose(1, 2) + query_residual = (query_residual + key_residual) * 0.5 + key_residual = query_residual.view(*query_residual.shape[:2], local_kv_heads, local_groups, self.head_dim).mean( + dim=-2 + ) + + qk_states = torch.cat([projected_queries.flatten(-2), projected_keys.flatten(-2)], dim=-1).transpose(1, 2) + qk_states = F.pad(qk_states, (self.conv_kernel_size, 0)) + channel_idx = self._local_head_channel_indices() + depthwise_weight = self.conv_qk_depthwise.weight.index_select(0, channel_idx) + depthwise_bias = ( + self.conv_qk_depthwise.bias.index_select(0, channel_idx) + if self.conv_qk_depthwise.bias is not None + else None + ) + qk_states = F.conv1d(qk_states, depthwise_weight, depthwise_bias, groups=channel_idx.numel()) + grouped_weight = self.conv_qk_grouped.weight.index_select(0, channel_idx) + grouped_bias = ( + self.conv_qk_grouped.bias.index_select(0, channel_idx) if self.conv_qk_grouped.bias is not None else None + ) + qk_states = F.conv1d(qk_states, grouped_weight, grouped_bias, groups=local_q_heads + local_kv_heads).transpose( + 1, 2 + ) + + q_size = local_q_heads * self.head_dim + query = qk_states[..., :q_size].view(*qk_states.shape[:2], local_q_heads, self.head_dim) + query_residual + key = qk_states[..., q_size:].view(*qk_states.shape[:2], local_kv_heads, self.head_dim) + key_residual + + recurrent_v_state = self.v_proj_delayed(hidden_states.new_zeros(input_shape[0], 1, self.hidden_size)) + delayed_v_state_full = gather_for_cp(delayed_v_state.contiguous(), self._cp_group) + value_delayed_full = torch.cat([recurrent_v_state, delayed_v_state_full[:, :-1]], dim=1) + seq_start = self._cp_rank * input_shape[1] + seq_end = seq_start + input_shape[1] + value_delayed = value_delayed_full[:, seq_start:seq_end] + value = torch.cat([value_current, value_delayed], dim=-1).view( + *input_shape, self.num_key_value_heads, self.head_dim + ) + value = _all_to_all_seq_to_head_batched(value, self._cp_world_size, self._cp_group) + return query, key, value + + def _conv_qk_by_sequence(self, qk_states: torch.Tensor, cu_seqlens: torch.Tensor | None) -> torch.Tensor: + # Vectorized version of: + # outputs = [] + # for start, end in zip(cu[:-1], cu[1:]): + # segment = F.pad(qk_states[:, :, start:end], (self.conv_kernel_size, 0)) + # outputs.append(self.conv_qk_grouped(self.conv_qk_depthwise(segment))) + # return torch.cat(outputs, dim=-1).transpose(1, 2) + qk_states = qk_states.transpose(1, 2) + + if cu_seqlens is None or cu_seqlens.numel() <= 2: + qk_states = F.pad(qk_states, (self.conv_kernel_size, 0)) + return self.conv_qk_grouped(self.conv_qk_depthwise(qk_states)).transpose(1, 2) + + B, C, S = qk_states.shape + device = qk_states.device + K = self.conv_kernel_size + + nseq = cu_seqlens.numel() - 1 + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + + seg_id = torch.repeat_interleave( + torch.arange(nseq, device=device, dtype=cu_seqlens.dtype), + lengths, + ) + + orig_idx = torch.arange(S, device=device, dtype=cu_seqlens.dtype) + expanded_idx = orig_idx + K * (seg_id + 1) + + expanded_len = S + K * nseq + expanded = qk_states.new_zeros(B, C, expanded_len) + + expanded.scatter_( + dim=2, + index=expanded_idx.to(torch.long)[None, None, :].expand(B, C, S), + src=qk_states, + ) + + out = self.conv_qk_grouped(self.conv_qk_depthwise(expanded)) + + gather_idx = (expanded_idx - K).to(torch.long) + out = out.index_select(dim=2, index=gather_idx) + + return out.transpose(1, 2) + + def _delay_value_by_sequence( + self, + hidden_states: torch.Tensor, + delayed_v_state: torch.Tensor, + cu_seqlens: torch.Tensor | None, + ) -> torch.Tensor: + # Vectorized version of: + # outputs = [] + # for start, end in zip(cu[:-1], cu[1:]): + # segment = delayed_v_state[:, start:end] + # outputs.append(torch.cat([recurrent_v_state, segment[:, :-1]], dim=1)) + # return torch.cat(outputs, dim=1) + input_shape = hidden_states.shape[:-1] + recurrent_v_state = self.v_proj_delayed(hidden_states.new_zeros(input_shape[0], 1, self.hidden_size)) + + if cu_seqlens is None or cu_seqlens.numel() <= 2: + return torch.cat([recurrent_v_state, delayed_v_state[:, :-1]], dim=1) + + B, S, D = delayed_v_state.shape + device = delayed_v_state.device + + idx = torch.arange(S, device=device) + + is_start = torch.zeros(S, device=device, dtype=torch.bool) + is_start[cu_seqlens[:-1].to(torch.long)] = True + + prev_idx = (idx - 1).clamp_min(0) + shifted = delayed_v_state.index_select(dim=1, index=prev_idx) + + return torch.where( + is_start[None, :, None], + recurrent_v_state.expand(B, S, D), + shifted, + ) + + def forward( + self, + hidden_states: torch.Tensor, + padding_mask: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + ): + if self.cp_enabled: + return self._forward_context_parallel(hidden_states, padding_mask) + + if padding_mask is not None: + hidden_states = hidden_states * padding_mask[:, :, None].to(hidden_states.dtype) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + projected_queries = self.q_proj(hidden_states) + projected_keys = self.k_proj(hidden_states) + qk_states = torch.cat([projected_queries, projected_keys], dim=-1) + + query_residual = projected_queries.view(*hidden_shape) + key_residual = projected_keys.view(*input_shape, -1, self.head_dim).transpose(1, 2) + key_residual = _repeat_kv(key_residual, self.num_key_value_groups).transpose(1, 2) + query_residual = (query_residual + key_residual) * 0.5 + key_residual = query_residual.view(*input_shape, -1, self.num_key_value_groups, self.head_dim).mean(dim=-2) + + qk_states = self._conv_qk_by_sequence(qk_states, cu_seqlens) + + query_hidden_size = query_residual.shape[-2] * query_residual.shape[-1] + query = qk_states[..., :query_hidden_size].view(*hidden_shape) + query_residual + key = qk_states[..., query_hidden_size:].view(*hidden_shape) + key_residual + + value_current = self.v_proj_current(hidden_states) + delayed_v_state = self.v_proj_delayed(hidden_states) + value_delayed = self._delay_value_by_sequence(hidden_states, delayed_v_state, cu_seqlens) + + value = torch.cat([value_current, value_delayed], dim=-1).view(*hidden_shape) + + return query, key, value + + +class ZayaQKNorm(nn.Module): + def __init__(self, config: ZayaConfig): + super().__init__() + scaling = config.head_dim**-0.5 + self.head_dim_scale = scaling**-1 + self.temp = nn.Parameter(torch.zeros(config.num_key_value_heads)) + self._cp_rank = 0 + self._cp_world_size = 1 + + def set_context_parallel_attributes(self, cp_group: dist.ProcessGroup, cp_rank: int, cp_world_size: int) -> None: + self._cp_rank = cp_rank + self._cp_world_size = cp_world_size + + def forward(self, query_states: torch.Tensor, key_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + norm_eps = torch.finfo(query_states.dtype).eps + query_states = query_states * ( + self.head_dim_scale / query_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps) + ) + key_states = key_states * (self.head_dim_scale / key_states.norm(p=2, dim=-1, keepdim=True).clamp_min(norm_eps)) + temp = self.temp + if self._cp_world_size > 1: + local_kv_heads = temp.shape[0] // self._cp_world_size + start = self._cp_rank * local_kv_heads + temp = temp[start : start + local_kv_heads] + key_states = key_states * temp[None, None, :, None] + return query_states, key_states + + +class ZayaRotaryEmbedding(RotaryEmbedding): + def __init__(self, config: RotaryEmbeddingConfig, device=None): + super().__init__(config, device=device) + self.inv_freq = self.inv_freq.float() + self.original_inv_freq = self.inv_freq.clone() + + def _apply(self, fn): + super()._apply(fn) + # always force RoPE in fp32 + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, self.inv_freq.device) + self.inv_freq = inv_freq.float() + self.original_inv_freq = self.inv_freq.clone() + return self + + +def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class ZayaSPDAAttention(SDPAAttention): + def __init__(self, config: ZayaConfig, layer_idx: int): + nn.Module.__init__( + self + ) # instead of initing with super().__init__() because we don't want to initialize the qkv projections + self.config = config + self.layer_idx = layer_idx + self.head_dim = config.head_dim + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.qkv_proj = ZayaCCAProjection( + config=self.config, + layer_idx=layer_idx, + ) + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.qk_norm = ZayaQKNorm(config) + self._cp_group = None + self._cp_rank = 0 + self._cp_world_size = 1 + + def set_context_parallel_attributes(self, cp_group: dist.ProcessGroup, cp_rank: int, cp_world_size: int) -> None: + assert self.num_attention_heads % cp_world_size == 0 + assert self.num_key_value_heads % cp_world_size == 0 + self._cp_group = cp_group + self._cp_rank = cp_rank + self._cp_world_size = cp_world_size + self.qkv_proj.set_context_parallel_attributes(cp_group, cp_rank, cp_world_size) + self.qk_norm.set_context_parallel_attributes(cp_group, cp_rank, cp_world_size) + + @property + def cp_enabled(self) -> bool: + return self._cp_world_size > 1 + + def _output_context_parallel(self, attn_output: torch.Tensor) -> torch.Tensor: + attn_output = attn_output.view(attn_output.shape[0], attn_output.shape[1], -1, self.head_dim) + attn_output = _all_to_all_head_to_seq_batched(attn_output, self._cp_world_size, self._cp_group) + return attn_output.flatten(-2) + + def _attention_core( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + causal_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + key_states = _repeat_kv(key_states, self.num_key_value_groups) + value_states = _repeat_kv(value_states, self.num_key_value_groups) + out = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=0.0 if not self.training else self.attention_dropout, + is_causal=causal_mask is None, + scale=self.scaling, + ) + out = out.transpose(1, 2).contiguous() + return out.view(out.shape[0], out.shape[1], -1) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: dict[str, Any] | None = None, + cca_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, + ) -> tuple[torch.Tensor, None]: + mask_mapping = attention_mask or {} + causal_mask = mask_mapping.get("causal") + padding_mask = mask_mapping.get("padding") + + query_states, key_states, value_states = self.qkv_proj(hidden_states, padding_mask, cu_seqlens) + + query_states, key_states = self.qk_norm(query_states, key_states) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if self.cp_enabled and cu_seqlens is None: + # TODO: support packed sequences in Context Parallel + causal_mask = None + attn_output = self._attention_core(query_states, key_states, value_states, causal_mask) + if self.cp_enabled: + attn_output = self._output_context_parallel(attn_output) + attn_output = self.o_proj(attn_output) + + return attn_output, None + + +class ZayaFlashAttention(FlashAttention): + _funcs = { + 2: flash_attn_varlen_func, + 3: flash_attn_3_varlen_func, + 4: flash_attn_4_varlen_func, + } + + def __init__(self, config: ZayaConfig, layer_idx: int, flash_attn_version: int = 2): + nn.Module.__init__( + self + ) # instead of initing with super().__init__() because we don't want to initialize the qkv projections + self.config = config + self.layer_idx = layer_idx + self.head_dim = config.head_dim + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.qkv_proj = ZayaCCAProjection( + config=self.config, + layer_idx=layer_idx, + ) + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.qk_norm = ZayaQKNorm(config) + self._cp_group = None + self._cp_rank = 0 + self._cp_world_size = 1 + + self._flash_attn_version = flash_attn_version + self.func = self._funcs[flash_attn_version] + self._flash_attn_call = self.func + if self._flash_attn_version == 4: + self._flash_attn_call = torch._dynamo.disable(self.func) + + def set_context_parallel_attributes(self, cp_group: dist.ProcessGroup, cp_rank: int, cp_world_size: int) -> None: + assert self.num_attention_heads % cp_world_size == 0 + assert self.num_key_value_heads % cp_world_size == 0 + self._cp_group = cp_group + self._cp_rank = cp_rank + self._cp_world_size = cp_world_size + self.qkv_proj.set_context_parallel_attributes(cp_group, cp_rank, cp_world_size) + self.qk_norm.set_context_parallel_attributes(cp_group, cp_rank, cp_world_size) + + @property + def cp_enabled(self) -> bool: + return self._cp_world_size > 1 + + def _output_context_parallel(self, attn_output: torch.Tensor) -> torch.Tensor: + attn_output = attn_output.view(attn_output.shape[0], attn_output.shape[1], -1, self.head_dim) + attn_output = _all_to_all_head_to_seq_batched(attn_output, self._cp_world_size, self._cp_group) + return attn_output.flatten(-2) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: dict[str, Any] | None = None, + padding_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, + ) -> tuple[torch.Tensor, None]: + mask_mapping = attention_mask or {} + padding_mask = mask_mapping.get("padding") + + query_states, key_states, value_states = self.qkv_proj(hidden_states, padding_mask, cu_seqlens) + + query_states, key_states = self.qk_norm(query_states, key_states) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if self.cp_enabled: + cu_seqlens = ULYSSES_PARAMS.get("cu_seqlens", cu_seqlens) + max_seqlen = ULYSSES_PARAMS.get("max_seqlen", max_seqlen) + attn_output = self._attention_core(query_states, key_states, value_states, cu_seqlens, max_seqlen) + if self.cp_enabled: + attn_output = self._output_context_parallel(attn_output) + attn_output = self.o_proj(attn_output) + + return attn_output, None + + +def _get_zaya_attention(config: ZayaConfig, layer_idx: int) -> nn.Module: + attn_impl = config._attn_implementation + if attn_impl == "eager": + attn_impl = "sdpa" + match attn_impl: + case "flash_attention_2": + return ZayaFlashAttention(config, layer_idx, flash_attn_version=2) + case "flash_attention_3": + return ZayaFlashAttention(config, layer_idx, flash_attn_version=3) + case "fa4": + return ZayaFlashAttention(config, layer_idx, flash_attn_version=4) + case "sdpa": + return ZayaSPDAAttention(config, layer_idx) + case _: + raise ValueError(f"Zaya attention does not support '{config._attn_implementation}'.") + + +class ZayaDecoderLayer(nn.Module): + def __init__(self, config: ZayaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = _get_zaya_attention(config, layer_idx) + self.mlp = ZayaMoE( + layer_idx=layer_idx, + hidden_size=config.hidden_size, + moe_intermediate_size=config.moe_intermediate_size, + num_experts=config.num_experts, + num_experts_per_tok=config.num_experts_per_tok, + router_hidden_size=config.router_hidden_size, + norm_epsilon=config.norm_epsilon, + use_grouped_mm=config.use_grouped_mm, + use_eda=config.zaya_use_eda, + ) + self.input_layernorm = RMSNorm(RMSNormConfig(hidden_size=config.hidden_size, eps=config.norm_epsilon)) + self.post_attention_layernorm = RMSNorm(RMSNormConfig(hidden_size=config.hidden_size, eps=config.norm_epsilon)) + self.post_attention_residual_scale = ZayaResidualScaling(config) + self.post_mlp_residual_scale = ZayaResidualScaling(config) + + def forward( + self, + hidden_states: torch.Tensor, + prev_router_hidden_states: torch.Tensor | None = None, + attention_mask: dict[str, Any] | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + routed_experts: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states.to(dtype=self.input_layernorm.weight.dtype)) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + residual = self.post_attention_residual_scale(hidden_states, residual) + hidden_states = self.post_attention_layernorm(residual.to(dtype=self.post_attention_layernorm.weight.dtype)) + + hidden_states, prev_router_hidden_states = self.mlp( + hidden_states, + prev_router_hidden_states, + routed_experts=routed_experts, + ) + + hidden_states = self.post_mlp_residual_scale(hidden_states, residual) + + return hidden_states, prev_router_hidden_states + + +class ZayaPreTrainedModel(PreTrainedModelPrimeRL): + config_class = ZayaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["ZayaDecoderATTLayer", "ZayaDecoderMLPLayer"] + _supports_flash_attn = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_attention_backend = True + + +class ZayaModel(ZayaPreTrainedModel): + def __init__(self, config: ZayaConfig): + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [ZayaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + self.input_hidden_states_scale = nn.Parameter(torch.ones(config.hidden_size)) + self.input_hidden_states_bias = nn.Parameter(torch.zeros(config.hidden_size)) + self.final_norm = RMSNorm(RMSNormConfig(hidden_size=config.hidden_size, eps=config.norm_epsilon)) + rope_parameters = config.rope_parameters.get("hybrid", config.rope_parameters) + rope_config = copy.copy(config) + rope_config.rope_parameters = rope_parameters + self.rotary_emb = ZayaRotaryEmbedding( + RotaryEmbeddingConfig( + max_position_embeddings=config.max_position_embeddings, + rope_type=rope_parameters["rope_type"], + model_config=rope_config, + ) + ) + # no swa layers for 8B + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @property + def cp_enabled(self) -> bool: + return bool(self.layers) and getattr(self.layers[0].self_attn, "cp_enabled", False) + + @property + def cp_group(self) -> dist.ProcessGroup | None: + if not self.cp_enabled: + return None + return self.layers[0].self_attn._cp_group + + @property + def cp_world_size(self) -> int: + if not self.cp_enabled: + return 1 + return self.layers[0].self_attn._cp_world_size + + def _prepare_causal_mask( + self, + attention_mask: torch.Tensor | None, + inputs_embeds: torch.Tensor, + ) -> torch.Tensor | None: + if attention_mask is None: + return None + + batch_size, seq_length = inputs_embeds.shape[:2] + min_dtype = torch.finfo(inputs_embeds.dtype).min + causal_mask = torch.full( + (seq_length, seq_length), + min_dtype, + device=inputs_embeds.device, + dtype=inputs_embeds.dtype, + ) + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, seq_length, seq_length) + padding_mask = attention_mask[:, None, None, :].to(torch.bool) + return causal_mask.masked_fill(~padding_mask, min_dtype) + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + routed_experts: torch.LongTensor | None = None, + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if position_ids is None: + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0) + + if attention_mask is not None and not isinstance(attention_mask, dict) and attention_mask.ndim != 2: + raise ValueError( + "ZAYA CCA projection requires a 2D `attention_mask` to mask padding tokens before convolution." + ) + + flat_position_ids = position_ids.reshape(-1) + is_packed = position_ids.shape[0] == 1 and ( + (flat_position_ids[1:] == 0).any() if flat_position_ids.numel() > 1 else False + ) + use_flash = self.config._attn_implementation in ("flash_attention_2", "flash_attention_3", "fa4") + if use_flash or is_packed: + cu_seqlens, max_seqlen = get_cu_seqlens_from_position_ids(position_ids) + torch._dynamo.mark_dynamic(cu_seqlens, 0) + else: + cu_seqlens = None + max_seqlen = None + + if use_flash: + causal_mask_mapping = dict.fromkeys(set(self.config.layer_types), None) + elif isinstance(attention_mask, dict): + causal_mask_mapping = attention_mask + elif is_packed: + seq_length = inputs_embeds.shape[1] + min_dtype = torch.finfo(inputs_embeds.dtype).min + causal_mask = torch.full( + (seq_length, seq_length), + min_dtype, + device=inputs_embeds.device, + dtype=inputs_embeds.dtype, + ) + causal_mask = torch.triu(causal_mask, diagonal=1) + segment_ids = (flat_position_ids == 0).cumsum(dim=0) - 1 + same_segment = segment_ids[:, None] == segment_ids[None, :] + causal_mask = causal_mask.masked_fill(~same_segment, min_dtype) + causal_mask = causal_mask[None, None, :, :] + if attention_mask is not None: + padding_mask_for_attn = attention_mask[:, None, None, -seq_length:].to(torch.bool) + causal_mask = causal_mask.masked_fill(~padding_mask_for_attn, min_dtype) + causal_mask_mapping = {"hybrid": causal_mask} + else: + mask_kwargs = { + "config": self.config, + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "past_key_values": None, + "position_ids": position_ids, + } + causal_mask_mapping = { + "hybrid": create_causal_mask(**mask_kwargs), + } + + padding_mask = None + if attention_mask is not None and not isinstance(attention_mask, dict): + padding_mask = attention_mask[:, -inputs_embeds.shape[1] :] + if inputs_embeds.shape[1] == 1: + padding_mask = None + + hidden_states = inputs_embeds + rope_position_ids = position_ids + if self.cp_enabled: + rope_position_ids = gather_for_cp(position_ids.contiguous(), self.cp_group) + position_embeddings = self.rotary_emb(hidden_states, rope_position_ids) + hidden_states = ((hidden_states + self.input_hidden_states_bias) * self.input_hidden_states_scale).to( + torch.float32 + ) + prev_router_hidden_states = None + + for layer_idx, decoder_layer in enumerate(self.layers): + routed_experts_layer = routed_experts[:, :, layer_idx, :] if routed_experts is not None else None + mask_mapping = {"causal": causal_mask_mapping[self.config.layer_types[layer_idx]], "padding": padding_mask} + hidden_states, prev_router_hidden_states = decoder_layer( + hidden_states, + prev_router_hidden_states=prev_router_hidden_states, + attention_mask=mask_mapping, + position_embeddings=position_embeddings, + routed_experts=routed_experts_layer, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + hidden_states = self.final_norm(hidden_states.to(dtype=self.final_norm.weight.dtype)) + return MoeModelOutputWithPast(last_hidden_state=hidden_states) + + +class ZayaForCausalLM(ZayaPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + + def __init__(self, config: ZayaConfig): + super().__init__(config) + self.model = ZayaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=config.lm_head_bias) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @classmethod + def is_hf_state_dict(cls, state_dict: dict[str, Tensor]) -> bool: + return _is_hf_state_dict(state_dict) + + @classmethod + def is_prime_state_dict(cls, state_dict: dict[str, Tensor]) -> bool: + return _is_prime_state_dict(state_dict) + + @classmethod + def convert_to_hf(cls, state_dict: dict[str, Tensor]) -> dict[str, Tensor]: + return convert_prime_to_hf(state_dict) + + @classmethod + def convert_to_prime(cls, state_dict: dict[str, Tensor]) -> dict[str, Tensor]: + return convert_hf_to_prime(state_dict) + + @classmethod + def convert_layer_to_hf(cls, state_dict: dict[str, Tensor], layer_idx: int) -> dict[str, Tensor]: + convert_prime_layer_to_hf(state_dict, layer_idx) + return state_dict + + @classmethod + def convert_layer_to_prime(cls, state_dict: dict[str, Tensor], layer_idx: int) -> dict[str, Tensor]: + convert_hf_layer_to_prime(state_dict, layer_idx) + return state_dict + + @classmethod + def convert_to_vllm(cls, state_dict: dict[str, Tensor]) -> dict[str, Tensor]: + return convert_prime_to_vllm(state_dict) + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + temperature: torch.Tensor | None = None, + routed_experts: torch.LongTensor | None = None, + **kwargs, + ) -> PrimeLmOutput: + del cache_position, kwargs + assert use_cache in (None, False), "use_cache is not supported for PrimeRL Zaya" + assert past_key_values is None, "past_key_values is not supported for PrimeRL Zaya" + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + routed_experts=routed_experts, + ) + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + if labels is not None: + labels = labels[:, slice_indices] + if type(self.lm_head) is nn.Linear: + return PrimeLmOutput(logits=self.lm_head(hidden_states[:, slice_indices, :])) + return self.lm_head(hidden_states[:, slice_indices, :], labels, temperature=temperature) + + def init_buffers_post_meta(self): + for rotary_emb in (getattr(self.model, "rotary_emb", None), getattr(self.model, "swa_rotary_emb", None)): + if rotary_emb is None: + continue + inv_freq, rotary_emb.attention_scaling = rotary_emb.rope_init_fn( + rotary_emb.config, rotary_emb.inv_freq.device + ) + rotary_emb.inv_freq = inv_freq + rotary_emb.original_inv_freq = inv_freq + + for module in self.modules(): + if isinstance(module, ZayaMoE): + module.tokens_per_expert = torch.zeros( + module.experts.num_experts, + dtype=torch.float32, + device=module.tokens_per_expert.device, + ) + + +__all__ = ["ZayaConfig", "ZayaForCausalLM", "ZayaModel", "ZayaPreTrainedModel"] diff --git a/src/prime_rl/trainer/models/zaya/vllm_postprocessing.py b/src/prime_rl/trainer/models/zaya/vllm_postprocessing.py new file mode 100644 index 0000000000..4b16c21e71 --- /dev/null +++ b/src/prime_rl/trainer/models/zaya/vllm_postprocessing.py @@ -0,0 +1,140 @@ +"""Zaya weight postprocessing for vLLM's original alternating-layer layout. + +PrimeRL trains the Transformers-native hybrid decoder layout. The upstream vLLM +Zaya PR used the original alternating attention/MoE layer layout; this module is +its in-memory inverse of ``convert_zaya_weights_to_hf.py``. +""" + +import re + +import torch +from torch import Tensor + +from prime_rl.trainer.models.zaya.converting_zaya import convert_prime_to_hf, get_max_layer_num, is_prime_state_dict + +_LAYER_PATTERN = re.compile(r"^model\.layers\.(\d+)\.(.+)$") + +_REVERSE_COMMON_REPLACEMENTS = ( + ("self_attn.qkv_proj.conv_qk_depthwise.", "self_attn.qkv.conv_qk.0."), + ("self_attn.qkv_proj.conv_qk_grouped.", "self_attn.qkv.conv_qk.1."), + ("self_attn.qk_norm.temp", "self_attn.qkv.temp"), + ("self_attn.qkv_proj.q_proj.", "self_attn.qkv.linear_q."), + ("self_attn.qkv_proj.k_proj.", "self_attn.qkv.linear_k."), + ("self_attn.qkv_proj.v_proj_current.", "self_attn.qkv.val_proj1."), + ("self_attn.qkv_proj.v_proj_delayed.", "self_attn.qkv.val_proj2."), + ("self_attn.qkv_proj.", "self_attn.qkv."), + ("mlp.gate.router_mlp.rmsnorm_eda.", "zaya_block.router.rmsnorm_eda."), + ("mlp.gate.router_mlp.fc1.", "zaya_block.router.router_mlp.0."), + ("mlp.gate.router_mlp.fc2.", "zaya_block.router.router_mlp.2."), + ("mlp.gate.router_mlp.out_proj.", "zaya_block.router.router_mlp.4."), + ("mlp.gate.", "zaya_block.router."), + ("mlp.", "zaya_block."), +) + + +def _reverse_common(rest: str) -> str: + for new, old in _REVERSE_COMMON_REPLACEMENTS: + if rest.startswith(new): + return old + rest.removeprefix(new) + return rest + + +def _convert_hf_weight_name_to_vllm(name: str, num_hidden_layers: int) -> str | None: + if name.startswith("model.input_"): + return f"model.layers.0.res_scale.{name.removeprefix('model.input_')}" + + final_post_mlp_prefix = f"model.layers.{num_hidden_layers - 1}.post_mlp_residual_scale." + if name.startswith(final_post_mlp_prefix): + return f"model.res_scale.{name.removeprefix(final_post_mlp_prefix)}" + + match = _LAYER_PATTERN.match(name) + if match is None: + return name + + layer_idx = int(match.group(1)) + rest = match.group(2) + + if rest.startswith("mlp.experts.gate_up_proj") or rest.startswith("mlp.experts.down_proj"): + return None + + if rest.startswith("self_attn."): + return f"model.layers.{2 * layer_idx}.{_reverse_common(rest)}" + if rest.startswith("input_layernorm."): + return f"model.layers.{2 * layer_idx}.input_norm.{rest.removeprefix('input_layernorm.')}" + if rest.startswith("post_mlp_residual_scale."): + if layer_idx == num_hidden_layers - 1: + return f"model.res_scale.{rest.removeprefix('post_mlp_residual_scale.')}" + return f"model.layers.{2 * (layer_idx + 1)}.res_scale.{rest.removeprefix('post_mlp_residual_scale.')}" + if rest.startswith("mlp."): + return f"model.layers.{2 * layer_idx + 1}.{_reverse_common(rest)}" + if rest.startswith("post_attention_layernorm."): + return f"model.layers.{2 * layer_idx + 1}.input_norm.{rest.removeprefix('post_attention_layernorm.')}" + if rest.startswith("post_attention_residual_scale."): + return f"model.layers.{2 * layer_idx + 1}.res_scale.{rest.removeprefix('post_attention_residual_scale.')}" + + raise ValueError(f"Unexpected HF Zaya weight name: {name}") + + +def _add_vllm_experts(hf_state_dict: dict[str, Tensor], vllm_state_dict: dict[str, Tensor], layer_idx: int) -> None: + gate_up_key = f"model.layers.{layer_idx}.mlp.experts.gate_up_proj" + down_key = f"model.layers.{layer_idx}.mlp.experts.down_proj" + if gate_up_key not in hf_state_dict: + return + + gate_up = hf_state_dict[gate_up_key] + down = hf_state_dict[down_key] + moe_dim = gate_up.shape[1] // 2 + old_layer_idx = 2 * layer_idx + 1 + + for expert_idx in range(gate_up.shape[0]): + prefix = f"model.layers.{old_layer_idx}.zaya_block.experts.local_experts.{expert_idx}" + vllm_state_dict[f"{prefix}.linear_fc1.weight"] = torch.cat( + [gate_up[expert_idx, :moe_dim], gate_up[expert_idx, moe_dim:]], dim=0 + ).contiguous() + vllm_state_dict[f"{prefix}.linear_fc2.weight"] = down[expert_idx].contiguous() + + +def _infer_num_hidden_layers(state_dict: dict[str, Tensor]) -> int: + if not any(key.startswith("model.layers.") for key in state_dict): + return 0 + return get_max_layer_num(state_dict) + + +def convert_hf_to_vllm(state_dict: dict[str, Tensor], num_hidden_layers: int | None = None) -> dict[str, Tensor]: + """Convert Transformers-native Zaya weights to vLLM original Zaya weights.""" + if num_hidden_layers is None: + num_hidden_layers = _infer_num_hidden_layers(state_dict) + converted: dict[str, Tensor] = {} + + for name, tensor in state_dict.items(): + if name == "lm_head.weight": + continue + target = _convert_hf_weight_name_to_vllm(name, num_hidden_layers) + if target is not None: + converted[target] = tensor.contiguous() + + for layer_idx in range(num_hidden_layers): + _add_vllm_experts(state_dict, converted, layer_idx) + + return converted + + +def convert_prime_to_vllm(state_dict: dict[str, Tensor], num_hidden_layers: int | None = None) -> dict[str, Tensor]: + """Convert PrimeRL training-format Zaya weights to vLLM original Zaya weights.""" + hf_state_dict = dict(state_dict) + if is_prime_state_dict(hf_state_dict): + hf_state_dict = convert_prime_to_hf(hf_state_dict) + return convert_hf_to_vllm(hf_state_dict, num_hidden_layers=num_hidden_layers) + + +class ZayaVLLMWeightPostprocessor: + """Callable postprocessor shared by checkpoint and NCCL broadcast paths.""" + + def __init__(self, num_hidden_layers: int | None = None): + self.num_hidden_layers = num_hidden_layers + + def __call__(self, state_dict: dict[str, Tensor]) -> dict[str, Tensor]: + return convert_prime_to_vllm(state_dict, num_hidden_layers=self.num_hidden_layers) + + +__all__ = ["ZayaVLLMWeightPostprocessor", "convert_hf_to_vllm", "convert_prime_to_vllm"] diff --git a/src/prime_rl/trainer/rl/broadcast/filesystem.py b/src/prime_rl/trainer/rl/broadcast/filesystem.py index 55a92c832d..c26d1ab9d9 100644 --- a/src/prime_rl/trainer/rl/broadcast/filesystem.py +++ b/src/prime_rl/trainer/rl/broadcast/filesystem.py @@ -43,7 +43,9 @@ def broadcast_weights(self, model: nn.Module, step: int) -> None: if not adapter_only: state_dict = gather_weights_on_master(model, is_master=self.world.is_master) - if isinstance(model, PreTrainedModelPrimeRL) and model.is_prime_state_dict(state_dict): + if isinstance(model, PreTrainedModelPrimeRL) and getattr(model.config, "model_type", None) == "zaya": + state_dict = model.convert_to_vllm(state_dict) + elif isinstance(model, PreTrainedModelPrimeRL) and model.is_prime_state_dict(state_dict): model.convert_to_hf(state_dict) else: # For regular transformers models, revert internal format to original HF hub format diff --git a/src/prime_rl/trainer/rl/broadcast/nccl.py b/src/prime_rl/trainer/rl/broadcast/nccl.py index 13a887308a..cbc700c645 100644 --- a/src/prime_rl/trainer/rl/broadcast/nccl.py +++ b/src/prime_rl/trainer/rl/broadcast/nccl.py @@ -90,6 +90,13 @@ def preprocess_layer_checkpoint( layer_state_dict: dict[str, Tensor], layer_idx: int, ) -> dict[str, Tensor]: + if getattr(model.config, "model_type", None) == "zaya": + from prime_rl.trainer.models.zaya.vllm_postprocessing import convert_hf_to_vllm + + if isinstance(model, PreTrainedModelPrimeRL) and layer_idx >= 0: + model.convert_layer_to_hf(layer_state_dict, layer_idx) + return convert_hf_to_vllm(layer_state_dict, num_hidden_layers=model.config.num_hidden_layers) + if isinstance(model, PreTrainedModelPrimeRL) and model.is_prime_state_dict(layer_state_dict): model.convert_layer_to_hf(layer_state_dict, layer_idx) return layer_state_dict diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py index fc03e89f3b..4a6ad32ca7 100644 --- a/src/prime_rl/trainer/rl/train.py +++ b/src/prime_rl/trainer/rl/train.py @@ -196,18 +196,24 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: substitute_ulysses_attn, ) - substitute_hf_ulysses_attn(cp_group) - substitute_ulysses_attn(cp_group, attn_impl=config.model.attn) from prime_rl.utils.cp import ( assert_cp_style_supports_model, setup_hybrid_cp, setup_nemotron_h_cp, setup_sparse_mla_cp, + setup_zaya_cp, + is_zaya_model, ) + substitute_hf_ulysses_attn(cp_group) + if not is_zaya_model(model): + # Zaya does a different kind of ulysses where FlashAttention should be vanilla + substitute_ulysses_attn(cp_group, attn_impl=config.model.attn) + assert_cp_style_supports_model(config.model.cp_style, model) # sparse MLA is softmax (works with both ring and ulysses). setup_sparse_mla_cp(model, cp_group, cp_rank, parallel_dims.cp) + setup_zaya_cp(model, cp_group, cp_rank, parallel_dims.cp) # Linear-attn / Mamba layers are only configured under ulysses; with ring # we'd have already raised above. if config.model.cp_style == "ulysses": diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index 1c12b342ee..7f014b6fa5 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -117,10 +117,18 @@ def train(config: SFTConfig): substitute_hf_ulysses_attn, substitute_ulysses_attn, ) + from prime_rl.utils.cp import ( + setup_hybrid_cp, + setup_nemotron_h_cp, + setup_sparse_mla_cp, + setup_zaya_cp, + is_zaya_model, + ) substitute_hf_ulysses_attn(cp_group) - substitute_ulysses_attn(cp_group, attn_impl=config.model.attn) - from prime_rl.utils.cp import setup_hybrid_cp, setup_nemotron_h_cp, setup_sparse_mla_cp + if not is_zaya_model(model): + # Zaya does a different kind of ulysses where FlashAttention should be vanilla + substitute_ulysses_attn(cp_group, attn_impl=config.model.attn) # Set up checkpoint manager logger.info(f"Initializing checkpoint managers ({config.ckpt})") @@ -145,6 +153,7 @@ def train(config: SFTConfig): assert_cp_style_supports_model(config.model.cp_style, model) # sparse MLA is softmax (works with both ring and ulysses). setup_sparse_mla_cp(model, cp_group, cp_rank, parallel_dims.cp) + setup_zaya_cp(model, cp_group, cp_rank, parallel_dims.cp) # Linear-attn / Mamba layers are only configured under ulysses; with ring # we'd have already raised above. if config.model.cp_style == "ulysses": diff --git a/src/prime_rl/utils/cp.py b/src/prime_rl/utils/cp.py index 2ac9d45ad7..86fc7d4e42 100644 --- a/src/prime_rl/utils/cp.py +++ b/src/prime_rl/utils/cp.py @@ -34,6 +34,15 @@ def _has_linear_attn_layer(model: nn.Module) -> bool: return False +def is_zaya_model(model: nn.Module) -> bool: + config = getattr(model, "config", None) + if getattr(config, "model_type", None) == "zaya": + return True + inner = getattr(model, "model", None) + config = getattr(inner, "config", None) + return getattr(config, "model_type", None) == "zaya" + + def assert_cp_style_supports_model(cp_style: CPStyle, model: nn.Module) -> None: """Refuse `cp_style='ring'` on models that have linear/SSM attention layers. @@ -50,6 +59,12 @@ def assert_cp_style_supports_model(cp_style: CPStyle, model: nn.Module) -> None: "cp_style='ulysses' instead — its all-to-all on Q/K/V works " "out-of-the-box with non-softmax kernels." ) + if cp_style == "ring" and is_zaya_model(model): + raise ValueError( + "cp_style='ring' is not supported for Zaya because CCA convolution " + "and value shifting require full sequence context before attention. " + "Use cp_style='ulysses' instead." + ) def setup_hybrid_cp(model: nn.Module, cp_group: dist.ProcessGroup, cp_rank: int, cp_world_size: int) -> None: @@ -125,6 +140,27 @@ def setup_sparse_mla_cp(model: nn.Module, cp_group: dist.ProcessGroup, cp_rank: get_logger().info(f"Configured sparse MLA CP on {count} DSA layers") +def setup_zaya_cp(model: nn.Module, cp_group: dist.ProcessGroup, cp_rank: int, cp_world_size: int) -> None: + """Configure Zaya attention layers for context-parallel full-sequence CCA.""" + + inner = getattr(model, "model", model) + layers = getattr(inner, "layers", None) + if layers is None: + return + + count = 0 + for layer in layers: + attn = getattr(layer, "self_attn", None) + if attn is not None and hasattr(attn, "set_context_parallel_attributes"): + attn.set_context_parallel_attributes(cp_group, cp_rank, cp_world_size) + count += 1 + + if count > 0: + from prime_rl.utils.logger import get_logger + + get_logger().info(f"Configured Zaya CP on {count} CCA attention layers") + + def shard_for_cp(t: torch.Tensor, cp_rank: int, cp_world_size: int) -> torch.Tensor: """ Shard a tensor for context parallelism. diff --git a/tests/unit/train/models/test_zaya.py b/tests/unit/train/models/test_zaya.py new file mode 100644 index 0000000000..6295c1d259 --- /dev/null +++ b/tests/unit/train/models/test_zaya.py @@ -0,0 +1,552 @@ +# This script should be run with "https://github.com/JJJYmmm/transformers.git" which checks against the (likely) merged transformers PR for Zaya +import copy +from pathlib import Path + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from huggingface_hub import snapshot_download +from torch import nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.parallel import parallelize_module +from torchtitan.distributed.expert_parallel import ExpertParallel +from transformers import ZayaForCausalLM as HFZayaForCausalLM + +# There is something wrong with the quack RMSNorm vs the FP32 implementation at least on (SM120) +import prime_rl.trainer.models.layers.norms as norms + +norms._get_quack_rmsnorm = lambda: None + +from prime_rl.trainer.models.layers.lm_head import inject_prime_lm_head +from prime_rl.trainer.models.layers.rotary_emb import RotaryEmbeddingConfig +from prime_rl.trainer.models.zaya import ZayaConfig +from prime_rl.trainer.models.zaya import ZayaForCausalLM as PrimeRLZayaForCausalLM +from prime_rl.trainer.models.zaya.modeling_zaya import ( + ZayaCCAProjection, + ZayaDecoderLayer, + ZayaFlashAttention, + ZayaQKNorm, + ZayaRotaryEmbedding, + ZayaSPDAAttention, +) +from prime_rl.trainer.weights import load_state_dict +from prime_rl.utils.utils import default_dtype + +pytestmark = [pytest.mark.gpu] + +LOGITS_ATOL = 2e-2 +EMBED_GRAD_ATOL = 2 +_REQUIRES_TWO_CUDA_DEVICES_MSG = "Zaya distributed parity tests require at least two CUDA devices" + + +def _tiny_config(attn_implementation: str = "sdpa"): + config = ZayaConfig( + vocab_size=128, + hidden_size=32, + num_hidden_layers=4, + num_experts=3, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=8, + max_position_embeddings=64, + rms_norm_eps=1e-5, + partial_rotary_factor=0.5, + zaya_use_eda=True, + attention_bias=False, + lm_head_bias=False, + tie_word_embeddings=True, + use_cache=False, + use_grouped_mm=False, + layer_types=["hybrid"] * 4, + rope_parameters={ + "hybrid": { + "rope_type": "default", + "rope_theta": 10000.0, + "partial_rotary_factor": 0.5, + } + }, + ) + config._attn_implementation = attn_implementation + return config + + +def _clone_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + return {key: value.detach().clone() for key, value in state_dict.items()} + + +def get_model_pairs( + hf_attn_implementation: str = "sdpa", + prime_attn_implementation: str | None = None, + dtype: torch.dtype = torch.float32, +): + if prime_attn_implementation is None: + prime_attn_implementation = hf_attn_implementation + hf_config = _tiny_config(hf_attn_implementation) + prime_config = _tiny_config(prime_attn_implementation) + + with torch.device("cuda"), default_dtype(dtype): + hf_model = HFZayaForCausalLM(hf_config) + prime_model = PrimeRLZayaForCausalLM._from_config(prime_config) + + with torch.no_grad(): + state_dict = _clone_state_dict(hf_model.state_dict()) + prime_state_keys = set(prime_model.state_dict().keys()) + prime_model.convert_to_prime(state_dict) + prime_model.load_state_dict(state_dict) + + inject_prime_lm_head(prime_model, chunk_size=None) + assert prime_state_keys - set(state_dict.keys()) == set() + return hf_model, prime_model + + +def _assert_logits_and_embed_grads_close(hf_model, prime_model, input_ids, position_ids, attention_mask=None) -> None: + hf_output = hf_model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False) + prime_output = prime_model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) + hf_output.logits.float().sum().backward() + prime_output["logits"].float().sum().backward() + + logits_diff = prime_output["logits"] - hf_output.logits + assert torch.allclose(logits_diff, torch.zeros_like(logits_diff), atol=LOGITS_ATOL), ( + f"Max logits diff: {logits_diff.abs().max()}" + ) + + grad_diff = hf_model.model.embed_tokens.weight.grad - prime_model.model.embed_tokens.weight.grad + assert torch.allclose(grad_diff, torch.zeros_like(grad_diff), atol=EMBED_GRAD_ATOL), ( + f"Max grad diff: {grad_diff.abs().max()}" + ) + + +class _PassthroughPrimeZayaBlock(nn.Module): + def forward(self, hidden_states, prev_router_hidden_states=None, routed_experts=None): + return hidden_states, prev_router_hidden_states + + +class _PassthroughHfZayaMoe(nn.Module): + """HF `ZayaSparseMoeBlock` returns `(hidden_states, prev_router_hidden_states)`.""" + + def forward(self, hidden_states, prev_router_hidden_states=None): + return hidden_states, prev_router_hidden_states + + +def test_zaya_attn_only() -> None: + hf_model, prime_model = get_model_pairs() + + for layer in hf_model.model.layers: + layer.mlp = _PassthroughHfZayaMoe() + for layer in prime_model.model.layers: + layer.mlp = _PassthroughPrimeZayaBlock() + + with torch.device("cuda"), default_dtype(torch.float32): + input_ids = torch.randint(0, hf_model.config.vocab_size, (1, 12)) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0) + + _assert_logits_and_embed_grads_close(hf_model, prime_model, input_ids, position_ids) + + +def test_zaya_mlp_only() -> None: + hf_model, prime_model = get_model_pairs() + + def identity_attn_hf(hidden_states, *args, **kwargs): + return hidden_states, None + + def identity_attn_prime( + hidden_states, + *args, + **kwargs, + ): + return hidden_states, None + + for layer in hf_model.model.layers: + layer.self_attn.forward = identity_attn_hf + for layer in prime_model.model.layers: + if hasattr(layer, "self_attn"): + layer.self_attn.forward = identity_attn_prime + + with torch.device("cuda"), default_dtype(torch.float32): + input_ids = torch.randint(0, hf_model.config.vocab_size, (1, 12)) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0) + + _assert_logits_and_embed_grads_close(hf_model, prime_model, input_ids, position_ids) + + +def test_zaya_packed_matches_unpacked() -> None: + _, prime_model = get_model_pairs() + prime_model.eval() + + with torch.device("cuda"), default_dtype(torch.float32), torch.no_grad(): + first_ids = torch.randint(0, prime_model.config.vocab_size, (1, 5)) + second_ids = torch.randint(0, prime_model.config.vocab_size, (1, 7)) + packed_ids = torch.cat([first_ids, second_ids], dim=1) + packed_position_ids = torch.cat( + [ + torch.arange(first_ids.shape[1]), + torch.arange(second_ids.shape[1]), + ] + ).unsqueeze(0) + + first_logits = prime_model(input_ids=first_ids, position_ids=torch.arange(first_ids.shape[1]).unsqueeze(0))[ + "logits" + ] + second_logits = prime_model( + input_ids=second_ids, + position_ids=torch.arange(second_ids.shape[1]).unsqueeze(0), + )["logits"] + packed_logits = prime_model(input_ids=packed_ids, position_ids=packed_position_ids)["logits"] + + expected_logits = torch.cat([first_logits, second_logits], dim=1) + logits_diff = packed_logits - expected_logits + assert torch.allclose(logits_diff, torch.zeros_like(logits_diff), atol=LOGITS_ATOL), ( + f"Max logits diff: {logits_diff.abs().max()}" + ) + + +@pytest.mark.slow +def test_zaya() -> None: + snapshot = Path(snapshot_download(repo_id="JJJYmmm/ZAYA1-8B-HF", repo_type="model")) + dtype = torch.bfloat16 + device = torch.device("cuda") + + # hf_model = HFZayaForCausalLM.from_pretrained("Zyphra/ZAYA1-8B", torch_dtype=dtype) # Original Zyphra weights (official) + hf_model = HFZayaForCausalLM.from_pretrained("JJJYmmm/ZAYA1-8B-HF", torch_dtype=dtype) # HF PR + hf_model.to(device) + attn_impl = getattr( + hf_model.config, + "_attn_implementation", + getattr(hf_model.config, "attn_implementation", "sdpa"), + ) + prime_config = ZayaConfig.from_pretrained(snapshot) + prime_config._attn_implementation = attn_impl + + prime_model = PrimeRLZayaForCausalLM._from_config(prime_config) + sd = load_state_dict(snapshot) + PrimeRLZayaForCausalLM.convert_to_prime(sd) + prime_model.load_state_dict(sd, strict=False) + + prime_model.to(device=device, dtype=dtype) + prime_model.eval() + hf_model.eval() + + vocab = hf_model.config.vocab_size + torch.manual_seed(0) + input_ids = torch.randint(0, vocab, (4, 16), device=device) + position_ids = torch.arange(input_ids.shape[1], device=device).unsqueeze(0).expand(4, -1) + + with torch.no_grad(): + hf_out = hf_model(input_ids=input_ids, position_ids=position_ids, use_cache=False) + prime_out = prime_model(input_ids=input_ids, position_ids=position_ids) + + hf_logits = hf_out.logits.float().cpu() + prime_logits = prime_out["logits"].float().cpu() + max_abs = (prime_logits - hf_logits).abs().max().item() + + assert torch.allclose(prime_logits, hf_logits, atol=5e-2), ( + f"Forward logits mismatch max abs diff {max_abs} (atol=5e-2)" + ) + + +def test_zaya_tiny_roundtrip() -> None: + hf_model, prime_model = get_model_pairs() + + with torch.device("cuda"), default_dtype(torch.float32): + input_ids = torch.randint(0, hf_model.config.vocab_size, (1, 12)) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0) + + _assert_logits_and_embed_grads_close(hf_model, prime_model, input_ids, position_ids) + + with torch.device("cuda"), default_dtype(torch.float32): + hf_from_prime_model = HFZayaForCausalLM(hf_model.config) + converted_state_dict = prime_model.convert_to_hf(prime_model.state_dict()) + hf_from_prime_model.load_state_dict(converted_state_dict) + + hf_model.zero_grad(set_to_none=True) + hf_from_prime_model.zero_grad(set_to_none=True) + hf_output = hf_model(input_ids=input_ids, position_ids=position_ids, use_cache=False) + hf_from_prime_output = hf_from_prime_model(input_ids=input_ids, position_ids=position_ids, use_cache=False) + hf_output.logits.sum().backward() + hf_from_prime_output.logits.sum().backward() + + logits_diff = hf_from_prime_output.logits - hf_output.logits + assert torch.allclose(logits_diff, torch.zeros_like(logits_diff), atol=LOGITS_ATOL), ( + f"Max logits diff: {logits_diff.abs().max()}" + ) + grad_diff = hf_from_prime_model.model.embed_tokens.weight.grad - hf_model.model.embed_tokens.weight.grad + assert torch.allclose(grad_diff, torch.zeros_like(grad_diff), atol=EMBED_GRAD_ATOL), ( + f"Max grad diff: {grad_diff.abs().max()}" + ) + + +def test_zaya_attention_mask() -> None: + hf_model, prime_model = get_model_pairs() + + with torch.device("cuda"), default_dtype(torch.float32): + input_ids = torch.randint(0, hf_model.config.vocab_size, (1, 12)) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0) + attention_mask = torch.ones_like(input_ids) + attention_mask[:, -3:] = 0 + + _assert_logits_and_embed_grads_close(hf_model, prime_model, input_ids, position_ids, attention_mask) + + +def _run_zaya_moe_expert_parallel_parity(rank: int, world_size: int, init_file: str) -> None: + torch.cuda.set_device(rank) + dist.init_process_group("nccl", init_method=f"file://{init_file}", rank=rank, world_size=world_size) + try: + device = torch.device("cuda", rank) + torch.manual_seed(0) + config = _tiny_zaya_config() + config.num_experts = world_size + config.num_experts_per_tok = 1 + config.moe_router_topk = 1 + config.use_grouped_mm = False + + with torch.device(device), default_dtype(torch.float32): + local_layer = ZayaDecoderLayer(config, layer_idx=0) + ep_layer = ZayaDecoderLayer(config, layer_idx=0) + ep_layer.load_state_dict(local_layer.state_dict()) + parallelize_module( + ep_layer.mlp.experts, + DeviceMesh("cuda", list(range(world_size))), + ExpertParallel(), + ) + + def identity_attn(hidden_states, *args, **kwargs): + return hidden_states, None + + local_layer.self_attn.forward = identity_attn + ep_layer.self_attn.forward = identity_attn + local_layer.train() + ep_layer.train() + + hidden_states = torch.randn(2, 8, config.hidden_size, device=device, requires_grad=True) + ep_hidden_states = hidden_states.detach().clone().requires_grad_() + routed_experts = (torch.arange(hidden_states.shape[1], device=device).reshape(1, -1, 1) % world_size).expand( + hidden_states.shape[0], -1, -1 + ) + + local_output, _ = local_layer(hidden_states, routed_experts=routed_experts) + ep_output, _ = ep_layer(ep_hidden_states, routed_experts=routed_experts) + + max_diff = (ep_output - local_output).abs().max().item() + assert torch.allclose(ep_output, local_output, atol=1e-5, rtol=1e-5), max_diff + + grad = torch.randn_like(local_output) + local_output.backward(grad) + ep_output.backward(grad) + + max_hidden_grad_diff = (ep_hidden_states.grad - hidden_states.grad).abs().max().item() + assert torch.allclose(ep_hidden_states.grad, hidden_states.grad, atol=1e-5, rtol=1e-5), max_hidden_grad_diff + finally: + dist.destroy_process_group() + + +@pytest.mark.gpu +def test_zaya_moe_expert_parallel_matches_local_output_and_backward(tmp_path) -> None: + if torch.cuda.device_count() < 2: + pytest.skip(_REQUIRES_TWO_CUDA_DEVICES_MSG) + world_size = 2 + init_file = tmp_path / "dist_init" + + mp.spawn( + _run_zaya_moe_expert_parallel_parity, + args=(world_size, init_file.as_posix()), + nprocs=world_size, + join=True, + ) + + +def test_zaya_flash_attention_2() -> None: + pytest.importorskip("flash_attn") + torch.manual_seed(0) + hf_model, prime_model = get_model_pairs(prime_attn_implementation="flash_attention_2", dtype=torch.bfloat16) + + with torch.device("cuda"): + input_ids = torch.randint(0, hf_model.config.vocab_size, (1, 12)) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0) + + _assert_logits_and_embed_grads_close(hf_model, prime_model, input_ids, position_ids) + + +def _tiny_zaya_config() -> ZayaConfig: + config = ZayaConfig( + vocab_size=32, + hidden_size=16, + num_hidden_layers=1, + num_experts=2, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=4, + moe_intermediate_size=8, + router_hidden_size=4, + use_grouped_mm=False, + layer_types=["hybrid"], + rope_parameters={ + "hybrid": { + "rope_type": "default", + "rope_theta": 10000.0, + "partial_rotary_factor": 0.5, + } + }, + ) + config._attn_implementation = "sdpa" + return config + + +def test_zaya_attention_cp_attributes_propagate_to_children(): + config = _tiny_zaya_config() + attention = ZayaFlashAttention(config, layer_idx=0) + cp_group = object() + + attention.set_context_parallel_attributes(cp_group, cp_rank=1, cp_world_size=2) + + assert attention.cp_enabled + assert attention._cp_group is cp_group + assert attention._cp_rank == 1 + assert attention._cp_world_size == 2 + assert attention.qkv_proj._cp_group is cp_group + assert attention.qkv_proj._cp_rank == 1 + assert attention.qkv_proj._cp_world_size == 2 + assert attention.qk_norm._cp_rank == 1 + assert attention.qk_norm._cp_world_size == 2 + + +def test_zaya_cca_cp_channel_indices_preserve_q_then_k_order(): + config = _tiny_zaya_config() + projection = ZayaCCAProjection(config, layer_idx=0) + projection.set_context_parallel_attributes(object(), cp_rank=1, cp_world_size=2) + + channel_indices = projection._local_head_channel_indices() + + assert channel_indices.tolist() == [8, 9, 10, 11, 12, 13, 14, 15, 20, 21, 22, 23] + + +def test_zaya_qk_norm_slices_temperature_for_local_kv_heads(): + config = _tiny_zaya_config() + qk_norm = ZayaQKNorm(config) + qk_norm.temp.data = torch.tensor([2.0, 3.0]) + qk_norm.set_context_parallel_attributes(object(), cp_rank=1, cp_world_size=2) + + query_states = torch.ones(1, 2, 2, config.head_dim) + key_states = torch.ones(1, 2, 1, config.head_dim) + + query_out, key_out = qk_norm(query_states, key_states) + + expected_norm = config.head_dim**0.5 + assert torch.allclose(query_out.norm(p=2, dim=-1), torch.full((1, 2, 2), expected_norm)) + assert torch.allclose(key_out, torch.full_like(key_out, 3.0)) + + +def _zaya_position_embeddings(config: ZayaConfig, hidden_states: torch.Tensor): + rope_parameters = config.rope_parameters["hybrid"] + rope_config = copy.copy(config) + rope_config.rope_parameters = rope_parameters + rotary_emb = ZayaRotaryEmbedding( + RotaryEmbeddingConfig( + max_position_embeddings=config.max_position_embeddings, + rope_type=rope_parameters["rope_type"], + model_config=rope_config, + ) + ).to(hidden_states.device) + position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0) + return rotary_emb(hidden_states, position_ids) + + +def _run_zaya_attention_cp_parity( + rank: int, + world_size: int, + init_file: str, + attention_name: str, + backend: str, + device_type: str, + check_backward: bool, +): + if device_type == "cuda": + torch.cuda.set_device(rank) + dist.init_process_group(backend, init_method=f"file://{init_file}", rank=rank, world_size=world_size) + try: + device = torch.device(device_type, rank) if device_type == "cuda" else torch.device(device_type) + torch.manual_seed(0) + config = _tiny_zaya_config() + attention_cls = {"sdpa": ZayaSPDAAttention, "flash": ZayaFlashAttention}[attention_name] + dtype = torch.bfloat16 if attention_name == "flash" else torch.float32 + full_attention = attention_cls(config, layer_idx=0).to(device=device, dtype=dtype) + cp_attention = attention_cls(config, layer_idx=0).to(device=device, dtype=dtype) + cp_attention.load_state_dict(full_attention.state_dict()) + cp_attention.set_context_parallel_attributes(dist.group.WORLD, rank, world_size) + full_attention.train(check_backward) + cp_attention.train(check_backward) + + hidden_states = torch.randn(1, 8, config.hidden_size, device=device, dtype=dtype) + position_embeddings = _zaya_position_embeddings(config, hidden_states) + cu_seqlens = torch.tensor([0, hidden_states.shape[1]], device=device, dtype=torch.int32) + max_seqlen = hidden_states.shape[1] + + full_hidden_states = hidden_states.detach().clone().requires_grad_(check_backward) + cp_hidden_states = ( + hidden_states.chunk(world_size, dim=1)[rank].detach().clone().contiguous().requires_grad_(check_backward) + ) + + full_output, _ = full_attention( + full_hidden_states, + position_embeddings=position_embeddings, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + cp_output, _ = cp_attention( + cp_hidden_states, + position_embeddings=position_embeddings, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + expected = full_output.chunk(world_size, dim=1)[rank] + atol = 1e-2 if attention_name == "flash" else 1e-5 + rtol = 1e-2 if attention_name == "flash" else 1e-5 + max_diff = (cp_output - expected).abs().max().item() + assert torch.allclose(cp_output, expected, atol=atol, rtol=rtol), max_diff + + if check_backward: + grad = torch.randn_like(full_output) + full_output.backward(grad) + cp_output.backward(grad.chunk(world_size, dim=1)[rank]) + + hidden_grad = full_hidden_states.grad.chunk(world_size, dim=1)[rank] + max_hidden_grad_diff = (cp_hidden_states.grad - hidden_grad).abs().max().item() + assert torch.allclose(cp_hidden_states.grad, hidden_grad, atol=atol, rtol=rtol), max_hidden_grad_diff + + for (name, full_param), cp_param in zip(full_attention.named_parameters(), cp_attention.parameters()): + dist.all_reduce(cp_param.grad, op=dist.ReduceOp.SUM) + max_grad_diff = (cp_param.grad - full_param.grad).abs().max().item() + assert torch.allclose(cp_param.grad, full_param.grad, atol=atol, rtol=rtol), (name, max_grad_diff) + finally: + dist.destroy_process_group() + + +def test_zaya_sdpa_context_parallel_matches_non_cp_output_and_backward(tmp_path): + if torch.cuda.device_count() < 2: + pytest.skip(_REQUIRES_TWO_CUDA_DEVICES_MSG) + world_size = 2 + init_file = tmp_path / "dist_init" + + mp.spawn( + _run_zaya_attention_cp_parity, + args=(world_size, init_file.as_posix(), "sdpa", "nccl", "cuda", True), + nprocs=world_size, + join=True, + ) + + +def test_zaya_flash_context_parallel_matches_non_cp_output_and_backward(tmp_path): + pytest.importorskip("flash_attn") + if torch.cuda.device_count() < 2: + pytest.skip(_REQUIRES_TWO_CUDA_DEVICES_MSG) + + world_size = 2 + init_file = tmp_path / "dist_init" + + mp.spawn( + _run_zaya_attention_cp_parity, + args=(world_size, init_file.as_posix(), "flash", "nccl", "cuda", True), + nprocs=world_size, + join=True, + ) diff --git a/uv.lock b/uv.lock index 39efa7597f..01288dd4f4 100644 --- a/uv.lock +++ b/uv.lock @@ -72,7 +72,7 @@ overrides = [ { name = "nvidia-cutlass-dsl", specifier = ">=4.4.1" }, { name = "openenv-core" }, { name = "torch", specifier = ">=2.9.0", index = "https://download.pytorch.org/whl/cu128" }, - { name = "transformers", git = "https://github.com/huggingface/transformers.git?rev=c1c3424" }, + { name = "transformers", git = "https://github.com/JJJYmmm/transformers?rev=d362c90c378b4b32b54513f1627b6d9d59ccc6a1" }, ] [[package]] @@ -3872,8 +3872,7 @@ dependencies = [ { name = "transformers", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "uvloop", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "verifiers", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "vllm", version = "0.21.0+cu129", source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "vllm", version = "0.21.0+cu129", source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_x86_64.whl" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "vllm", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "wandb", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, ] @@ -4021,12 +4020,10 @@ requires-dist = [ { name = "torchdata", specifier = ">=0.11.0" }, { name = "torchtitan", git = "https://github.com/pytorch/torchtitan?rev=a1fdd7e" }, { name = "torchvision", index = "https://download.pytorch.org/whl/cu128" }, - { name = "transformers", git = "https://github.com/huggingface/transformers.git?rev=c1c3424" }, + { name = "transformers", git = "https://github.com/JJJYmmm/transformers?rev=d362c90c378b4b32b54513f1627b6d9d59ccc6a1" }, { name = "uvloop", specifier = ">=0.21.0" }, { name = "verifiers", editable = "deps/verifiers" }, - { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64'", specifier = ">=0.21.0" }, - { name = "vllm", marker = "platform_machine == 'aarch64'", url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl" }, - { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_x86_64.whl" }, + { name = "vllm", git = "https://github.com/Zyphra/vllm.git?branch=zaya1-pr" }, { name = "vllm-router", marker = "platform_machine == 'x86_64' and extra == 'disagg'", url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.22/vllm_router-0.1.22-cp38-abi3-manylinux_2_28_x86_64.whl" }, { name = "wandb", specifier = ">=0.26.1" }, { name = "wiki-search", marker = "extra == 'envs'", editable = "deps/verifiers/environments/wiki_search" }, @@ -4725,7 +4722,7 @@ requires-dist = [ { name = "openai", specifier = ">=1.108.1" }, { name = "openai-harmony", specifier = ">=0.0.8" }, { name = "tiktoken" }, - { name = "transformers", git = "https://github.com/huggingface/transformers.git?rev=c1c3424" }, + { name = "transformers", git = "https://github.com/JJJYmmm/transformers?rev=d362c90c378b4b32b54513f1627b6d9d59ccc6a1" }, ] [package.metadata.requires-dev] @@ -5641,8 +5638,8 @@ wheels = [ [[package]] name = "transformers" -version = "5.5.0" -source = { git = "https://github.com/huggingface/transformers.git?rev=c1c3424#c1c34249fa27deefbd4a377dfbf883a39baf5c6d" } +version = "5.8.0.dev0" +source = { git = "https://github.com/JJJYmmm/transformers?rev=d362c90c378b4b32b54513f1627b6d9d59ccc6a1#d362c90c378b4b32b54513f1627b6d9d59ccc6a1" } dependencies = [ { name = "huggingface-hub", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "numpy", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, @@ -5893,8 +5890,7 @@ rl = [ { name = "requests", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "torch", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "transformers", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "vllm", version = "0.21.0+cu129", source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "vllm", version = "0.21.0+cu129", source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_x86_64.whl" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "vllm", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "wandb", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, ] ta = [ @@ -5966,11 +5962,9 @@ requires-dist = [ { name = "textual" }, { name = "tomli", marker = "python_full_version < '3.11'" }, { name = "torch", marker = "extra == 'rl'", specifier = ">=2.8.0,<2.9.0", index = "https://download.pytorch.org/whl/cu128" }, - { name = "transformers", marker = "extra == 'rl'", git = "https://github.com/huggingface/transformers.git?rev=c1c3424" }, + { name = "transformers", marker = "extra == 'rl'", git = "https://github.com/JJJYmmm/transformers?rev=d362c90c378b4b32b54513f1627b6d9d59ccc6a1" }, { name = "typing-extensions", marker = "python_full_version < '3.12'" }, - { name = "vllm", marker = "platform_machine == 'aarch64' and extra == 'rl'", url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl" }, - { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'rl'", specifier = ">=0.10.0,<0.11.0" }, - { name = "vllm", marker = "platform_machine == 'x86_64' and extra == 'rl'", url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_x86_64.whl" }, + { name = "vllm", marker = "extra == 'rl'", git = "https://github.com/Zyphra/vllm.git?branch=zaya1-pr" }, { name = "wandb", marker = "extra == 'rl'" }, ] provides-extras = ["browser", "openenv", "renderers", "rg", "rl", "ta"] @@ -6013,358 +6007,79 @@ wheels = [ [[package]] name = "vllm" -version = "0.21.0+cu129" -source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl" } -resolution-markers = [ - "platform_machine == 'aarch64' and sys_platform == 'linux'", -] -dependencies = [ - { name = "aiohttp", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "anthropic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "apache-tvm-ffi", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "blake3", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "cachetools", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "cbor2", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "cloudpickle", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "compressed-tensors", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "depyf", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "diskcache", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "einops", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "fastapi", extra = ["standard"], marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "fastsafetensors", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "filelock", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "flashinfer-cubin", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "flashinfer-python", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "gguf", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "ijson", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "lark", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "llguidance", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "lm-format-enforcer", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "mcp", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "mistral-common", extra = ["image"], marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "model-hosting-container-standards", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "msgspec", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "ninja", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "numba", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "nvidia-cudnn-frontend", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "nvidia-cutlass-dsl", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "openai", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "openai-harmony", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "opencv-python-headless", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "opentelemetry-api", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "opentelemetry-exporter-otlp", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "opentelemetry-sdk", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "opentelemetry-semantic-conventions-ai", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "outlines-core", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "partial-json-parser", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "pillow", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "prometheus-client", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "prometheus-fastapi-instrumentator", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "protobuf", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "psutil", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "py-cpuinfo", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "pybase64", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "pydantic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "python-json-logger", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "pyyaml", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "pyzmq", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "quack-kernels", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "regex", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "requests", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "sentencepiece", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "setproctitle", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "setuptools", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "six", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "tiktoken", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "tilelang", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "tokenizers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "tokenspeed-mla", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "torch", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "torchaudio", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "torchvision", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "tqdm", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "transformers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "watchfiles", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "xgrammar", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, -] -wheels = [ - { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:de0af3ab4c0cc86e98712bbe89bb30eae967b5bf87873920b7cf13bbfd096aaa" }, -] - -[package.metadata] -requires-dist = [ - { name = "aiohttp", specifier = ">=3.13.3" }, - { name = "anthropic", specifier = ">=0.71.0" }, - { name = "apache-tvm-ffi", specifier = "==0.1.9" }, - { name = "av", marker = "extra == 'audio'" }, - { name = "blake3" }, - { name = "cachetools" }, - { name = "cbor2" }, - { name = "cloudpickle" }, - { name = "compressed-tensors", specifier = "==0.15.0.1" }, - { name = "datasets", marker = "extra == 'bench'" }, - { name = "depyf", specifier = "==0.20.0" }, - { name = "diskcache", specifier = "==5.6.3" }, - { name = "einops" }, - { name = "fastapi", extras = ["standard"], specifier = ">=0.115.0" }, - { name = "fastsafetensors", specifier = ">=0.2.2" }, - { name = "fastsafetensors", marker = "extra == 'fastsafetensors'", specifier = ">=0.2.2" }, - { name = "filelock", specifier = ">=3.16.1" }, - { name = "flashinfer-cubin", specifier = "==0.6.8.post1" }, - { name = "flashinfer-python", specifier = "==0.6.8.post1" }, - { name = "gguf", specifier = ">=0.17.0" }, - { name = "helion", marker = "extra == 'helion'", specifier = "==1.0.0" }, - { name = "ijson" }, - { name = "instanttensor", marker = "extra == 'instanttensor'", specifier = ">=0.1.5" }, - { name = "lark", specifier = "==1.2.2" }, - { name = "llguidance", marker = "platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'ppc64le' or platform_machine == 'x86_64'", specifier = ">=1.3.0,<1.4.0" }, - { name = "lm-format-enforcer", specifier = "==0.11.3" }, - { name = "matplotlib", marker = "extra == 'bench'" }, - { name = "mcp" }, - { name = "mistral-common", extras = ["audio"], marker = "extra == 'audio'" }, - { name = "mistral-common", extras = ["image"], specifier = ">=1.11.2" }, - { name = "model-hosting-container-standards", specifier = ">=0.1.14,<1.0.0" }, - { name = "msgspec" }, - { name = "ninja" }, - { name = "numba", specifier = "==0.65.0" }, - { name = "numpy" }, - { name = "nvidia-cudnn-frontend", specifier = ">=1.13.0,<1.19.0" }, - { name = "nvidia-cutlass-dsl", specifier = "==4.4.2" }, - { name = "openai", specifier = ">=2.0.0" }, - { name = "openai-harmony", specifier = ">=0.0.3" }, - { name = "opencv-python-headless", specifier = ">=4.13.0" }, - { name = "opentelemetry-api", specifier = ">=1.27.0" }, - { name = "opentelemetry-api", marker = "extra == 'otel'", specifier = ">=1.26.0" }, - { name = "opentelemetry-exporter-otlp", specifier = ">=1.27.0" }, - { name = "opentelemetry-exporter-otlp", marker = "extra == 'otel'", specifier = ">=1.26.0" }, - { name = "opentelemetry-sdk", specifier = ">=1.27.0" }, - { name = "opentelemetry-sdk", marker = "extra == 'otel'", specifier = ">=1.26.0" }, - { name = "opentelemetry-semantic-conventions-ai", specifier = ">=0.4.1" }, - { name = "opentelemetry-semantic-conventions-ai", marker = "extra == 'otel'", specifier = ">=0.4.1" }, - { name = "outlines-core", specifier = "==0.2.14" }, - { name = "pandas", marker = "extra == 'bench'" }, - { name = "partial-json-parser" }, - { name = "pillow" }, - { name = "plotly", marker = "extra == 'bench'" }, - { name = "prometheus-client", specifier = ">=0.18.0" }, - { name = "prometheus-fastapi-instrumentator", specifier = ">=7.0.0" }, - { name = "protobuf", specifier = ">=5.29.6,!=6.30.*,!=6.31.*,!=6.32.*,!=6.33.0.*,!=6.33.1.*,!=6.33.2.*,!=6.33.3.*,!=6.33.4.*" }, - { name = "psutil" }, - { name = "py-cpuinfo" }, - { name = "pybase64" }, - { name = "pydantic", specifier = ">=2.12.0" }, - { name = "python-json-logger" }, - { name = "pyyaml" }, - { name = "pyzmq", specifier = ">=25.0.0" }, - { name = "quack-kernels", specifier = ">=0.3.3" }, - { name = "regex" }, - { name = "requests", specifier = ">=2.26.0" }, - { name = "runai-model-streamer", extras = ["azure", "gcs", "s3"], marker = "extra == 'runai'", specifier = ">=0.15.7" }, - { name = "scipy", marker = "extra == 'audio'" }, - { name = "scipy", marker = "extra == 'bench'" }, - { name = "seaborn", marker = "extra == 'bench'" }, - { name = "sentencepiece" }, - { name = "setproctitle" }, - { name = "setuptools", marker = "python_full_version >= '3.12'", specifier = ">=77.0.3,<81.0.0" }, - { name = "six", marker = "python_full_version >= '3.12'", specifier = ">=1.16.0" }, - { name = "smg-grpc-servicer", extras = ["vllm"], marker = "extra == 'grpc'", specifier = ">=0.5.2" }, - { name = "soundfile", marker = "extra == 'audio'" }, - { name = "tensorizer", marker = "extra == 'tensorizer'", specifier = "==2.10.1" }, - { name = "tiktoken", specifier = ">=0.6.0" }, - { name = "tilelang", specifier = "==0.1.9" }, - { name = "tokenizers", specifier = ">=0.21.1" }, - { name = "tokenspeed-mla", specifier = "==0.1.2" }, - { name = "torch", specifier = "==2.11.0" }, - { name = "torchaudio", specifier = "==2.11.0" }, - { name = "torchvision", specifier = "==0.26.0" }, - { name = "tqdm" }, - { name = "transformers", specifier = ">=4.56.0,!=5.0.*,!=5.1.*,!=5.2.*,!=5.3.*,!=5.4.*,!=5.5.0" }, - { name = "typing-extensions", specifier = ">=4.10" }, - { name = "watchfiles" }, - { name = "xgrammar", marker = "platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'ppc64le' or platform_machine == 's390x' or platform_machine == 'x86_64'", specifier = ">=0.2.0,<1.0.0" }, - { name = "zentorch-weekly", marker = "extra == 'zen'", specifier = "==5.2.1.dev20260408" }, -] -provides-extras = ["zen", "bench", "tensorizer", "fastsafetensors", "instanttensor", "runai", "audio", "video", "flashinfer", "helion", "grpc", "otel"] - -[[package]] -name = "vllm" -version = "0.21.0+cu129" -source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_x86_64.whl" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] +version = "0.1.dev16384+g6c3178819.cu128" +source = { git = "https://github.com/Zyphra/vllm.git?branch=zaya1-pr#6c3178819e623ab915dd17e0b965269c3b745abd" } dependencies = [ - { name = "aiohttp", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "anthropic", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "apache-tvm-ffi", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "blake3", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "cachetools", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "cbor2", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "cloudpickle", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "compressed-tensors", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "depyf", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "diskcache", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "einops", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "fastapi", extra = ["standard"], marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "fastsafetensors", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "filelock", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "flashinfer-cubin", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "flashinfer-python", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "gguf", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "ijson", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "lark", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "llguidance", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "lm-format-enforcer", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "mcp", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "mistral-common", extra = ["image"], marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "model-hosting-container-standards", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "msgspec", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "ninja", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "numba", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "numpy", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "nvidia-cudnn-frontend", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "nvidia-cutlass-dsl", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "openai", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "openai-harmony", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "opencv-python-headless", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "opentelemetry-api", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "opentelemetry-exporter-otlp", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "opentelemetry-sdk", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "opentelemetry-semantic-conventions-ai", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "outlines-core", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "partial-json-parser", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "pillow", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "prometheus-client", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "prometheus-fastapi-instrumentator", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "protobuf", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "psutil", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "py-cpuinfo", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "pybase64", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "pydantic", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "python-json-logger", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "pyyaml", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "pyzmq", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "quack-kernels", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "regex", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "requests", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "sentencepiece", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "setproctitle", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "setuptools", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "six", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "tiktoken", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "tilelang", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "tokenizers", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "tokenspeed-mla", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "torch", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "torchaudio", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "torchvision", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "tqdm", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "transformers", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "typing-extensions", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "watchfiles", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "xgrammar", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, -] -wheels = [ - { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:920777691e340df7a8328adfb1e57b9996dbb537edfb654dd32f70844f5f423d" }, + { name = "aiohttp", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "anthropic", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "apache-tvm-ffi", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "blake3", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "cachetools", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "cbor2", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "cloudpickle", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "compressed-tensors", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "depyf", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "diskcache", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "einops", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "fastapi", extra = ["standard"], marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "fastsafetensors", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "filelock", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "flashinfer-cubin", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "flashinfer-python", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "gguf", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "ijson", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "lark", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "llguidance", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "lm-format-enforcer", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "mcp", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "mistral-common", extra = ["image"], marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "model-hosting-container-standards", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "msgspec", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "ninja", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "numba", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "numpy", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "nvidia-cudnn-frontend", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "nvidia-cutlass-dsl", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "openai", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "openai-harmony", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "opencv-python-headless", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "opentelemetry-api", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "opentelemetry-exporter-otlp", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "opentelemetry-sdk", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "opentelemetry-semantic-conventions-ai", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "outlines-core", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "partial-json-parser", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "pillow", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "prometheus-client", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "prometheus-fastapi-instrumentator", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "protobuf", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "psutil", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "py-cpuinfo", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "pybase64", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "pydantic", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "python-json-logger", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "pyyaml", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "pyzmq", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "quack-kernels", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "regex", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "requests", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "sentencepiece", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "setproctitle", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "setuptools", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "six", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "tiktoken", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "tilelang", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "tokenizers", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "torch", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "torchaudio", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "torchvision", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "tqdm", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "transformers", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "typing-extensions", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "watchfiles", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "xgrammar", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, ] -[package.metadata] -requires-dist = [ - { name = "aiohttp", specifier = ">=3.13.3" }, - { name = "anthropic", specifier = ">=0.71.0" }, - { name = "apache-tvm-ffi", specifier = "==0.1.9" }, - { name = "av", marker = "extra == 'audio'" }, - { name = "blake3" }, - { name = "cachetools" }, - { name = "cbor2" }, - { name = "cloudpickle" }, - { name = "compressed-tensors", specifier = "==0.15.0.1" }, - { name = "datasets", marker = "extra == 'bench'" }, - { name = "depyf", specifier = "==0.20.0" }, - { name = "diskcache", specifier = "==5.6.3" }, - { name = "einops" }, - { name = "fastapi", extras = ["standard"], specifier = ">=0.115.0" }, - { name = "fastsafetensors", specifier = ">=0.2.2" }, - { name = "fastsafetensors", marker = "extra == 'fastsafetensors'", specifier = ">=0.2.2" }, - { name = "filelock", specifier = ">=3.16.1" }, - { name = "flashinfer-cubin", specifier = "==0.6.8.post1" }, - { name = "flashinfer-python", specifier = "==0.6.8.post1" }, - { name = "gguf", specifier = ">=0.17.0" }, - { name = "helion", marker = "extra == 'helion'", specifier = "==1.0.0" }, - { name = "ijson" }, - { name = "instanttensor", marker = "extra == 'instanttensor'", specifier = ">=0.1.5" }, - { name = "lark", specifier = "==1.2.2" }, - { name = "llguidance", marker = "platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'ppc64le' or platform_machine == 'x86_64'", specifier = ">=1.3.0,<1.4.0" }, - { name = "lm-format-enforcer", specifier = "==0.11.3" }, - { name = "matplotlib", marker = "extra == 'bench'" }, - { name = "mcp" }, - { name = "mistral-common", extras = ["audio"], marker = "extra == 'audio'" }, - { name = "mistral-common", extras = ["image"], specifier = ">=1.11.2" }, - { name = "model-hosting-container-standards", specifier = ">=0.1.14,<1.0.0" }, - { name = "msgspec" }, - { name = "ninja" }, - { name = "numba", specifier = "==0.65.0" }, - { name = "numpy" }, - { name = "nvidia-cudnn-frontend", specifier = ">=1.13.0,<1.19.0" }, - { name = "nvidia-cutlass-dsl", specifier = "==4.4.2" }, - { name = "openai", specifier = ">=2.0.0" }, - { name = "openai-harmony", specifier = ">=0.0.3" }, - { name = "opencv-python-headless", specifier = ">=4.13.0" }, - { name = "opentelemetry-api", specifier = ">=1.27.0" }, - { name = "opentelemetry-api", marker = "extra == 'otel'", specifier = ">=1.26.0" }, - { name = "opentelemetry-exporter-otlp", specifier = ">=1.27.0" }, - { name = "opentelemetry-exporter-otlp", marker = "extra == 'otel'", specifier = ">=1.26.0" }, - { name = "opentelemetry-sdk", specifier = ">=1.27.0" }, - { name = "opentelemetry-sdk", marker = "extra == 'otel'", specifier = ">=1.26.0" }, - { name = "opentelemetry-semantic-conventions-ai", specifier = ">=0.4.1" }, - { name = "opentelemetry-semantic-conventions-ai", marker = "extra == 'otel'", specifier = ">=0.4.1" }, - { name = "outlines-core", specifier = "==0.2.14" }, - { name = "pandas", marker = "extra == 'bench'" }, - { name = "partial-json-parser" }, - { name = "pillow" }, - { name = "plotly", marker = "extra == 'bench'" }, - { name = "prometheus-client", specifier = ">=0.18.0" }, - { name = "prometheus-fastapi-instrumentator", specifier = ">=7.0.0" }, - { name = "protobuf", specifier = ">=5.29.6,!=6.30.*,!=6.31.*,!=6.32.*,!=6.33.0.*,!=6.33.1.*,!=6.33.2.*,!=6.33.3.*,!=6.33.4.*" }, - { name = "psutil" }, - { name = "py-cpuinfo" }, - { name = "pybase64" }, - { name = "pydantic", specifier = ">=2.12.0" }, - { name = "python-json-logger" }, - { name = "pyyaml" }, - { name = "pyzmq", specifier = ">=25.0.0" }, - { name = "quack-kernels", specifier = ">=0.3.3" }, - { name = "regex" }, - { name = "requests", specifier = ">=2.26.0" }, - { name = "runai-model-streamer", extras = ["azure", "gcs", "s3"], marker = "extra == 'runai'", specifier = ">=0.15.7" }, - { name = "scipy", marker = "extra == 'audio'" }, - { name = "scipy", marker = "extra == 'bench'" }, - { name = "seaborn", marker = "extra == 'bench'" }, - { name = "sentencepiece" }, - { name = "setproctitle" }, - { name = "setuptools", marker = "python_full_version >= '3.12'", specifier = ">=77.0.3,<81.0.0" }, - { name = "six", marker = "python_full_version >= '3.12'", specifier = ">=1.16.0" }, - { name = "smg-grpc-servicer", extras = ["vllm"], marker = "extra == 'grpc'", specifier = ">=0.5.2" }, - { name = "soundfile", marker = "extra == 'audio'" }, - { name = "tensorizer", marker = "extra == 'tensorizer'", specifier = "==2.10.1" }, - { name = "tiktoken", specifier = ">=0.6.0" }, - { name = "tilelang", specifier = "==0.1.9" }, - { name = "tokenizers", specifier = ">=0.21.1" }, - { name = "tokenspeed-mla", specifier = "==0.1.2" }, - { name = "torch", specifier = "==2.11.0" }, - { name = "torchaudio", specifier = "==2.11.0" }, - { name = "torchvision", specifier = "==0.26.0" }, - { name = "tqdm" }, - { name = "transformers", specifier = ">=4.56.0,!=5.0.*,!=5.1.*,!=5.2.*,!=5.3.*,!=5.4.*,!=5.5.0" }, - { name = "typing-extensions", specifier = ">=4.10" }, - { name = "watchfiles" }, - { name = "xgrammar", marker = "platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'ppc64le' or platform_machine == 's390x' or platform_machine == 'x86_64'", specifier = ">=0.2.0,<1.0.0" }, - { name = "zentorch-weekly", marker = "extra == 'zen'", specifier = "==5.2.1.dev20260408" }, -] -provides-extras = ["zen", "bench", "tensorizer", "fastsafetensors", "instanttensor", "runai", "audio", "video", "flashinfer", "helion", "grpc", "otel"] - [[package]] name = "vllm-router" version = "0.1.22"