Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"numpy>=1.26",
"pyzmq>=25.0",
"psutil>=5.9.5",
"requests",
"httpx[socks]>=0.26.0",
"aiohttp",
"uvicorn",
Expand Down
31 changes: 19 additions & 12 deletions src/backend/server/static_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@
"Qwen/Qwen3-Next-80B-A3B-Instruct-FP8": "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit",
"Qwen/Qwen3-Next-80B-A3B-Thinking": "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit",
"Qwen/Qwen3-Next-80B-A3B-Thinking-FP8": "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit",
# Qwen 3.6 Series
"Qwen/Qwen3.6-27B": "mlx-community/Qwen3.6-27B-mxfp4",
# Qwen 3 Large MoE Models
"Qwen/Qwen3-235B-A22B-Instruct-2507-FP8": "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit",
"Qwen/Qwen3-235B-A22B-Thinking-2507-FP8": "mlx-community/Qwen3-235B-A22B-Thinking-2507-8bit",
Expand All @@ -100,34 +102,39 @@
NODE_JOIN_COMMAND_PUBLIC_NETWORK = """parallax join -s {scheduler_addr} """


def get_model_info(model_name, use_hfcache: bool = False):
config = load_config_only(model_name, local_files_only=use_hfcache)

def get_param_bytes_per_element(config, model_name: str) -> float:
quant_method = config.get("quant_method", None)
quantization_config = config.get("quantization_config", None)
quantization_config = config.get("quantization_config") or config.get("quantization")
if quant_method is None and quantization_config is not None:
quant_method = quantization_config.get("quant_method", None)
quant_method = quantization_config.get("quant_method") or quantization_config.get("mode")

if quantization_config is not None and quantization_config.get("bits") is not None:
return quantization_config["bits"] / 8

if quant_method is None:
param_bytes_per_element = 2
return 2
elif quant_method == "fp8":
param_bytes_per_element = 1
return 1
elif quant_method in ("mxfp4", "int4", "awq", "gptq", "compressed-tensors"):
param_bytes_per_element = 0.5
return 0.5
else:
param_bytes_per_element = 1
logger.warning(
f"model_name:{model_name} quant_method {quant_method} not supported in get_model_info method"
)
return 1


def get_model_info(model_name, use_hfcache: bool = False):
config = load_config_only(model_name, local_files_only=use_hfcache)

param_bytes_per_element = get_param_bytes_per_element(config, model_name)

mlx_param_bytes_per_element = param_bytes_per_element
mlx_model_name = MODELS.get(model_name, model_name)

if mlx_model_name != model_name:
mlx_config = load_config_only(mlx_model_name, local_files_only=use_hfcache)
mlx_quant_dict = mlx_config.get("quantization_config", None)
if mlx_quant_dict and "bits" in mlx_quant_dict:
mlx_param_bytes_per_element = mlx_quant_dict["bits"] / 8
mlx_param_bytes_per_element = get_param_bytes_per_element(mlx_config, mlx_model_name)

