Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,23 @@ parallax = "parallax.cli:main"
[project.optional-dependencies]

mac = [
"nanobind==2.10.2",
"nanobind==2.12.0",
"torch==2.8.0",
"mlx-lm==0.30.6",
"mlx==0.30.4",
"mlx-lm==0.31.3",
"mlx==0.31.2",
]

gpu = [
"sglang[all]==0.5.12",
"accelerate",
"mlx-lm==0.28.4",
"mlx[cpu]==0.30.0",
"mlx-lm==0.31.3",
"mlx[cpu]==0.31.2",
]

vllm = [
"vllm==0.14.0",
"mlx-lm==0.28.4",
"mlx[cpu]==0.30.0",
"mlx-lm==0.31.3",
"mlx[cpu]==0.31.2",
]

benchmark = [
Expand Down
7 changes: 2 additions & 5 deletions src/parallax/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ def is_mps_available():
def is_metal_available():
"""Check if MLX Metal backend is available"""
try:
import mlx.core as mx

mx.metal.device_info()
return True
return mx.metal.is_available()
except (RuntimeError, AttributeError, ImportError):
return False

Expand All @@ -43,7 +40,7 @@ def get_current_device():
device = "cpu"
if is_cuda_available():
device = "cuda"
if is_mps_available():
if is_metal_available():
device = "mlx"
return device

Expand Down
2 changes: 1 addition & 1 deletion src/parallax_extensions/kernels/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ void PagedAttentionV1::eval_gpu(
auto kernel = d.get_kernel(kname, lib, hash_name, func_consts);

// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = mx::metal::get_command_encoder(s);
compute_encoder.set_compute_pipeline_state(kernel);

// Shared Memory
Expand Down
6 changes: 3 additions & 3 deletions src/parallax_extensions/kernels/reshape_and_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ namespace parallax_ext {
mx::array reshape_and_cache(
const mx::array& key, // [num_tokens, num_heads, head_size]
const mx::array& value, // [num_tokens, num_heads, head_size]
mx::array& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
mx::array& value_cache, // [num_blocks, num_heads, head_size/x, block_size]
const mx::array& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
const mx::array& value_cache, // [num_blocks, num_heads, head_size/x, block_size]
const mx::array& slot_mapping, // [num_tokens]
mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
Expand Down Expand Up @@ -88,7 +88,7 @@ void ReshapeAndCache::eval_gpu(
auto kernel = d.get_kernel(kname, lib, hash_name, func_consts);

// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = mx::metal::get_command_encoder(s);
compute_encoder.set_compute_pipeline_state(kernel);

// Calculate parameters
Expand Down
4 changes: 2 additions & 2 deletions src/parallax_extensions/kernels/reshape_and_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ namespace parallax_ext {
mx::array reshape_and_cache(
const mx::array& key, // [num_tokens, num_heads, head_size]
const mx::array& value, // [num_tokens, num_heads, head_size]
mx::array& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
mx::array& value_cache, // [num_blocks, num_heads, head_size/x, block_size]
const mx::array& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
const mx::array& value_cache, // [num_blocks, num_heads, head_size/x, block_size]
const mx::array& slot_mapping, // [num_tokens]
mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
);
Expand Down
Binary file modified src/parallax_extensions/lib/_ext.cpython-311-darwin.so
Binary file not shown.
Binary file modified src/parallax_extensions/lib/_ext.cpython-312-darwin.so
Binary file not shown.
Binary file modified src/parallax_extensions/lib/_ext.cpython-313-darwin.so
Binary file not shown.
Binary file modified src/parallax_extensions/lib/libparallax_ext.dylib
Binary file not shown.
Binary file modified src/parallax_extensions/lib/parallax_ext.metallib
Binary file not shown.
33 changes: 33 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from types import SimpleNamespace

from parallax.utils import utils


def test_is_metal_available_uses_mlx_metal_is_available(monkeypatch):
fake_mx = SimpleNamespace(metal=SimpleNamespace(is_available=lambda: True))

monkeypatch.setattr(utils, "mx", fake_mx)

assert utils.is_metal_available() is True


def test_is_metal_available_returns_false_when_metal_api_missing(monkeypatch):
fake_mx = SimpleNamespace()

monkeypatch.setattr(utils, "mx", fake_mx)

assert utils.is_metal_available() is False


def test_get_current_device_prefers_mlx_when_metal_available(monkeypatch):
monkeypatch.setattr(utils, "is_cuda_available", lambda: False)
monkeypatch.setattr(utils, "is_metal_available", lambda: True)

assert utils.get_current_device() == "mlx"


def test_get_current_device_prefers_mlx_when_both_backends_report_available(monkeypatch):
monkeypatch.setattr(utils, "is_cuda_available", lambda: True)
monkeypatch.setattr(utils, "is_metal_available", lambda: True)

assert utils.get_current_device() == "mlx"
Loading