Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ build/
.cache
.vscode/
hosts.json
uv.lock
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ benchmark = [

dev = [
"black>=24.3",
"cmake>=3.27",
"ninja>=1.11",
"ruff>=0.4",
"setuptools>=68",
"pytest>=8.2",
"pytest-mock>=3.14",
"pytest-cov>=5.0",
Expand Down
26 changes: 6 additions & 20 deletions src/backend/server/static_config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import concurrent.futures
import json
import math
from pathlib import Path

from parallax.utils.utils import load_config_only
from parallax_utils.logging_config import get_logger
from scheduling.model_info import ModelInfo

Expand Down Expand Up @@ -33,6 +32,8 @@
"zai-org/GLM-4.7": "mlx-community/GLM-4.7-4bit",
"zai-org/GLM-4.7-Flash": "mlx-community/GLM-4.7-Flash-4bit",
"zai-org/GLM-4.7-Flash-FP8": "mlx-community/GLM-4.7-Flash-8bit",
"zai-org/GLM-5.1": "mlx-community/GLM-5.1",
"zai-org/GLM-5.1-FP8": "mlx-community/GLM-5.1",
# Minimax M2 Models
"MiniMaxAI/MiniMax-M2.7": "mlx-community/MiniMax-M2.7-4bit",
"MiniMaxAI/MiniMax-M2.1": "mlx-community/MiniMax-M2.1-4bit",
Expand Down Expand Up @@ -100,23 +101,7 @@


def get_model_info(model_name, use_hfcache: bool = False):
def _load_config_only(name: str) -> dict:
local_path = Path(name)
if local_path.exists():
config_path = local_path / "config.json"
with open(config_path, "r") as f:
return json.load(f)

# Hugging Face only – download just config.json
from huggingface_hub import hf_hub_download # type: ignore

config_file = hf_hub_download(
repo_id=name, filename="config.json", local_files_only=use_hfcache
)
with open(config_file, "r") as f:
return json.load(f)

config = _load_config_only(model_name)
config = load_config_only(model_name, local_files_only=use_hfcache)

quant_method = config.get("quant_method", None)
quantization_config = config.get("quantization_config", None)
Expand All @@ -139,7 +124,7 @@ def _load_config_only(name: str) -> dict:
mlx_model_name = MODELS.get(model_name, model_name)

if mlx_model_name != model_name:
mlx_config = _load_config_only(mlx_model_name)
mlx_config = load_config_only(mlx_model_name, local_files_only=use_hfcache)
mlx_quant_dict = mlx_config.get("quantization_config", None)
if mlx_quant_dict and "bits" in mlx_quant_dict:
mlx_param_bytes_per_element = mlx_quant_dict["bits"] / 8
Expand All @@ -161,6 +146,7 @@ def _load_config_only(name: str) -> dict:
head_size=config.get("head_dim", 128),
qk_nope_head_dim=config.get("qk_nope_head_dim", None),
qk_rope_head_dim=config.get("qk_rope_head_dim", None),
v_head_dim=config.get("v_head_dim", None),
hidden_dim=config.get("hidden_size", 0),
intermediate_dim=config.get("intermediate_size", 0),
num_attention_heads=config.get("num_attention_heads", 0),
Expand Down
4 changes: 2 additions & 2 deletions src/parallax/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from parallax.server.http_server import launch_http_server, stop_http_server
from parallax.server.server_args import parse_args
from parallax.utils.shared_state import SharedState
from parallax.utils.utils import fetch_model_from_hf, initialize_nccl_port
from parallax.utils.utils import initialize_nccl_port, load_config_only
from parallax_utils.ascii_anime import display_parallax_join
from parallax_utils.logging_config import get_logger, set_log_level
from parallax_utils.version_check import check_latest_release
Expand Down Expand Up @@ -119,7 +119,7 @@ def _wait_executors_check_layer_change(shared_state: SharedState, executor_subpr
display_parallax_join(args.model_path)
check_latest_release()

config = fetch_model_from_hf(args.model_path, local_files_only=args.use_hfcache)
config = load_config_only(args.model_path, local_files_only=args.use_hfcache)
if args.start_layer is None:
args.start_layer = 0
if args.end_layer is None:
Expand Down
16 changes: 10 additions & 6 deletions src/parallax/models/deepseek_v32.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,12 @@ def __call__(
scores = scores * weights
scores = scores.sum(axis=1)
if mask is not None:
scores = mx.where(mask, scores, -float("inf"))
if mask.ndim == 4:
mask = mask[:, 0, :, :]
if mask.dtype == mx.bool_:
scores = mx.where(mask, scores, -float("inf"))
else:
scores = scores + mask.astype(scores.dtype)
return mx.argpartition(scores, kth=-self.index_topk, axis=-1)[..., -self.index_topk :]


Expand Down Expand Up @@ -137,11 +142,10 @@ def __call__(
compressed_kv = self.kv_a_proj_with_mqa(x)
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
k_pe = k_pe.reshape(batch, target_len, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
kv = kv.reshape(batch, target_len, self.num_heads, -1)

k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
k_nope = k_nope.transpose(0, 2, 1, 3)
kv_latent = self.kv_a_layernorm(compressed_kv)
kv_latent = kv_latent[:, None, :, :]
k_nope = self.embed_q(kv_latent, transpose=False)
values = self.unembed_out(kv_latent).transpose(0, 2, 1, 3)
key_cache_global, value_cache_global = cache.get_cache()
indexer_cache = cache.get_indexer_cache()

Expand Down
8 changes: 8 additions & 0 deletions src/parallax/server/shard_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
"minimax_m2": "mlx_lm.models.minimax",
}

ARCHITECTURE_CLASS_ALIASES = {
"GlmMoeDsaForCausalLM": "DeepseekV32ForCausalLM",
}


class MLXModelLoader:
"""
Expand Down Expand Up @@ -93,6 +97,10 @@ def register_block_class(self):
except Exception as e:
logger.warning(f"Failed to load model from {model_file}: {e}")

for alias, target in ARCHITECTURE_CLASS_ALIASES.items():
if target in self.block_class_map:
self.block_class_map[alias] = self.block_class_map[target]

def linear_to_lora_layers(
self,
model: nn.Module,
Expand Down
3 changes: 0 additions & 3 deletions src/parallax/utils/selective_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,6 @@ def selective_model_download(
# Check if file already exists in local cache before downloading
weight_file_path = model_path / weight_file
if weight_file_path.exists():
logger.debug(
f"Weight file {weight_file} already exists locally, skipping download"
)
continue

logger.debug(f"Downloading {weight_file}")
Expand Down
30 changes: 19 additions & 11 deletions src/parallax/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
"""Utility functions."""

import json
import random
import socket
from pathlib import Path
from typing import List

import mlx.core as mx
import numpy as np
import psutil
import torch
import zmq
from mlx_lm.utils import _download, load_config

from parallax.utils.selective_download import download_metadata_only


def is_cuda_available():
Expand Down Expand Up @@ -281,15 +280,24 @@ def combine_padding_and_causal_masks(
return causal_mask + padding_mask_float


def fetch_model_from_hf(name: str, local_files_only: bool = False):
"""Fetch model from huggingface and returns model config"""

if local_files_only:
model_path = download_metadata_only(name, local_files_only=local_files_only)
def load_config_only(name: str, local_files_only: bool = False):
"""Load only config.json from a local path or Hugging Face repo."""
local_path = Path(name)
if local_path.exists():
config_file = local_path / "config.json"
else:
model_path = _download(name)
config = load_config(model_path)
return config
from huggingface_hub import hf_hub_download

config_file = Path(
hf_hub_download(
repo_id=name,
filename="config.json",
local_files_only=local_files_only,
)
)

with open(config_file, "r") as f:
return json.load(f)


def is_port_available(port: int):
Expand Down
89 changes: 89 additions & 0 deletions src/parallax_extensions/ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import importlib
import os
import shutil
import subprocess
import sys
import sysconfig
from pathlib import Path
from types import ModuleType
from typing import Optional
Expand Down Expand Up @@ -34,8 +38,93 @@ def _build_import_error(original_error: Exception) -> ImportError:
return ImportError(msg)


def _build_signature() -> str:
try:
from importlib.metadata import version

mlx_version = version("mlx")
nanobind_version = version("nanobind")
except Exception:
mlx_version = "unknown"
nanobind_version = "unknown"

return "|".join(
[
sys.implementation.cache_tag or "",
sys.version.split()[0],
mx.__file__,
mlx_version,
nanobind_version,
]
)


def _ensure_build_tools() -> None:
missing = []
try:
import setuptools # noqa: F401
except Exception:
missing.append("setuptools")

if shutil.which("cmake") is None:
missing.append("cmake")
if shutil.which("ninja") is None:
missing.append("ninja")

if missing:
subprocess.run(
[sys.executable, "-m", "pip", "install", *missing],
check=True,
)


def _rebuild_for_github_actions() -> None:
"""Build native kernels against the exact Python/MLX used by GitHub macOS CI."""
if os.environ.get("GITHUB_ACTIONS") != "true" or sys.platform != "darwin":
return
if os.environ.get("PARALLAX_SKIP_CI_EXTENSION_REBUILD") == "1":
return

package_dir = Path(__file__).resolve().parent
lib_dir = package_dir / "lib"
ext_suffix = sysconfig.get_config_var("EXT_SUFFIX") or ".so"
expected_ext = lib_dir / f"_ext{ext_suffix}"
stamp = lib_dir / f".ci-build-{sys.implementation.cache_tag or 'python'}"
signature = _build_signature()

if expected_ext.exists() and stamp.exists() and stamp.read_text() == signature:
return

_ensure_build_tools()

env = os.environ.copy()
env["DEBUG"] = "0"
cmake_args = env.get("CMAKE_ARGS", "")
python_arg = f"-DPython_EXECUTABLE={sys.executable}"
env["CMAKE_ARGS"] = f"{python_arg} {cmake_args}".strip()

log_path = Path("/tmp/parallax_ext_build.log")
with log_path.open("w") as log:
subprocess.run(
[sys.executable, "setup.py", "build_ext", "-j8", "--inplace"],
cwd=package_dir,
env=env,
stdout=log,
stderr=subprocess.STDOUT,
check=True,
)

stamp.write_text(signature)
print(f"Rebuilt parallax_extensions native kernels for CI: {expected_ext}")


def load_extension_module() -> ModuleType:
"""Load the compiled extension module for the current Python runtime."""
try:
_rebuild_for_github_actions()
except Exception as exc: # pragma: no cover - GitHub runner dependent
raise _build_import_error(exc) from exc

try:
# Python's import machinery selects the matching ABI-tagged binary
# (e.g. _ext.cpython-312-*.so) from parallax_extensions/lib.
Expand Down
4 changes: 3 additions & 1 deletion src/scheduling/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class ModelInfo:

qk_nope_head_dim: Optional[int] = None
qk_rope_head_dim: Optional[int] = None
v_head_dim: Optional[int] = None
head_size_k: int = None
head_size_v: int = None

Expand All @@ -55,7 +56,8 @@ def __init__(self, **kwargs):
self.head_size_k = self.qk_nope_head_dim + self.qk_rope_head_dim
else:
self.head_size_k = self.head_size
self.head_size_v = self.head_size
v_head_dim = getattr(self, "v_head_dim", None)
self.head_size_v = v_head_dim if v_head_dim is not None else self.head_size

@property
def q_dim(self) -> int:
Expand Down
23 changes: 23 additions & 0 deletions tests/scheduler_tests/test_model_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from scheduling.model_info import ModelInfo


def test_model_info_uses_distinct_value_head_dim():
model_info = ModelInfo(
model_name="zai-org/GLM-5.1",
mlx_model_name="mlx-community/GLM-5.1",
head_size=64,
qk_nope_head_dim=192,
qk_rope_head_dim=64,
v_head_dim=256,
hidden_dim=6144,
intermediate_dim=12288,
num_attention_heads=64,
num_kv_heads=64,
vocab_size=154880,
num_layers=78,
cache_bytes_per_element=2,
)

assert model_info.head_size_k == 256
assert model_info.head_size_v == 256
assert model_info.per_token_per_layer_kv_size == 2 * 64 * (256 + 256)
Loading
Loading