# get local experts
num_local_experts = config.get("num_local_experts", None)
Expand Down
123 changes: 123 additions & 0 deletions src/parallax/models/qwen3_5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
Defines the Qwen3.5 text block for Parallax.
"""

from typing import Any, List, Optional

import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.gated_delta import gated_delta_update
from mlx_lm.models.qwen3_5 import DecoderLayer as MLXQwen35Block
from mlx_lm.models.qwen3_5 import GatedDeltaNet as MLXQwen35GatedDeltaNet
from mlx_lm.models.qwen3_5 import TextModelArgs

from parallax.models.qwen3_next import ParallaxQwen3NextAttention
from parallax.server.cache.base import BaseCache


class ParallaxQwen35GatedDeltaNet(MLXQwen35GatedDeltaNet):
def __call__(
self,
x: mx.array,
cache: Optional[BaseCache] = None,
state_slot_mapping: Optional[mx.array] = None,
**kwargs,
):
batch, target_len, _ = x.shape

qkv = self.in_proj_qkv(x)
z = self.in_proj_z(x).reshape(batch, target_len, self.num_v_heads, self.head_v_dim)
b = self.in_proj_b(x)
a = self.in_proj_a(x)

if target_len == 1:
conv_state, state = cache.read_states(state_slot_mapping)
else:
conv_state = mx.zeros(
(batch, self.conv_kernel_size - 1, self.conv_dim),
dtype=x.dtype,
)
state = None

conv_input = mx.concatenate([conv_state, qkv], axis=1)
next_conv_state = conv_input[:, -(self.conv_kernel_size - 1) :]
conv_out = nn.silu(self.conv1d(conv_input))

q, k, v = [
t.reshape(batch, target_len, h, d)
for t, h, d in zip(
mx.split(conv_out, [self.key_dim, 2 * self.key_dim], -1),
[self.num_k_heads, self.num_k_heads, self.num_v_heads],
[self.head_k_dim, self.head_k_dim, self.head_v_dim],
)
]

inv_scale = k.shape[-1] ** -0.5
q = (inv_scale**2) * mx.fast.rms_norm(q, None, 1e-6)
k = inv_scale * mx.fast.rms_norm(k, None, 1e-6)

out, state = gated_delta_update(
q,
k,
v,
a,
b,
self.A_log,
self.dt_bias,
state,
use_kernel=not self.training,
)

cache.write_states(state_slot_mapping, next_conv_state, state)

out = self.norm(out, z)
return self.out_proj(out.reshape(batch, target_len, -1))


class ParallaxQwen35Block(MLXQwen35Block):
def __init__(self, args: TextModelArgs, layer_idx: int, local_layer_idx: int):
super().__init__(args, layer_idx)
self.layer_idx = layer_idx
self.local_layer_idx = local_layer_idx
if self.is_linear:
self.linear_attn = ParallaxQwen35GatedDeltaNet(args)
else:
self.self_attn = ParallaxQwen3NextAttention(args)

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[List[Any]] = None,
block_tables: Optional[mx.array] = None,
context_lengths: Optional[mx.array] = None,
slot_mapping: Optional[mx.array] = None,
**kwargs,
):
if self.is_linear:
state_slot_mapping = kwargs.pop("state_slot_mapping", None)
r = self.linear_attn(
self.input_layernorm(x),
cache[self.local_layer_idx],
state_slot_mapping,
**kwargs,
)
else:
r = self.self_attn(
self.input_layernorm(x),
mask,
cache[self.local_layer_idx],
block_tables=block_tables,
context_lengths=context_lengths,
slot_mapping=slot_mapping,
**kwargs,
)
h = x + r
return h + self.mlp(self.post_attention_layernorm(h))

@classmethod
def get_architecture(cls):
return "Qwen3_5ForConditionalGeneration"


EntryClass = ParallaxQwen35Block
53 changes: 27 additions & 26 deletions src/parallax/server/executor/mlx_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,33 +598,34 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A
else:
padded_inputs, padding_mask = pad_inputs(0, h_or_tokens_list, self.dtype)

# Generate slot_mapping for prefill (only for NEW tokens, starting from prefix_len)
max_len = padded_inputs.shape[1]
slot_mapping_flat = []

for i, req in enumerate(batched_requests):
block_table = block_tables_list[i]
prefix_len = prefix_lens_list[i]
total_len = req.total_length
new_tokens_len = total_len - prefix_len

for seq_idx in range(max_len):
if seq_idx < new_tokens_len:
# Valid new token - map to position after prefix
actual_pos = prefix_len + seq_idx
block_idx = actual_pos // self.cache_manager.block_size
block_offset = actual_pos % self.cache_manager.block_size
physical_block = block_table[block_idx]
slot = physical_block * self.cache_manager.block_size + block_offset
slot_mapping_flat.append(slot)
else:
# Padding token
# Map to -1. The kernel should ignore this.
slot_mapping_flat.append(-1)
slot_mapping_tensor = None
if self.cache_manager.needs_blocks:
# Generate slot_mapping for prefill (only for NEW tokens, starting from prefix_len)
max_len = padded_inputs.shape[1]
slot_mapping_flat = []

for i, req in enumerate(batched_requests):
block_table = block_tables_list[i]
prefix_len = prefix_lens_list[i]
total_len = req.total_length
new_tokens_len = total_len - prefix_len

for seq_idx in range(max_len):
if seq_idx < new_tokens_len:
# Valid new token - map to position after prefix
actual_pos = prefix_len + seq_idx
block_idx = actual_pos // self.cache_manager.block_size
block_offset = actual_pos % self.cache_manager.block_size
physical_block = block_table[block_idx]
slot = physical_block * self.cache_manager.block_size + block_offset
slot_mapping_flat.append(slot)
else:
# Padding token. The kernel should ignore this.
slot_mapping_flat.append(-1)

slot_mapping_tensor = mx.array(slot_mapping_flat, dtype=mx.int64)
slot_mapping_tensor = mx.array(slot_mapping_flat, dtype=mx.int64)

# Pad block tables
# Pad block tables. Linear-only shards do not allocate KV blocks.
max_blocks = max(len(bt) for bt in block_tables_list)
padded_block_tables = []
for bt in block_tables_list:
Expand Down Expand Up @@ -737,7 +738,7 @@ def _prepare_decode_batch(self, batched_requests: List[Request]) -> Optional[Dic
padded_inputs = mx.concatenate(h_or_tokens_list, axis=0) # (Batch, D)
padded_inputs = padded_inputs.reshape(batch_size, 1, -1) # (Batch, 1, D)

# Pad block tables
# Pad block tables. Linear-only shards do not allocate KV blocks.
max_blocks = max(len(bt) for bt in block_tables_list)
padded_block_tables = []
for bt in block_tables_list:
Expand Down
31 changes: 24 additions & 7 deletions src/parallax/server/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import uvicorn
import zmq
import zmq.asyncio
from fastapi.responses import ORJSONResponse, StreamingResponse
from fastapi.responses import JSONResponse, StreamingResponse
from mlx_lm.tokenizer_utils import StreamingDetokenizer
from mlx_lm.utils import load_config
from pydantic import BaseModel
Expand Down Expand Up @@ -101,6 +101,7 @@ class HTTPRequestInfo:
# tool calling support
tool_state: Optional[ToolCallState] = None
tool_calls: List[Dict[str, Any]] = field(default_factory=list)
enable_thinking: bool = True


class HTTPHandler:
Expand Down Expand Up @@ -137,12 +138,29 @@ def __init__(
self.tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None))
self.detokenizer_class, self.tokenmap = load_detokenizer(model_path, self.tokenizer)

@staticmethod
def _is_thinking_enabled(request: Dict) -> bool:
chat_template_kwargs = dict(request.get("chat_template_kwargs", {}))
extra_body = request.get("extra_body")
if isinstance(extra_body, dict) and "chat_template_kwargs" in extra_body:
chat_template_kwargs.update(extra_body["chat_template_kwargs"])
return chat_template_kwargs.get("enable_thinking") is not False

def _get_initial_assistant_content(self, request_info: HTTPRequestInfo) -> str:
model_path = self.model_path_str.lower()
if "minimax-m2" in model_path:
return "<think>"
if ("qwen3.6" in model_path or "qwen3.5" in model_path) and request_info.enable_thinking:
return "<think>"
return ""

def create_request(self, request: Dict):
"""Creates a new request information"""
rid = request["rid"]
stream = request.get("stream", False)
model = request.get("model", "default")
return_probs = request.get("return_probs", False) # Check if probs requested
enable_thinking = self._is_thinking_enabled(request)
chat_object = "chat.completion.chunk" if stream else "chat.completion"
detokenizer = self.detokenizer_class(self.tokenizer, self.tokenmap)
create_time = time.time()
Expand All @@ -156,6 +174,7 @@ def create_request(self, request: Dict):
update_time=update_time,
detokenizer=detokenizer,
return_probs=return_probs,
enable_thinking=enable_thinking,
)
request_info.tool_state = ToolCallState.from_tokenizer(
self.tokenizer, request.get("tools"), stream
Expand Down Expand Up @@ -206,9 +225,7 @@ def _generate_stream_chunk(self, rid, token, is_first=False, is_last=False):

if is_first:
role = "assistant"
content = ""
if "minimax-m2" in self.model_path_str.lower():
content = "<think>"
content = self._get_initial_assistant_content(request_info)
tool_calls = None
elif is_last:
role = None
Expand Down Expand Up @@ -318,7 +335,7 @@ def generate_non_stream_response(self, rid):
choice = response["choices"][0]
choice["message"] = {
"role": "assistant",
"content": request_info.text,
"content": self._get_initial_assistant_content(request_info) + request_info.text,
"reasoning_content": None,
"tool_calls": request_info.tool_calls or None,
}
Expand Down Expand Up @@ -464,7 +481,7 @@ def create_error_response(
):
"""Creates a json error response for the frontend."""
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
return ORJSONResponse(content=error.model_dump(), status_code=error.code)
return JSONResponse(content=error.model_dump(), status_code=error.code)


# Fast API
Expand Down Expand Up @@ -548,7 +565,7 @@ async def v1_chat_completions(raw_request: fastapi.Request):

response = app.state.http_handler.generate_non_stream_response(request_id)
app.state.http_handler.release_request(request_id)
return ORJSONResponse(status_code=200, content=response)
return JSONResponse(status_code=200, content=response)
except Exception as e:
# Handle any unexpected errors during processing
logger.error(f"Error processing non-streaming request {request_id}: {e}")
Expand Down
Loading
Loading