diff --git a/pyproject.toml b/pyproject.toml index 0bccc6b5..ee8ed28b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "numpy>=1.26", "pyzmq>=25.0", "psutil>=5.9.5", + "requests", "httpx[socks]>=0.26.0", "aiohttp", "uvicorn", diff --git a/src/backend/server/static_config.py b/src/backend/server/static_config.py index e7af2945..beb73891 100644 --- a/src/backend/server/static_config.py +++ b/src/backend/server/static_config.py @@ -84,6 +84,8 @@ "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8": "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit", "Qwen/Qwen3-Next-80B-A3B-Thinking": "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit", "Qwen/Qwen3-Next-80B-A3B-Thinking-FP8": "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit", + # Qwen 3.6 Series + "Qwen/Qwen3.6-27B": "mlx-community/Qwen3.6-27B-mxfp4", # Qwen 3 Large MoE Models "Qwen/Qwen3-235B-A22B-Instruct-2507-FP8": "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit", "Qwen/Qwen3-235B-A22B-Thinking-2507-FP8": "mlx-community/Qwen3-235B-A22B-Thinking-2507-8bit", @@ -100,34 +102,39 @@ NODE_JOIN_COMMAND_PUBLIC_NETWORK = """parallax join -s {scheduler_addr} """ -def get_model_info(model_name, use_hfcache: bool = False): - config = load_config_only(model_name, local_files_only=use_hfcache) - +def get_param_bytes_per_element(config, model_name: str) -> float: quant_method = config.get("quant_method", None) - quantization_config = config.get("quantization_config", None) + quantization_config = config.get("quantization_config") or config.get("quantization") if quant_method is None and quantization_config is not None: - quant_method = quantization_config.get("quant_method", None) + quant_method = quantization_config.get("quant_method") or quantization_config.get("mode") + + if quantization_config is not None and quantization_config.get("bits") is not None: + return quantization_config["bits"] / 8 if quant_method is None: - param_bytes_per_element = 2 + return 2 elif quant_method == "fp8": - param_bytes_per_element = 1 + return 1 elif quant_method in ("mxfp4", "int4", "awq", "gptq", "compressed-tensors"): - param_bytes_per_element = 0.5 + return 0.5 else: - param_bytes_per_element = 1 logger.warning( f"model_name:{model_name} quant_method {quant_method} not supported in get_model_info method" ) + return 1 + + +def get_model_info(model_name, use_hfcache: bool = False): + config = load_config_only(model_name, local_files_only=use_hfcache) + + param_bytes_per_element = get_param_bytes_per_element(config, model_name) mlx_param_bytes_per_element = param_bytes_per_element mlx_model_name = MODELS.get(model_name, model_name) if mlx_model_name != model_name: mlx_config = load_config_only(mlx_model_name, local_files_only=use_hfcache) - mlx_quant_dict = mlx_config.get("quantization_config", None) - if mlx_quant_dict and "bits" in mlx_quant_dict: - mlx_param_bytes_per_element = mlx_quant_dict["bits"] / 8 + mlx_param_bytes_per_element = get_param_bytes_per_element(mlx_config, mlx_model_name) # get local experts num_local_experts = config.get("num_local_experts", None) diff --git a/src/parallax/models/qwen3_5.py b/src/parallax/models/qwen3_5.py new file mode 100644 index 00000000..497bcc22 --- /dev/null +++ b/src/parallax/models/qwen3_5.py @@ -0,0 +1,123 @@ +""" +Defines the Qwen3.5 text block for Parallax. +""" + +from typing import Any, List, Optional + +import mlx.core as mx +import mlx.nn as nn +from mlx_lm.models.gated_delta import gated_delta_update +from mlx_lm.models.qwen3_5 import DecoderLayer as MLXQwen35Block +from mlx_lm.models.qwen3_5 import GatedDeltaNet as MLXQwen35GatedDeltaNet +from mlx_lm.models.qwen3_5 import TextModelArgs + +from parallax.models.qwen3_next import ParallaxQwen3NextAttention +from parallax.server.cache.base import BaseCache + + +class ParallaxQwen35GatedDeltaNet(MLXQwen35GatedDeltaNet): + def __call__( + self, + x: mx.array, + cache: Optional[BaseCache] = None, + state_slot_mapping: Optional[mx.array] = None, + **kwargs, + ): + batch, target_len, _ = x.shape + + qkv = self.in_proj_qkv(x) + z = self.in_proj_z(x).reshape(batch, target_len, self.num_v_heads, self.head_v_dim) + b = self.in_proj_b(x) + a = self.in_proj_a(x) + + if target_len == 1: + conv_state, state = cache.read_states(state_slot_mapping) + else: + conv_state = mx.zeros( + (batch, self.conv_kernel_size - 1, self.conv_dim), + dtype=x.dtype, + ) + state = None + + conv_input = mx.concatenate([conv_state, qkv], axis=1) + next_conv_state = conv_input[:, -(self.conv_kernel_size - 1) :] + conv_out = nn.silu(self.conv1d(conv_input)) + + q, k, v = [ + t.reshape(batch, target_len, h, d) + for t, h, d in zip( + mx.split(conv_out, [self.key_dim, 2 * self.key_dim], -1), + [self.num_k_heads, self.num_k_heads, self.num_v_heads], + [self.head_k_dim, self.head_k_dim, self.head_v_dim], + ) + ] + + inv_scale = k.shape[-1] ** -0.5 + q = (inv_scale**2) * mx.fast.rms_norm(q, None, 1e-6) + k = inv_scale * mx.fast.rms_norm(k, None, 1e-6) + + out, state = gated_delta_update( + q, + k, + v, + a, + b, + self.A_log, + self.dt_bias, + state, + use_kernel=not self.training, + ) + + cache.write_states(state_slot_mapping, next_conv_state, state) + + out = self.norm(out, z) + return self.out_proj(out.reshape(batch, target_len, -1)) + + +class ParallaxQwen35Block(MLXQwen35Block): + def __init__(self, args: TextModelArgs, layer_idx: int, local_layer_idx: int): + super().__init__(args, layer_idx) + self.layer_idx = layer_idx + self.local_layer_idx = local_layer_idx + if self.is_linear: + self.linear_attn = ParallaxQwen35GatedDeltaNet(args) + else: + self.self_attn = ParallaxQwen3NextAttention(args) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[List[Any]] = None, + block_tables: Optional[mx.array] = None, + context_lengths: Optional[mx.array] = None, + slot_mapping: Optional[mx.array] = None, + **kwargs, + ): + if self.is_linear: + state_slot_mapping = kwargs.pop("state_slot_mapping", None) + r = self.linear_attn( + self.input_layernorm(x), + cache[self.local_layer_idx], + state_slot_mapping, + **kwargs, + ) + else: + r = self.self_attn( + self.input_layernorm(x), + mask, + cache[self.local_layer_idx], + block_tables=block_tables, + context_lengths=context_lengths, + slot_mapping=slot_mapping, + **kwargs, + ) + h = x + r + return h + self.mlp(self.post_attention_layernorm(h)) + + @classmethod + def get_architecture(cls): + return "Qwen3_5ForConditionalGeneration" + + +EntryClass = ParallaxQwen35Block diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index c3ab5be3..8063b12f 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -598,33 +598,34 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A else: padded_inputs, padding_mask = pad_inputs(0, h_or_tokens_list, self.dtype) - # Generate slot_mapping for prefill (only for NEW tokens, starting from prefix_len) - max_len = padded_inputs.shape[1] - slot_mapping_flat = [] - - for i, req in enumerate(batched_requests): - block_table = block_tables_list[i] - prefix_len = prefix_lens_list[i] - total_len = req.total_length - new_tokens_len = total_len - prefix_len - - for seq_idx in range(max_len): - if seq_idx < new_tokens_len: - # Valid new token - map to position after prefix - actual_pos = prefix_len + seq_idx - block_idx = actual_pos // self.cache_manager.block_size - block_offset = actual_pos % self.cache_manager.block_size - physical_block = block_table[block_idx] - slot = physical_block * self.cache_manager.block_size + block_offset - slot_mapping_flat.append(slot) - else: - # Padding token - # Map to -1. The kernel should ignore this. - slot_mapping_flat.append(-1) + slot_mapping_tensor = None + if self.cache_manager.needs_blocks: + # Generate slot_mapping for prefill (only for NEW tokens, starting from prefix_len) + max_len = padded_inputs.shape[1] + slot_mapping_flat = [] + + for i, req in enumerate(batched_requests): + block_table = block_tables_list[i] + prefix_len = prefix_lens_list[i] + total_len = req.total_length + new_tokens_len = total_len - prefix_len + + for seq_idx in range(max_len): + if seq_idx < new_tokens_len: + # Valid new token - map to position after prefix + actual_pos = prefix_len + seq_idx + block_idx = actual_pos // self.cache_manager.block_size + block_offset = actual_pos % self.cache_manager.block_size + physical_block = block_table[block_idx] + slot = physical_block * self.cache_manager.block_size + block_offset + slot_mapping_flat.append(slot) + else: + # Padding token. The kernel should ignore this. + slot_mapping_flat.append(-1) - slot_mapping_tensor = mx.array(slot_mapping_flat, dtype=mx.int64) + slot_mapping_tensor = mx.array(slot_mapping_flat, dtype=mx.int64) - # Pad block tables + # Pad block tables. Linear-only shards do not allocate KV blocks. max_blocks = max(len(bt) for bt in block_tables_list) padded_block_tables = [] for bt in block_tables_list: @@ -737,7 +738,7 @@ def _prepare_decode_batch(self, batched_requests: List[Request]) -> Optional[Dic padded_inputs = mx.concatenate(h_or_tokens_list, axis=0) # (Batch, D) padded_inputs = padded_inputs.reshape(batch_size, 1, -1) # (Batch, 1, D) - # Pad block tables + # Pad block tables. Linear-only shards do not allocate KV blocks. max_blocks = max(len(bt) for bt in block_tables_list) padded_block_tables = [] for bt in block_tables_list: diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index ad3ca173..e888da22 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -28,7 +28,7 @@ import uvicorn import zmq import zmq.asyncio -from fastapi.responses import ORJSONResponse, StreamingResponse +from fastapi.responses import JSONResponse, StreamingResponse from mlx_lm.tokenizer_utils import StreamingDetokenizer from mlx_lm.utils import load_config from pydantic import BaseModel @@ -101,6 +101,7 @@ class HTTPRequestInfo: # tool calling support tool_state: Optional[ToolCallState] = None tool_calls: List[Dict[str, Any]] = field(default_factory=list) + enable_thinking: bool = True class HTTPHandler: @@ -137,12 +138,29 @@ def __init__( self.tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) self.detokenizer_class, self.tokenmap = load_detokenizer(model_path, self.tokenizer) + @staticmethod + def _is_thinking_enabled(request: Dict) -> bool: + chat_template_kwargs = dict(request.get("chat_template_kwargs", {})) + extra_body = request.get("extra_body") + if isinstance(extra_body, dict) and "chat_template_kwargs" in extra_body: + chat_template_kwargs.update(extra_body["chat_template_kwargs"]) + return chat_template_kwargs.get("enable_thinking") is not False + + def _get_initial_assistant_content(self, request_info: HTTPRequestInfo) -> str: + model_path = self.model_path_str.lower() + if "minimax-m2" in model_path: + return "" + if ("qwen3.6" in model_path or "qwen3.5" in model_path) and request_info.enable_thinking: + return "" + return "" + def create_request(self, request: Dict): """Creates a new request information""" rid = request["rid"] stream = request.get("stream", False) model = request.get("model", "default") return_probs = request.get("return_probs", False) # Check if probs requested + enable_thinking = self._is_thinking_enabled(request) chat_object = "chat.completion.chunk" if stream else "chat.completion" detokenizer = self.detokenizer_class(self.tokenizer, self.tokenmap) create_time = time.time() @@ -156,6 +174,7 @@ def create_request(self, request: Dict): update_time=update_time, detokenizer=detokenizer, return_probs=return_probs, + enable_thinking=enable_thinking, ) request_info.tool_state = ToolCallState.from_tokenizer( self.tokenizer, request.get("tools"), stream @@ -206,9 +225,7 @@ def _generate_stream_chunk(self, rid, token, is_first=False, is_last=False): if is_first: role = "assistant" - content = "" - if "minimax-m2" in self.model_path_str.lower(): - content = "" + content = self._get_initial_assistant_content(request_info) tool_calls = None elif is_last: role = None @@ -318,7 +335,7 @@ def generate_non_stream_response(self, rid): choice = response["choices"][0] choice["message"] = { "role": "assistant", - "content": request_info.text, + "content": self._get_initial_assistant_content(request_info) + request_info.text, "reasoning_content": None, "tool_calls": request_info.tool_calls or None, } @@ -464,7 +481,7 @@ def create_error_response( ): """Creates a json error response for the frontend.""" error = ErrorResponse(message=message, type=err_type, code=status_code.value) - return ORJSONResponse(content=error.model_dump(), status_code=error.code) + return JSONResponse(content=error.model_dump(), status_code=error.code) # Fast API @@ -548,7 +565,7 @@ async def v1_chat_completions(raw_request: fastapi.Request): response = app.state.http_handler.generate_non_stream_response(request_id) app.state.http_handler.release_request(request_id) - return ORJSONResponse(status_code=200, content=response) + return JSONResponse(status_code=200, content=response) except Exception as e: # Handle any unexpected errors during processing logger.error(f"Error processing non-streaming request {request_id}: {e}") diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index 75d09004..233e8954 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -20,6 +20,7 @@ from parallax.server.model import ShardedModel from parallax.utils.tokenizer_utils import load_tokenizer +from parallax.utils.utils import normalize_model_config from parallax_utils.logging_config import get_logger logger = get_logger(__name__) @@ -33,6 +34,8 @@ "GlmMoeDsaForCausalLM": "DeepseekV32ForCausalLM", } +LANGUAGE_MODEL_PREFIX = "language_model." + class MLXModelLoader: """ @@ -249,7 +252,7 @@ def load( else: model_path = _download(self.model_path_str) - config = load_config(model_path) + config = normalize_model_config(load_config(model_path)) tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) architectures = config.get("architectures", None) @@ -279,7 +282,10 @@ def load( try: arch_module = importlib.import_module(model_class) - model_args_class = getattr(arch_module, "ModelArgs") + if model_type == "qwen3_5" and hasattr(arch_module, "TextModelArgs"): + model_args_class = getattr(arch_module, "TextModelArgs") + else: + model_args_class = getattr(arch_module, "ModelArgs") model_args = model_args_class.from_dict(config) except (ImportError, AttributeError) as e: @@ -341,38 +347,43 @@ def load( for key in f.keys(): is_needed = False remapped_key = None + model_key = ( + key[len(LANGUAGE_MODEL_PREFIX) :] + if key.startswith(LANGUAGE_MODEL_PREFIX) + else key + ) # Check if the key belongs to the shard and remap it if ( model_shard.is_first_shard - and "embed_tokens" in key - and key.startswith("model.") + and "embed_tokens" in model_key + and model_key.startswith("model.") ): is_needed = True - remapped_key = key.replace("model.", "", 1) + remapped_key = model_key.replace("model.", "", 1) if model_shard.is_last_shard and config.get("tie_word_embeddings", False): # Also add lm_head mapping for tied embeddings lm_head_key = remapped_key.replace("embed_tokens", "lm_head") shard_weights[lm_head_key] = f[key] elif model_shard.is_last_shard: - if "model.norm" in key: + if "model.norm" in model_key: is_needed = True - remapped_key = key.replace("model.", "", 1) - if "lm_head" in key: + remapped_key = model_key.replace("model.", "", 1) + if "lm_head" in model_key: is_needed = True - remapped_key = key + remapped_key = model_key elif ( config.get("tie_word_embeddings", False) - and "embed_tokens" in key - and key.startswith("model.embed_tokens") + and "embed_tokens" in model_key + and model_key.startswith("model.embed_tokens") ): is_needed = True - remapped_key = key.replace("model.", "", 1).replace( + remapped_key = model_key.replace("model.", "", 1).replace( "embed_tokens", "lm_head" ) - if layer_key_prefix in key: + if layer_key_prefix in model_key: try: - parts = key.split(".") + parts = model_key.split(".") layer_idx = int(parts[2]) if current_start_layer <= layer_idx < current_end_layer: is_needed = True diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 6c17c1f5..bcae9fc4 100644 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -42,6 +42,7 @@ set_layer_range_for_filtering, ) from parallax.utils.tokenizer_utils import load_tokenizer +from parallax.utils.utils import normalize_model_config logger = logging.getLogger(__name__) @@ -311,7 +312,7 @@ def initialize_sgl_model_runner( model_repo, start_layer=start_layer, end_layer=end_layer, local_files_only=use_hfcache ) - config = load_config(model_path) + config = normalize_model_config(load_config(model_path)) tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) dtype = config.get("torch_dtype") or "bfloat16" @@ -367,11 +368,17 @@ def initialize_sgl_model_runner( # (multi-node PP where this node doesn't have both embed_tokens and lm_head). # For single-node or full-range runs, keep the original setting so that # lm_head correctly shares weights with embed_tokens. - num_hidden_layers = model_config.hf_config.num_hidden_layers + normalized_config = normalize_model_config(model_config.hf_config.to_dict()) + num_hidden_layers = normalized_config["num_hidden_layers"] + model_config.hf_config.num_hidden_layers = num_hidden_layers if start_layer > 0 or end_layer < num_hidden_layers: model_config.hf_config.tie_word_embeddings = False model_config.hf_config.start_layer = start_layer model_config.hf_config.end_layer = end_layer + if hasattr(model_config.hf_config, "text_config"): + model_config.hf_config.text_config.num_hidden_layers = num_hidden_layers + model_config.hf_config.text_config.start_layer = start_layer + model_config.hf_config.text_config.end_layer = end_layer logger.debug(f"model_start_layer: {model_config.hf_config.start_layer}") logger.debug(f"model_end_layer: {model_config.hf_config.end_layer}") diff --git a/src/parallax/utils/utils.py b/src/parallax/utils/utils.py index 63c13a76..44f6c76e 100644 --- a/src/parallax/utils/utils.py +++ b/src/parallax/utils/utils.py @@ -297,7 +297,21 @@ def load_config_only(name: str, local_files_only: bool = False): ) with open(config_file, "r") as f: - return json.load(f) + return normalize_model_config(json.load(f)) + + +def normalize_model_config(config: dict) -> dict: + """Expose nested text model fields at the top level for VLM-style configs.""" + text_config = config.get("text_config") + if config.get("model_type") == "qwen3_5" and isinstance(text_config, dict): + normalized = {**config, **text_config} + normalized["model_type"] = config["model_type"] + normalized["architectures"] = config.get("architectures", normalized.get("architectures")) + normalized["tie_word_embeddings"] = text_config.get( + "tie_word_embeddings", config.get("tie_word_embeddings", False) + ) + return normalized + return config def is_port_available(port: int): diff --git a/tests/test_http_handler.py b/tests/test_http_handler.py index c5598e4e..3cc26f63 100644 --- a/tests/test_http_handler.py +++ b/tests/test_http_handler.py @@ -1,4 +1,5 @@ import asyncio +import json from http import HTTPStatus try: @@ -26,6 +27,111 @@ def is_available(): from parallax.server.http_server import HTTPHandler, HTTPRequestInfo +def _decode_sse_json(chunk: bytes): + line = chunk.decode().strip() + assert line.startswith("data: ") + return json.loads(line[len("data: ") :]) + + +def test_qwen3_6_stream_first_chunk_includes_think_marker(): + handler = HTTPHandler.__new__(HTTPHandler) + handler.model_path_str = "mlx-community/Qwen3.6-27B-mxfp4" + handler.processing_requests = {} + + rid = "req-qwen36-stream" + handler.processing_requests[rid] = HTTPRequestInfo( + id=rid, + stream=True, + model="test-model", + ) + + payload = _decode_sse_json(handler._generate_stream_chunk(rid, None, is_first=True)) + + assert payload["choices"][0]["delta"]["role"] == "assistant" + assert payload["choices"][0]["delta"]["content"] == "" + + +def test_qwen3_6_stream_first_chunk_respects_disable_thinking(): + handler = HTTPHandler.__new__(HTTPHandler) + handler.model_path_str = "mlx-community/Qwen3.6-27B-mxfp4" + handler.processing_requests = {} + + rid = "req-qwen36-no-thinking-stream" + handler.processing_requests[rid] = HTTPRequestInfo( + id=rid, + stream=True, + model="test-model", + enable_thinking=False, + ) + + payload = _decode_sse_json(handler._generate_stream_chunk(rid, None, is_first=True)) + + assert payload["choices"][0]["delta"]["role"] == "assistant" + assert payload["choices"][0]["delta"]["content"] == "" + + +def test_http_handler_thinking_enabled_uses_extra_body_chat_template_kwargs(): + assert ( + HTTPHandler._is_thinking_enabled( + { + "messages": [{"role": "user", "content": "hi"}], + "chat_template_kwargs": {"enable_thinking": True}, + "sampling_params": {"top_k": 3}, + } + ) + is True + ) + assert ( + HTTPHandler._is_thinking_enabled( + { + "chat_template_kwargs": {"enable_thinking": True}, + "extra_body": {"chat_template_kwargs": {"enable_thinking": False}}, + } + ) + is False + ) + + +def test_qwen3_6_non_stream_response_includes_think_marker_when_enabled(): + handler = HTTPHandler.__new__(HTTPHandler) + handler.model_path_str = "mlx-community/Qwen3.6-27B-mxfp4" + handler.processing_requests = {} + + rid = "req-qwen36-non-stream" + request_info = HTTPRequestInfo( + id=rid, + stream=False, + model="test-model", + enable_thinking=True, + ) + request_info.text = "reasoning" + handler.processing_requests[rid] = request_info + + payload = handler.generate_non_stream_response(rid) + + assert payload["choices"][0]["message"]["content"] == "reasoning" + + +def test_qwen3_6_non_stream_response_respects_disable_thinking(): + handler = HTTPHandler.__new__(HTTPHandler) + handler.model_path_str = "mlx-community/Qwen3.6-27B-mxfp4" + handler.processing_requests = {} + + rid = "req-qwen36-no-thinking-non-stream" + request_info = HTTPRequestInfo( + id=rid, + stream=False, + model="test-model", + enable_thinking=False, + ) + request_info.text = "answer" + handler.processing_requests[rid] = request_info + + payload = handler.generate_non_stream_response(rid) + + assert payload["choices"][0]["message"]["content"] == "answer" + + def test_http_handler_marks_non_stream_error(): async def scenario(): handler = HTTPHandler.__new__(HTTPHandler) diff --git a/tests/test_static_config.py b/tests/test_static_config.py index a50ce6b8..24e60b8d 100644 --- a/tests/test_static_config.py +++ b/tests/test_static_config.py @@ -1,5 +1,49 @@ -from backend.server.static_config import MODELS +from backend.server import static_config +from backend.server.static_config import MODELS, get_model_info +from parallax.utils.utils import normalize_model_config def test_glm_5_1_uses_mlx_community_model(): assert MODELS["zai-org/GLM-5.1"] == "mlx-community/GLM-5.1" + + +def test_qwen3_6_mxfp4_is_scheduler_supported(): + assert MODELS["Qwen/Qwen3.6-27B"] == "mlx-community/Qwen3.6-27B-mxfp4" + assert "mlx-community/Qwen3.6-27B-mxfp4" not in MODELS + + +def test_qwen3_6_mxfp4_model_info_uses_text_config(monkeypatch): + def fake_load_config_only(model_name, local_files_only=False): + assert model_name in { + "Qwen/Qwen3.6-27B", + "mlx-community/Qwen3.6-27B-mxfp4", + } + return normalize_model_config( + { + "model_type": "qwen3_5", + "architectures": ["Qwen3_5ForConditionalGeneration"], + "quantization_config": {"bits": 4, "mode": "mxfp4"}, + "text_config": { + "num_hidden_layers": 64, + "head_dim": 256, + "hidden_size": 5120, + "intermediate_size": 17408, + "num_attention_heads": 24, + "num_key_value_heads": 4, + "vocab_size": 248320, + }, + } + ) + + monkeypatch.setattr(static_config, "load_config_only", fake_load_config_only) + + model_info = get_model_info("Qwen/Qwen3.6-27B") + + assert model_info.num_layers == 64 + assert model_info.mlx_model_name == "mlx-community/Qwen3.6-27B-mxfp4" + assert model_info.head_size == 256 + assert model_info.hidden_dim == 5120 + assert model_info.num_attention_heads == 24 + assert model_info.num_kv_heads == 4 + assert model_info.param_bytes_per_element == 0.5 + assert model_info.mlx_param_bytes_per_element == 0.5