Skip to content

Conversation

@tonyjohnvan
Copy link
Contributor

@tonyjohnvan tonyjohnvan commented Feb 11, 2026

Summary

What feature is addressed:
Adds a native Apple MLX backend for the DiT (Diffusion Transformer) decoder inference loop, replacing the PyTorch MPS path on Apple Silicon Macs. The DiT diffusion loop is the single most expensive phase of audio generation -- this change reimplements it in pure MLX to bypass PyTorch-to-MPS overhead entirely.

Why this change is needed:
On Apple Silicon, PyTorch's MPS backend incurs significant dispatch and synchronization overhead for the iterative diffusion loop (8 transformer forward passes per generation). MLX's Metal-native graph execution eliminates this overhead. Real-world benchmarks show 2-3x wall-clock speedup for the DiT diffusion phase on M-series chips, making interactive music generation practical on consumer MacBooks.

The feature is entirely opt-in (checkbox in Gradio UI), auto-detected at init, and falls back gracefully to the existing PyTorch path on any failure -- no existing behavior is affected.


Scope

Files changed (12 modified/added, +1164 / -9 lines)

File Change Purpose
acestep/mlx_dit/__init__.py Added (33 lines) Platform detection: checks macOS + Apple Silicon + MLX package + Metal backend. Cached mlx_available() gate.
acestep/mlx_dit/model.py Added (629 lines) Full MLX reimplementation of AceStepDiTModel: rotary embeddings, multi-head attention with QK-RMSNorm, GQA, sliding window masking, AdaLN DiT layers, timestep embeddings, Conv1d patch embedding, ConvTranspose1d de-patchify.
acestep/mlx_dit/convert.py Added (84 lines) Weight converter: PyTorch state_dict -> MLX arrays. Handles Conv1d [out,in,K]->[out,K,in], ConvTranspose1d [in,out,K]->[out,K,in], Sequential index stripping, rotary buffer skipping.
acestep/mlx_dit/generate.py Added (213 lines) MLX diffusion sampling loop: timestep scheduling, ODE Euler / SDE re-noising, cross-attention KV caching, cover-strength condition switching, seed handling, timing metrics.
acestep/handler.py Modified (+186/-9) Adds _init_mlx_dit(), _mlx_run_diffusion(), MLX fast-path in service_generate() with try/except fallback, use_mlx_dit param in initialize_service(), MLX status in init output.
acestep/gradio_ui/interfaces/generation.py Modified (+11) Adds "MLX DiT (Apple Silicon)" checkbox, auto-enabled when MLX detected, greyed-out otherwise.
acestep/gradio_ui/events/__init__.py Modified (+1) Wires mlx_dit_checkbox into init button inputs.
acestep/gradio_ui/events/generation_handlers.py Modified (+3/-1) Passes mlx_dit param through init_service_wrapper to initialize_service.
acestep/gradio_ui/i18n/{en,he,ja,zh}.json Modified (+3 each) Translation keys: mlx_dit_label, mlx_dit_info_enabled, mlx_dit_info_disabled.
tests/test_mlx_dit.py Added (1382 lines) 74 tests across 17 test classes (see Regression Checks).

What is explicitly out of scope

  • VAE decode remains on PyTorch MPS (only the DiT diffusion loop is accelerated)
  • CUDA/ROCm paths are completely untouched -- MLX init is gated on device in ("mps", "cpu")
  • Training is not affected -- MLX path is inference-only
  • LoRA weights are not converted to MLX (the PyTorch fallback handles LoRA-loaded models)
  • Quantized models are not converted -- torch.compile disables the MLX path (not compile_model gate)
  • LM/text encoder remains on PyTorch; only the decoder DiT forward pass runs in MLX

Risk and Compatibility

Target platform / path

  • macOS + Apple Silicon (M1/M2/M3/M4) + mlx pip package installed
  • Only activates when: use_mlx_dit=True AND device in ("mps", "cpu") AND not compile_model

Confirmation that non-target paths are unchanged

  • Linux/Windows/CUDA: is_mlx_available() returns False on non-Darwin -> MLX init is never attempted. The only code touched in handler.py's generation path is behind if self.use_mlx_dit and self.mlx_decoder is not None, which is always False on non-Apple platforms. The else branch calls the original self.model.generate_audio(**generate_kwargs) identically to main.
  • macOS with MLX disabled: User can uncheck the checkbox -> use_mlx_dit=False -> handler explicitly sets self.mlx_decoder = None; self.use_mlx_dit = False -> PyTorch path runs.
  • macOS with MLX init failure: _init_mlx_dit() catches all exceptions, logs a warning, sets mlx_decoder=None, returns False. Status message shows "Unavailable (PyTorch fallback)".
  • macOS with MLX runtime failure: The service_generate MLX fast-path is wrapped in try/except Exception which logs a warning and falls back to self.model.generate_audio() in the same request.
  • One minor non-functional change: logger.warning -> logger.info for the MPS VAE chunk-size reduction log message (cosmetic, reduces log noise).

Multi-layer fallback chain

Platform check (Darwin?) -> Import check (mlx installed?) -> Runtime check (Metal works?)
  -> Init check (_init_mlx_dit succeeds?) -> Per-request check (try MLX, except -> PyTorch)

Every level is non-fatal. The worst case is a log warning + PyTorch fallback with zero behavioral change.


Regression Checks

Automated tests (74 tests, all passing)

Run: conda run -n ace python -m pytest test_mlx_dit.py -v (4.25s)

Test Class # Tests What it validates
TestMLXAvailabilityDetection 6 is_mlx_available() on Darwin/non-Darwin, import failure, caching
TestWeightConversion 5 Conv1d layout [out,in,K]->[out,K,in], ConvTranspose1d [in,out,K]->[out,K,in], key remapping, rotary skip, convert_and_load integration
TestTimestepSchedule 10 Default shifts (1/2/3), rounding, custom timesteps, trailing zero stripping, truncation >20, empty fallback
TestMLXModelArchitecture 6 from_config, forward output shapes, batch>1, odd seq padding/crop, KV cache population, sliding mask caching
TestMLXCrossAttentionCache 3 Update/get, pre-update state, multi-layer isolation
TestRotaryEmbedding 2 Output shape, cos/sin value range [-1, 1]
TestSwiGLUMLP 1 Output shape preservation
TestMLXDiffusionLoop 9 ODE shape, SDE shape, time_costs, seed reproducibility, seed=None, seed list, custom timesteps, cover-strength switching, numpy output type
TestHandlerMLXIntegration 6 _init_mlx_dit success/failure/skip, tensor conversion roundtrip, None non-cover params, __init__ defaults
TestHandlerInitializeServiceMLXParam 2 use_mlx_dit in signature with default=True, last_init_params storage
TestGradioUIIntegration 2 mlx_dit param in init_service_wrapper, checkbox wired in events
TestI18NKeys 4 All 3 MLX keys present and non-empty in en/he/ja/zh
TestPyTorchFallbackPreserved 5 MLX disabled -> PyTorch path, mlx_decoder=None -> PyTorch, exception -> fallback, user disable clears state, device gating (mps/cpu only, not cuda, not with compile)
TestUtilityFunctions 4 _rotate_half values, _apply_rotary_pos_emb shapes, sliding mask shape + values
TestTimestepEmbedding 2 Output shapes, different timesteps -> different embeddings
TestEdgeCases 5 Single-step diffusion, cover_strength=1.0/0.0 boundary, output not all-zeros, output finite (no NaN/Inf)
TestDiTLayerStandalone 2 Full attention layer forward, sliding attention layer forward

Key scenarios validated

  1. Happy path: MLX available -> weights convert -> diffusion runs -> correct output shape, finite values, seed-reproducible
  2. Graceful degradation: MLX unavailable (non-Darwin, import fails, Metal fails) -> returns False, no crash, PyTorch runs
  3. Runtime failure: MLX diffusion throws -> caught, logged, PyTorch fallback in same request
  4. User opt-out: Checkbox unchecked -> mlx_decoder=None, PyTorch path only
  5. Device gating: Only activates on mps/cpu, never on cuda; blocked when compile_model=True
  6. Weight fidelity: Conv1d/ConvTranspose1d layout transforms verified numerically against PyTorch originals
  7. Edge cases: Odd sequence lengths (padding/crop), single-step schedules, empty timestep lists, cover_strength boundaries

Reviewer Notes

Known pre-existing issues not addressed

  • VAE decode on MPS still requires chunk-size reduction (existing behavior, unchanged)
  • LoRA-loaded models will use the PyTorch path (MLX weight conversion does not handle LoRA adapters)
  • torchao deprecation warnings in test output are from upstream, not introduced by this PR
  • Seems Recent main introduced changes that failed to use MPS vae decode due to 0GB vram detection, resulting extremely slow vae decoding aba0e7b
2026-02-10 23:47:45.444 | INFO     | acestep.handler:generate_music:3604 - [generate_music] Decoding latents with VAE...
2026-02-10 23:47:45.600 | DEBUG    | acestep.handler:generate_music:3622 - [generate_music] Before VAE decode: allocated=0.00GB, max=0.00GB
2026-02-10 23:47:45.600 | INFO     | acestep.handler:generate_music:3629 - [generate_music] Effective free VRAM before VAE decode: 0.00 GB
2026-02-10 23:47:45.600 | WARNING  | acestep.handler:generate_music:3632 - [generate_music] Only 0.00 GB free VRAM — auto-enabling CPU VAE decode
2026-02-10 23:47:45.600 | INFO     | acestep.handler:generate_music:3635 - [generate_music] Moving VAE to CPU for decode (ACESTEP_VAE_ON_CPU=1)...
2026-02-10 23:47:45.995 | INFO     | acestep.handler:generate_music:3642 - [generate_music] Using tiled VAE decode to reduce VRAM usage...

Follow-up items

  • MLX LoRA support: Convert LoRA adapter weights to MLX for full acceleration of fine-tuned models
  • MLX VAE decode: Port the VAE decoder to MLX to eliminate the remaining PyTorch bottleneck
  • Benchmarking CI: Add automated performance regression tests comparing MLX vs MPS wall-clock times
  • Memory profiling: Profile unified memory usage under MLX vs MPS to quantify memory efficiency gains
  • Mixed-precision: Explore fp16/bf16 MLX inference for further speedup on M3+/M4 chips

PS: the test code

"""Comprehensive tests for the native MLX DiT acceleration feature.

Covers:
  1. MLX availability detection and caching
  2. Weight conversion (PyTorch -> MLX) with layout transforms
  3. Timestep schedule generation (valid, edge, and invalid inputs)
  4. MLX model architecture (construction, forward pass, shapes)
  5. Full diffusion generation loop (ODE / SDE / cover-strength)
  6. Handler-level integration and fallback behaviour
  7. Gradio UI wiring (checkbox, init_service_wrapper passthrough)
  8. i18n key presence for all supported languages

Every test is designed to run WITHOUT a real model checkpoint so the suite
stays fast and CI-friendly.
"""

import importlib
import math
import platform
import sys
import types
from copy import deepcopy
from typing import Dict, List, Optional, Tuple
from unittest import mock

import numpy as np
import pytest
import torch

# ──────────────────────────────────────────────────────────────────────
# Conditional MLX import (tests that require MLX are auto-skipped)
# ──────────────────────────────────────────────────────────────────────

try:
    import mlx.core as mx
    import mlx.nn as nn

    MLX_AVAILABLE = True
except ImportError:
    MLX_AVAILABLE = False

requires_mlx = pytest.mark.skipif(not MLX_AVAILABLE, reason="MLX not available")

# ======================================================================
# 1. Availability detection – acestep.mlx_dit.__init__
# ======================================================================


class TestMLXAvailabilityDetection:
    """Ensure is_mlx_available() and mlx_available() behave correctly."""

    def test_is_mlx_available_returns_bool(self):
        from acestep.mlx_dit import is_mlx_available

        result = is_mlx_available()
        assert isinstance(result, bool)

    def test_mlx_available_caching(self):
        """mlx_available() must cache the first result in _MLX_AVAILABLE."""
        import acestep.mlx_dit as mod

        # Reset cache
        mod._MLX_AVAILABLE = None
        first = mod.mlx_available()
        assert mod._MLX_AVAILABLE is not None
        second = mod.mlx_available()
        assert first == second

    def test_is_mlx_available_non_darwin(self):
        """On non-Darwin platforms, should return False."""
        from acestep.mlx_dit import is_mlx_available

        with mock.patch("acestep.mlx_dit.platform") as mock_plat:
            mock_plat.system.return_value = "Linux"
            assert is_mlx_available() is False

    def test_is_mlx_available_import_failure(self):
        """If mlx import fails, should return False (not raise)."""
        from acestep.mlx_dit import is_mlx_available

        with mock.patch("acestep.mlx_dit.platform") as mock_plat:
            mock_plat.system.return_value = "Darwin"
            with mock.patch.dict(sys.modules, {"mlx": None, "mlx.core": None, "mlx.nn": None}):
                # Force reimport to hit the import error path
                assert is_mlx_available() is False

    @requires_mlx
    def test_is_mlx_available_true_on_apple_silicon(self):
        """On macOS + Apple Silicon with MLX installed, should return True."""
        from acestep.mlx_dit import is_mlx_available

        if platform.system() == "Darwin":
            assert is_mlx_available() is True

    def test_mlx_available_resets_on_cache_clear(self):
        """Clearing _MLX_AVAILABLE re-evaluates."""
        import acestep.mlx_dit as mod

        mod._MLX_AVAILABLE = None
        result = mod.mlx_available()
        assert isinstance(result, bool)
        # Should not be None after call
        assert mod._MLX_AVAILABLE is not None


# ======================================================================
# 2. Weight conversion – acestep.mlx_dit.convert
# ======================================================================


@requires_mlx
class TestWeightConversion:
    """Test convert_decoder_weights and convert_and_load."""

    @staticmethod
    def _make_fake_decoder_state_dict() -> Dict[str, torch.Tensor]:
        """Build a minimal state_dict that exercises all conversion paths."""
        sd = {}
        # Conv1d proj_in (Sequential wrapper index 1)
        sd["proj_in.1.weight"] = torch.randn(256, 192, 2)  # [out, in, K]
        sd["proj_in.1.bias"] = torch.randn(256)
        # ConvTranspose1d proj_out (Sequential wrapper index 1)
        sd["proj_out.1.weight"] = torch.randn(256, 64, 2)  # [in, out, K]
        sd["proj_out.1.bias"] = torch.randn(64)
        # A normal linear weight
        sd["layers.0.self_attn.q_proj.weight"] = torch.randn(256, 256)
        # Rotary embedding buffer (should be skipped)
        sd["layers.0.self_attn.rotary_emb.inv_freq"] = torch.randn(64)
        return sd

    def _make_fake_pytorch_model(self):
        """Create a minimal mock that exposes decoder.state_dict()."""
        sd = self._make_fake_decoder_state_dict()
        decoder = mock.MagicMock()
        decoder.state_dict.return_value = sd
        model = mock.MagicMock()
        model.decoder = decoder
        return model, sd

    def test_convert_decoder_weights_key_remapping(self):
        from acestep.mlx_dit.convert import convert_decoder_weights

        model, _ = self._make_fake_pytorch_model()
        pairs = convert_decoder_weights(model)
        names = [n for n, _ in pairs]

        # Sequential index should be stripped
        assert "proj_in.weight" in names
        assert "proj_in.bias" in names
        assert "proj_out.weight" in names
        assert "proj_out.bias" in names
        # Original indexed keys should NOT appear
        assert "proj_in.1.weight" not in names
        assert "proj_out.1.weight" not in names
        # Rotary embedding should be skipped
        assert all("rotary_emb" not in n for n, _ in pairs)

    def test_conv1d_weight_layout(self):
        """proj_in Conv1d: PT [out, in, K] -> MLX [out, K, in]."""
        from acestep.mlx_dit.convert import convert_decoder_weights

        model, sd = self._make_fake_pytorch_model()
        pairs = convert_decoder_weights(model)
        weight_dict = dict(pairs)

        pt_w = sd["proj_in.1.weight"].numpy()  # [256, 192, 2]
        mlx_w = np.array(weight_dict["proj_in.weight"])
        assert mlx_w.shape == (256, 2, 192), f"Expected (256,2,192), got {mlx_w.shape}"
        # Verify swapaxes(1,2) relationship
        np.testing.assert_allclose(mlx_w, pt_w.swapaxes(1, 2), atol=1e-6)

    def test_convtranspose1d_weight_layout(self):
        """proj_out ConvTranspose1d: PT [in, out, K] -> MLX [out, K, in]."""
        from acestep.mlx_dit.convert import convert_decoder_weights

        model, sd = self._make_fake_pytorch_model()
        pairs = convert_decoder_weights(model)
        weight_dict = dict(pairs)

        pt_w = sd["proj_out.1.weight"].numpy()  # [256, 64, 2]
        mlx_w = np.array(weight_dict["proj_out.weight"])
        assert mlx_w.shape == (64, 2, 256), f"Expected (64,2,256), got {mlx_w.shape}"
        # Verify transpose(1,2,0) relationship
        np.testing.assert_allclose(mlx_w, pt_w.transpose(1, 2, 0), atol=1e-6)

    def test_passthrough_weight_unchanged(self):
        """Non-special weights should transfer values unchanged."""
        from acestep.mlx_dit.convert import convert_decoder_weights

        model, sd = self._make_fake_pytorch_model()
        pairs = convert_decoder_weights(model)
        weight_dict = dict(pairs)

        key = "layers.0.self_attn.q_proj.weight"
        pt_np = sd[key].numpy()
        mlx_np = np.array(weight_dict[key])
        np.testing.assert_allclose(mlx_np, pt_np, atol=1e-6)

    def test_convert_and_load_calls_load_weights(self):
        """convert_and_load must call mlx_decoder.load_weights with the pairs."""
        from acestep.mlx_dit.convert import convert_and_load

        model, _ = self._make_fake_pytorch_model()
        mlx_decoder = mock.MagicMock()
        # parameters() returns a dict so mx.eval can iterate it
        mlx_decoder.parameters.return_value = {}

        convert_and_load(model, mlx_decoder)

        mlx_decoder.load_weights.assert_called_once()
        args = mlx_decoder.load_weights.call_args[0][0]
        assert isinstance(args, list)
        assert all(isinstance(p, tuple) and len(p) == 2 for p in args)


# ======================================================================
# 3. Timestep schedule generation – acestep.mlx_dit.generate
# ======================================================================


class TestTimestepSchedule:
    """Test get_timestep_schedule for various inputs."""

    def test_default_shift_3(self):
        from acestep.mlx_dit.generate import SHIFT_TIMESTEPS, get_timestep_schedule

        result = get_timestep_schedule(shift=3.0)
        assert result == SHIFT_TIMESTEPS[3.0]

    def test_default_shift_1(self):
        from acestep.mlx_dit.generate import SHIFT_TIMESTEPS, get_timestep_schedule

        result = get_timestep_schedule(shift=1.0)
        assert result == SHIFT_TIMESTEPS[1.0]

    def test_default_shift_2(self):
        from acestep.mlx_dit.generate import SHIFT_TIMESTEPS, get_timestep_schedule

        result = get_timestep_schedule(shift=2.0)
        assert result == SHIFT_TIMESTEPS[2.0]

    def test_shift_rounds_to_nearest_valid(self):
        """Non-standard shift values should round to the nearest valid one."""
        from acestep.mlx_dit.generate import SHIFT_TIMESTEPS, get_timestep_schedule

        result = get_timestep_schedule(shift=2.6)
        assert result == SHIFT_TIMESTEPS[3.0]

        result = get_timestep_schedule(shift=1.4)
        assert result == SHIFT_TIMESTEPS[1.0]

    def test_custom_timesteps_passthrough(self):
        from acestep.mlx_dit.generate import VALID_TIMESTEPS, get_timestep_schedule

        ts = [1.0, 0.75, 0.5, 0.25]
        result = get_timestep_schedule(shift=3.0, timesteps=ts)
        assert len(result) == 4
        # Each should be one of the valid timesteps (mapped)
        for t in result:
            assert t in VALID_TIMESTEPS

    def test_custom_timesteps_trailing_zeros_stripped(self):
        from acestep.mlx_dit.generate import get_timestep_schedule

        ts = [1.0, 0.5, 0.0, 0.0]
        result = get_timestep_schedule(shift=3.0, timesteps=ts)
        assert 0.0 not in result
        assert len(result) == 2

    def test_custom_timesteps_all_zeros_falls_back(self):
        from acestep.mlx_dit.generate import SHIFT_TIMESTEPS, get_timestep_schedule

        ts = [0.0, 0.0, 0.0]
        result = get_timestep_schedule(shift=3.0, timesteps=ts)
        # Falls back to default when empty after stripping zeros
        assert result == SHIFT_TIMESTEPS[3.0]

    def test_custom_timesteps_truncation_over_20(self):
        from acestep.mlx_dit.generate import get_timestep_schedule

        ts = list(np.linspace(1.0, 0.05, 25))  # 25 steps
        result = get_timestep_schedule(shift=3.0, timesteps=ts)
        assert len(result) <= 20

    def test_schedule_is_non_empty(self):
        from acestep.mlx_dit.generate import get_timestep_schedule

        for shift in [1.0, 2.0, 3.0]:
            result = get_timestep_schedule(shift=shift)
            assert len(result) > 0

    def test_schedule_values_positive(self):
        from acestep.mlx_dit.generate import get_timestep_schedule

        for shift in [1.0, 2.0, 3.0]:
            for t in get_timestep_schedule(shift=shift):
                assert t > 0, f"Timestep {t} should be positive"


# ======================================================================
# 4. MLX Model Architecture – acestep.mlx_dit.model
# ======================================================================


@requires_mlx
class TestMLXModelArchitecture:
    """Test model construction and forward pass shapes with small configs."""

    @staticmethod
    def _small_config():
        """Return a minimal config-like object for a tiny model."""
        cfg = types.SimpleNamespace(
            hidden_size=64,
            intermediate_size=128,
            num_hidden_layers=2,
            num_attention_heads=4,
            num_key_value_heads=2,
            head_dim=16,
            rms_norm_eps=1e-6,
            attention_bias=False,
            in_channels=24,           # context (16) + audio_acoustic (8)
            audio_acoustic_hidden_dim=8,
            patch_size=2,
            sliding_window=16,
            layer_types=["sliding_attention", "full_attention"],
            rope_theta=1_000_000.0,
            max_position_embeddings=512,
        )
        return cfg

    def test_from_config_creates_decoder(self):
        from acestep.mlx_dit.model import MLXDiTDecoder

        cfg = self._small_config()
        dec = MLXDiTDecoder.from_config(cfg)
        assert isinstance(dec, MLXDiTDecoder)
        assert len(dec.layers) == cfg.num_hidden_layers

    def test_forward_output_shape(self):
        from acestep.mlx_dit.model import MLXCrossAttentionCache, MLXDiTDecoder

        cfg = self._small_config()
        dec = MLXDiTDecoder.from_config(cfg)
        mx.eval(dec.parameters())

        B, T, C_audio = 1, 16, 8
        C_ctx = cfg.in_channels - C_audio  # 16
        hidden = mx.random.normal((B, T, C_audio))
        enc_hs = mx.random.normal((B, 10, cfg.hidden_size))
        ctx = mx.random.normal((B, T, C_ctx))
        t_step = mx.full((B,), 1.0)

        out, cache = dec(
            hidden_states=hidden,
            timestep=t_step,
            timestep_r=t_step,
            encoder_hidden_states=enc_hs,
            context_latents=ctx,
        )
        mx.eval(out)
        assert out.shape == (B, T, C_audio), f"Expected ({B},{T},{C_audio}), got {out.shape}"

    def test_forward_batch_size_2(self):
        from acestep.mlx_dit.model import MLXDiTDecoder

        cfg = self._small_config()
        dec = MLXDiTDecoder.from_config(cfg)
        mx.eval(dec.parameters())

        B, T, C_audio = 2, 16, 8
        C_ctx = cfg.in_channels - C_audio
        hidden = mx.random.normal((B, T, C_audio))
        enc_hs = mx.random.normal((B, 10, cfg.hidden_size))
        ctx = mx.random.normal((B, T, C_ctx))
        t_step = mx.full((B,), 0.5)

        out, _ = dec(
            hidden_states=hidden,
            timestep=t_step,
            timestep_r=t_step,
            encoder_hidden_states=enc_hs,
            context_latents=ctx,
        )
        mx.eval(out)
        assert out.shape[0] == B

    def test_padding_odd_sequence_length(self):
        """Odd sequence lengths should be padded and output cropped back."""
        from acestep.mlx_dit.model import MLXDiTDecoder

        cfg = self._small_config()
        dec = MLXDiTDecoder.from_config(cfg)
        mx.eval(dec.parameters())

        B, T, C_audio = 1, 15, 8  # T=15 is not divisible by patch_size=2
        C_ctx = cfg.in_channels - C_audio
        hidden = mx.random.normal((B, T, C_audio))
        enc_hs = mx.random.normal((B, 10, cfg.hidden_size))
        ctx = mx.random.normal((B, T, C_ctx))
        t_step = mx.full((B,), 1.0)

        out, _ = dec(
            hidden_states=hidden,
            timestep=t_step,
            timestep_r=t_step,
            encoder_hidden_states=enc_hs,
            context_latents=ctx,
        )
        mx.eval(out)
        # Output should match original T, not padded T
        assert out.shape[1] == T

    def test_cache_reuse_across_steps(self):
        """Cross-attention KV cache should be populated on first call and reused."""
        from acestep.mlx_dit.model import MLXCrossAttentionCache, MLXDiTDecoder

        cfg = self._small_config()
        dec = MLXDiTDecoder.from_config(cfg)
        mx.eval(dec.parameters())

        B, T, C_audio = 1, 16, 8
        C_ctx = cfg.in_channels - C_audio
        hidden = mx.random.normal((B, T, C_audio))
        enc_hs = mx.random.normal((B, 10, cfg.hidden_size))
        ctx = mx.random.normal((B, T, C_ctx))
        t_step = mx.full((B,), 1.0)

        cache = MLXCrossAttentionCache()
        _, cache = dec(
            hidden_states=hidden,
            timestep=t_step,
            timestep_r=t_step,
            encoder_hidden_states=enc_hs,
            context_latents=ctx,
            cache=cache,
            use_cache=True,
        )
        # After first call, cache should have entries for each layer
        for i in range(cfg.num_hidden_layers):
            assert cache.is_updated(i), f"Layer {i} cache not populated"

    def test_sliding_window_mask_caching(self):
        """Sliding mask should be cached for a given sequence length."""
        from acestep.mlx_dit.model import MLXDiTDecoder

        cfg = self._small_config()
        dec = MLXDiTDecoder.from_config(cfg)

        mask1 = dec._get_sliding_mask(32, mx.float32)
        mask2 = dec._get_sliding_mask(32, mx.float32)
        # Should be the exact same object (cached)
        assert mask1 is mask2

        # Different seq_len should produce different mask
        mask3 = dec._get_sliding_mask(64, mx.float32)
        assert mask3 is not mask1


# ======================================================================
# 5. MLX Cross-Attention Cache
# ======================================================================


@requires_mlx
class TestMLXCrossAttentionCache:
    def test_update_and_get(self):
        from acestep.mlx_dit.model import MLXCrossAttentionCache

        cache = MLXCrossAttentionCache()
        k = mx.zeros((1, 4, 10, 16))
        v = mx.ones((1, 4, 10, 16))
        cache.update(k, v, layer_idx=0)
        assert cache.is_updated(0)
        k_out, v_out = cache.get(0)
        np.testing.assert_array_equal(np.array(k_out), np.array(k))
        np.testing.assert_array_equal(np.array(v_out), np.array(v))

    def test_is_updated_false_before_update(self):
        from acestep.mlx_dit.model import MLXCrossAttentionCache

        cache = MLXCrossAttentionCache()
        assert not cache.is_updated(0)
        assert not cache.is_updated(99)

    def test_multiple_layers(self):
        from acestep.mlx_dit.model import MLXCrossAttentionCache

        cache = MLXCrossAttentionCache()
        for i in range(5):
            cache.update(mx.zeros((1, 1, 1, 1)), mx.ones((1, 1, 1, 1)), layer_idx=i)
        for i in range(5):
            assert cache.is_updated(i)
        assert not cache.is_updated(5)


# ======================================================================
# 6. Rotary Embedding
# ======================================================================


@requires_mlx
class TestRotaryEmbedding:
    def test_output_shape(self):
        from acestep.mlx_dit.model import MLXRotaryEmbedding

        rope = MLXRotaryEmbedding(head_dim=64, max_len=128)
        cos, sin = rope(seq_len=32)
        assert cos.shape == (1, 1, 32, 64)
        assert sin.shape == (1, 1, 32, 64)

    def test_cos_sin_range(self):
        from acestep.mlx_dit.model import MLXRotaryEmbedding

        rope = MLXRotaryEmbedding(head_dim=64, max_len=128)
        cos, sin = rope(seq_len=32)
        cos_np = np.array(cos)
        sin_np = np.array(sin)
        assert np.all(cos_np >= -1.0) and np.all(cos_np <= 1.0)
        assert np.all(sin_np >= -1.0) and np.all(sin_np <= 1.0)


# ======================================================================
# 7. SwiGLU MLP
# ======================================================================


@requires_mlx
class TestSwiGLUMLP:
    def test_output_shape(self):
        from acestep.mlx_dit.model import MLXSwiGLUMLP

        mlp = MLXSwiGLUMLP(hidden_size=64, intermediate_size=128)
        mx.eval(mlp.parameters())
        x = mx.random.normal((2, 10, 64))
        out = mlp(x)
        mx.eval(out)
        assert out.shape == (2, 10, 64)


# ======================================================================
# 8. Full diffusion loop – acestep.mlx_dit.generate
# ======================================================================


@requires_mlx
class TestMLXDiffusionLoop:
    """Test mlx_generate_diffusion with a real tiny model."""

    @staticmethod
    def _make_tiny_decoder():
        from acestep.mlx_dit.model import MLXDiTDecoder

        cfg = types.SimpleNamespace(
            hidden_size=64,
            intermediate_size=128,
            num_hidden_layers=2,
            num_attention_heads=4,
            num_key_value_heads=2,
            head_dim=16,
            rms_norm_eps=1e-6,
            attention_bias=False,
            in_channels=24,
            audio_acoustic_hidden_dim=8,
            patch_size=2,
            sliding_window=16,
            layer_types=["sliding_attention", "full_attention"],
            rope_theta=1_000_000.0,
            max_position_embeddings=512,
        )
        dec = MLXDiTDecoder.from_config(cfg)
        mx.eval(dec.parameters())
        return dec, cfg

    def test_ode_generation_returns_correct_shape(self):
        from acestep.mlx_dit.generate import mlx_generate_diffusion

        dec, cfg = self._make_tiny_decoder()
        B, T, C = 1, 16, 8
        enc_np = np.random.randn(B, 10, cfg.hidden_size).astype(np.float32)
        ctx_np = np.random.randn(B, T, cfg.in_channels - C).astype(np.float32)

        result = mlx_generate_diffusion(
            mlx_decoder=dec,
            encoder_hidden_states_np=enc_np,
            context_latents_np=ctx_np,
            src_latents_shape=(B, T, C),
            seed=42,
            infer_method="ode",
            shift=3.0,
        )
        assert "target_latents" in result
        assert "time_costs" in result
        assert result["target_latents"].shape == (B, T, C)

    def test_sde_generation_returns_correct_shape(self):
        from acestep.mlx_dit.generate import mlx_generate_diffusion

        dec, cfg = self._make_tiny_decoder()
        B, T, C = 1, 16, 8
        enc_np = np.random.randn(B, 10, cfg.hidden_size).astype(np.float32)
        ctx_np = np.random.randn(B, T, cfg.in_channels - C).astype(np.float32)

        result = mlx_generate_diffusion(
            mlx_decoder=dec,
            encoder_hidden_states_np=enc_np,
            context_latents_np=ctx_np,
            src_latents_shape=(B, T, C),
            seed=42,
            infer_method="sde",
            shift=3.0,
        )
        assert result["target_latents"].shape == (B, T, C)

    def test_time_costs_populated(self):
        from acestep.mlx_dit.generate import mlx_generate_diffusion

        dec, cfg = self._make_tiny_decoder()
        B, T, C = 1, 16, 8
        enc_np = np.random.randn(B, 10, cfg.hidden_size).astype(np.float32)
        ctx_np = np.random.randn(B, T, cfg.in_channels - C).astype(np.float32)

        result = mlx_generate_diffusion(
            mlx_decoder=dec,
            encoder_hidden_states_np=enc_np,
            context_latents_np=ctx_np,
            src_latents_shape=(B, T, C),
            seed=42,
        )
        tc = result["time_costs"]
        assert "diffusion_time_cost" in tc
        assert "diffusion_per_step_time_cost" in tc
        assert "total_time_cost" in tc
        assert tc["diffusion_time_cost"] >= 0
        assert tc["total_time_cost"] >= tc["diffusion_time_cost"]

    def test_seed_reproducibility(self):
        from acestep.mlx_dit.generate import mlx_generate_diffusion

        dec, cfg = self._make_tiny_decoder()
        B, T, C = 1, 16, 8
        enc_np = np.random.randn(B, 10, cfg.hidden_size).astype(np.float32)
        ctx_np = np.random.randn(B, T, cfg.in_channels - C).astype(np.float32)

        r1 = mlx_generate_diffusion(
            mlx_decoder=dec,
            encoder_hidden_states_np=enc_np,
            context_latents_np=ctx_np,
            src_latents_shape=(B, T, C),
            seed=123,
            infer_method="ode",
            shift=3.0,
        )
        r2 = mlx_generate_diffusion(
            mlx_decoder=dec,
            encoder_hidden_states_np=enc_np,
            context_latents_np=ctx_np,
            src_latents_shape=(B, T, C),
            seed=123,
            infer_method="ode",
            shift=3.0,
        )
        np.testing.assert_allclose(
            r1["target_latents"], r2["target_latents"], atol=1e-4,
            err_msg="Same seed should produce identical results"
        )

    def test_seed_none_runs_without_error(self):
        from acestep.mlx_dit.generate import mlx_generate_diffusion

        dec, cfg = self._make_tiny_decoder()
        B, T, C = 1, 16, 8
        enc_np = np.random.randn(B, 10, cfg.hidden_size).astype(np.float32)
        ctx_np = np.random.randn(B, T, cfg.in_channels - C).astype(np.float32)

        result = mlx_generate_diffusion(
            mlx_decoder=dec,
            encoder_hidden_states_np=enc_np,
            context_latents_np=ctx_np,
            src_latents_shape=(B, T, C),
            seed=None,
        )
        assert result["target_latents"].shape == (B, T, C)

    def test_seed_list(self):
        from acestep.mlx_dit.generate import mlx_generate_diffusion

        dec, cfg = self._make_tiny_decoder()
        B, T, C = 2, 16, 8
        enc_np = np.random.randn(B, 10, cfg.hidden_size).astype(np.float32)
        ctx_np = np.random.randn(B, T, cfg.in_channels - C).astype(np.float32)

        result = mlx_generate_diffusion(
            mlx_decoder=dec,
            encoder_hidden_states_np=enc_np,
            context_latents_np=ctx_np,
            src_latents_shape=(B, T, C),
            seed=[42, 99],
        )
        assert result["target_latents"].shape == (B, T, C)

    def test_custom_timesteps(self):
        from acestep.mlx_dit.generate import mlx_generate_diffusion

        dec, cfg = self._make_tiny_decoder()
        B, T, C = 1, 16, 8
        enc_np = np.random.randn(B, 10, cfg.hidden_size).astype(np.float32)
        ctx_np = np.random.randn(B, T, cfg.in_channels - C).astype(np.float32)

        result = mlx_generate_diffusion(
            mlx_decoder=dec,
            encoder_hidden_states_np=enc_np,
            context_latents_np=ctx_np,
            src_latents_shape=(B, T, C),
            seed=42,
            timesteps=[1.0, 0.75, 0.5, 0.25],
        )
        assert result["target_latents"].shape == (B, T, C)

    def test_cover_strength_switching(self):
        """When audio_cover_strength < 1.0, non-cover conditions should be used."""
        from acestep.mlx_dit.generate import mlx_generate_diffusion

        dec, cfg = self._make_tiny_decoder()
        B, T, C = 1, 16, 8
        enc_np = np.random.randn(B, 10, cfg.hidden_size).astype(np.float32)
        ctx_np = np.random.randn(B, T, cfg.in_channels - C).astype(np.float32)
        enc_nc_np = np.random.randn(B, 10, cfg.hidden_size).astype(np.float32)
        ctx_nc_np = np.random.randn(B, T, cfg.in_channels - C).astype(np.float32)

        result = mlx_generate_diffusion(
            mlx_decoder=dec,
            encoder_hidden_states_np=enc_np,
            context_latents_np=ctx_np,
            src_latents_shape=(B, T, C),
            seed=42,
            audio_cover_strength=0.5,
            encoder_hidden_states_non_cover_np=enc_nc_np,
            context_latents_non_cover_np=ctx_nc_np,
        )
        assert result["target_latents"].shape == (B, T, C)

    def test_output_is_numpy(self):
        from acestep.mlx_dit.generate import mlx_generate_diffusion

        dec, cfg = self._make_tiny_decoder()
        B, T, C = 1, 16, 8
        enc_np = np.random.randn(B, 10, cfg.hidden_size).astype(np.float32)
        ctx_np = np.random.randn(B, T, cfg.in_channels - C).astype(np.float32)

        result = mlx_generate_diffusion(
            mlx_decoder=dec,
            encoder_hidden_states_np=enc_np,
            context_latents_np=ctx_np,
            src_latents_shape=(B, T, C),
            seed=42,
        )
        assert isinstance(result["target_latents"], np.ndarray)


# ======================================================================
# 9. Handler integration – init_mlx_dit and fallback
# ======================================================================


class TestHandlerMLXIntegration:
    """Test handler-level MLX DiT initialization and fallback paths."""

    def _make_handler_stub(self):
        """Create a minimal AceStepHandler with mocked internals."""
        from acestep.handler import AceStepHandler

        handler = AceStepHandler.__new__(AceStepHandler)
        handler.model = mock.MagicMock()
        handler.config = mock.MagicMock()
        handler.device = "mps"
        handler.dtype = torch.float32
        handler.mlx_decoder = None
        handler.use_mlx_dit = False
        return handler

    @requires_mlx
    def test_init_mlx_dit_success(self):
        """On Apple Silicon with MLX, _init_mlx_dit should succeed."""
        handler = self._make_handler_stub()

        # Provide a config with real numeric attributes so from_config works
        handler.config = types.SimpleNamespace(
            hidden_size=64, intermediate_size=128, num_hidden_layers=2,
            num_attention_heads=4, num_key_value_heads=2, head_dim=16,
            rms_norm_eps=1e-6, attention_bias=False, in_channels=24,
            audio_acoustic_hidden_dim=8, patch_size=2, sliding_window=16,
            layer_types=["sliding_attention", "full_attention"],
            rope_theta=1_000_000.0, max_position_embeddings=512,
        )

        # Mock the converter to avoid needing real weights
        with mock.patch("acestep.mlx_dit.convert.convert_and_load"):
            result = handler._init_mlx_dit()

        assert result is True
        assert handler.use_mlx_dit is True
        assert handler.mlx_decoder is not None

    def test_init_mlx_dit_failure_is_non_fatal(self):
        """If MLX init fails, handler should fall back gracefully."""
        handler = self._make_handler_stub()

        with mock.patch("acestep.mlx_dit.mlx_available", return_value=True):
            with mock.patch(
                "acestep.mlx_dit.model.MLXDiTDecoder.from_config",
                side_effect=RuntimeError("test failure"),
            ):
                result = handler._init_mlx_dit()

        assert result is False
        assert handler.use_mlx_dit is False
        assert handler.mlx_decoder is None

    def test_init_mlx_dit_skipped_when_not_available(self):
        """If MLX is unavailable, init should return False without error."""
        handler = self._make_handler_stub()

        with mock.patch("acestep.mlx_dit.mlx_available", return_value=False):
            result = handler._init_mlx_dit()

        assert result is False
        assert handler.use_mlx_dit is False

    @requires_mlx
    def test_mlx_run_diffusion_converts_torch_to_numpy_and_back(self):
        """_mlx_run_diffusion should accept PyTorch tensors and return PyTorch tensors."""
        handler = self._make_handler_stub()
        handler.device = "cpu"
        handler.dtype = torch.float32

        B, T, C = 1, 16, 8
        fake_result_np = np.random.randn(B, T, C).astype(np.float32)

        with mock.patch(
            "acestep.mlx_dit.generate.mlx_generate_diffusion",
            return_value={
                "target_latents": fake_result_np,
                "time_costs": {"diffusion_time_cost": 0.1, "diffusion_per_step_time_cost": 0.01},
            },
        ):
            result = handler._mlx_run_diffusion(
                encoder_hidden_states=torch.randn(B, 10, 64),
                encoder_attention_mask=torch.ones(B, 10),
                context_latents=torch.randn(B, T, 16),
                src_latents=torch.randn(B, T, C),
                seed=42,
            )

        assert isinstance(result["target_latents"], torch.Tensor)
        assert result["target_latents"].shape == (B, T, C)

    def test_mlx_run_diffusion_handles_non_cover_none(self):
        """Non-cover params as None should not cause errors."""
        handler = self._make_handler_stub()
        handler.device = "cpu"
        handler.dtype = torch.float32

        B, T, C = 1, 16, 8
        fake_result_np = np.random.randn(B, T, C).astype(np.float32)

        with mock.patch(
            "acestep.mlx_dit.generate.mlx_generate_diffusion",
            return_value={
                "target_latents": fake_result_np,
                "time_costs": {},
            },
        ):
            result = handler._mlx_run_diffusion(
                encoder_hidden_states=torch.randn(B, 10, 64),
                encoder_attention_mask=torch.ones(B, 10),
                context_latents=torch.randn(B, T, 16),
                src_latents=torch.randn(B, T, C),
                seed=42,
                encoder_hidden_states_non_cover=None,
                encoder_attention_mask_non_cover=None,
                context_latents_non_cover=None,
            )
        assert isinstance(result["target_latents"], torch.Tensor)

    def test_handler_init_defaults(self):
        """AceStepHandler.__init__ must set mlx_decoder=None and use_mlx_dit=False."""
        from acestep.handler import AceStepHandler

        handler = AceStepHandler()
        assert handler.mlx_decoder is None
        assert handler.use_mlx_dit is False


# ======================================================================
# 10. Handler initialize_service – MLX parameter passthrough
# ======================================================================


class TestHandlerInitializeServiceMLXParam:
    """Ensure use_mlx_dit parameter is wired through initialize_service."""

    def test_initialize_service_signature_has_use_mlx_dit(self):
        """initialize_service must accept use_mlx_dit keyword argument."""
        import inspect

        from acestep.handler import AceStepHandler

        sig = inspect.signature(AceStepHandler.initialize_service)
        assert "use_mlx_dit" in sig.parameters
        # Default should be True
        assert sig.parameters["use_mlx_dit"].default is True

    def test_last_init_params_includes_use_mlx_dit(self):
        """After successful init, last_init_params should contain use_mlx_dit."""
        from acestep.handler import AceStepHandler

        handler = AceStepHandler()
        # We can't fully init without a model, but we can check the code path
        # by verifying the parameter is in the signature and testing the dict shape
        # from a mock scenario.
        handler.last_init_params = {
            "use_mlx_dit": True,
            "device": "mps",
        }
        assert "use_mlx_dit" in handler.last_init_params


# ======================================================================
# 11. Gradio UI integration
# ======================================================================


class TestGradioUIIntegration:
    """Test that the Gradio UI correctly wires the MLX DiT checkbox."""

    def test_init_service_wrapper_signature_has_mlx_dit(self):
        """init_service_wrapper must accept mlx_dit parameter."""
        import inspect

        from acestep.gradio_ui.events.generation_handlers import init_service_wrapper

        sig = inspect.signature(init_service_wrapper)
        assert "mlx_dit" in sig.parameters
        # Default should be True
        assert sig.parameters["mlx_dit"].default is True

    def test_events_init_includes_mlx_dit_checkbox(self):
        """The events __init__.py should reference mlx_dit_checkbox."""
        import os

        events_init_path = os.path.join(
            os.path.dirname(__file__),
            "..",
            "acestep",
            "gradio_ui",
            "events",
            "__init__.py",
        )
        with open(events_init_path, "r", encoding="utf-8") as f:
            source = f.read()
        assert "mlx_dit_checkbox" in source


# ======================================================================
# 12. i18n key coverage
# ======================================================================


class TestI18NKeys:
    """Ensure all supported languages have the MLX DiT translation keys."""

    MLX_KEYS = [
        "service.mlx_dit_label",
        "service.mlx_dit_info_enabled",
        "service.mlx_dit_info_disabled",
    ]
    LANGUAGES = ["en", "he", "ja", "zh"]

    @pytest.fixture(params=LANGUAGES)
    def lang_data(self, request):
        import json
        import os

        lang = request.param
        path = os.path.join(
            os.path.dirname(__file__),
            "..",
            "acestep",
            "gradio_ui",
            "i18n",
            f"{lang}.json",
        )
        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)
        return lang, data

    def test_mlx_dit_keys_present(self, lang_data):
        lang, data = lang_data
        # Flatten dotted keys -> nested lookup
        for dotted_key in self.MLX_KEYS:
            parts = dotted_key.split(".")
            node = data
            for part in parts:
                assert part in node, (
                    f"Missing i18n key '{dotted_key}' in {lang}.json "
                    f"(could not find '{part}' in {list(node.keys()) if isinstance(node, dict) else type(node)})"
                )
                node = node[part]
            assert isinstance(node, str) and len(node) > 0, (
                f"i18n key '{dotted_key}' in {lang}.json is empty"
            )


# ======================================================================
# 13. Fallback path – PyTorch code path not broken when MLX is off
# ======================================================================


class TestPyTorchFallbackPreserved:
    """Verify that the original PyTorch path works when MLX is disabled."""

    def test_handler_with_mlx_disabled_uses_pytorch_path(self):
        """When use_mlx_dit=False, handler should use model.generate_audio."""
        from acestep.handler import AceStepHandler

        handler = AceStepHandler.__new__(AceStepHandler)
        handler.use_mlx_dit = False
        handler.mlx_decoder = None

        # The condition in service_generate checks these flags
        should_use_mlx = handler.use_mlx_dit and handler.mlx_decoder is not None
        assert should_use_mlx is False

    def test_handler_with_mlx_decoder_none_uses_pytorch_path(self):
        """Even if use_mlx_dit=True, if mlx_decoder is None, PyTorch path runs."""
        from acestep.handler import AceStepHandler

        handler = AceStepHandler.__new__(AceStepHandler)
        handler.use_mlx_dit = True
        handler.mlx_decoder = None

        should_use_mlx = handler.use_mlx_dit and handler.mlx_decoder is not None
        assert should_use_mlx is False

    def test_mlx_exception_falls_back_to_pytorch(self):
        """If MLX diffusion raises, the handler must catch and fall back."""
        from acestep.handler import AceStepHandler

        handler = AceStepHandler.__new__(AceStepHandler)
        handler.use_mlx_dit = True
        handler.mlx_decoder = mock.MagicMock()  # non-None

        # Verify the condition would select MLX path
        should_use_mlx = handler.use_mlx_dit and handler.mlx_decoder is not None
        assert should_use_mlx is True

        # The try/except in service_generate catches Exception from _mlx_run_diffusion
        # and calls model.generate_audio instead. Verify this pattern exists.
        import inspect

        source = inspect.getsource(AceStepHandler)
        # The fallback pattern: except -> model.generate_audio
        assert "MLX diffusion failed" in source
        assert "falling back to PyTorch" in source

    def test_initialize_service_disabled_by_user_sets_flags(self):
        """When use_mlx_dit=False is passed, handler must clear MLX state."""
        from acestep.handler import AceStepHandler

        handler = AceStepHandler.__new__(AceStepHandler)
        handler.mlx_decoder = mock.MagicMock()
        handler.use_mlx_dit = True

        # Simulate the code path when user disables MLX
        # (extracted from initialize_service)
        use_mlx_dit = False
        if not use_mlx_dit:
            handler.mlx_decoder = None
            handler.use_mlx_dit = False

        assert handler.mlx_decoder is None
        assert handler.use_mlx_dit is False

    def test_mlx_only_activates_on_mps_or_cpu(self):
        """MLX init should only happen on 'mps' or 'cpu' devices, not 'cuda'."""
        # Verify the condition from handler.py
        for device, compile_model, expected in [
            ("mps", False, True),
            ("cpu", False, True),
            ("cuda", False, False),
            ("cuda:0", False, False),
            ("mps", True, False),  # compile_model disables MLX
        ]:
            use_mlx_dit = True
            should_try = use_mlx_dit and device in ("mps", "cpu") and not compile_model
            assert should_try == expected, (
                f"device={device}, compile={compile_model}: "
                f"expected {expected}, got {should_try}"
            )


# ======================================================================
# 14. Utility functions – rotate_half, apply_rotary_pos_emb, sliding mask
# ======================================================================


@requires_mlx
class TestUtilityFunctions:
    def test_rotate_half(self):
        from acestep.mlx_dit.model import _rotate_half

        x = mx.array([[1.0, 2.0, 3.0, 4.0]])
        out = _rotate_half(x)
        expected = mx.array([[-3.0, -4.0, 1.0, 2.0]])
        np.testing.assert_allclose(np.array(out), np.array(expected))

    def test_apply_rotary_pos_emb_shapes(self):
        from acestep.mlx_dit.model import _apply_rotary_pos_emb

        B, H, L, D = 1, 4, 8, 16
        q = mx.random.normal((B, H, L, D))
        k = mx.random.normal((B, H, L, D))
        cos = mx.random.normal((1, 1, L, D))
        sin = mx.random.normal((1, 1, L, D))
        q_out, k_out = _apply_rotary_pos_emb(q, k, cos, sin)
        assert q_out.shape == (B, H, L, D)
        assert k_out.shape == (B, H, L, D)

    def test_create_sliding_window_mask_shape(self):
        from acestep.mlx_dit.model import _create_sliding_window_mask

        mask = _create_sliding_window_mask(seq_len=16, window_size=4)
        assert mask.shape == (1, 1, 16, 16)

    def test_create_sliding_window_mask_values(self):
        from acestep.mlx_dit.model import _create_sliding_window_mask

        mask = _create_sliding_window_mask(seq_len=8, window_size=2)
        mask_np = np.array(mask[0, 0])
        # Diagonal should be 0 (within window)
        for i in range(8):
            assert mask_np[i, i] == 0.0
        # Far off-diagonal should be large negative
        assert mask_np[0, 7] < -1e8
        assert mask_np[7, 0] < -1e8
        # Within window should be 0
        assert mask_np[0, 2] == 0.0  # |0-2| = 2 <= 2
        assert mask_np[0, 3] < -1e8  # |0-3| = 3 > 2


# ======================================================================
# 15. Timestep embedding
# ======================================================================


@requires_mlx
class TestTimestepEmbedding:
    def test_output_shapes(self):
        from acestep.mlx_dit.model import MLXTimestepEmbedding

        emb = MLXTimestepEmbedding(in_channels=32, time_embed_dim=64)
        mx.eval(emb.parameters())
        t = mx.array([0.5, 1.0])
        temb, proj = emb(t)
        mx.eval(temb)
        mx.eval(proj)
        assert temb.shape == (2, 64)
        assert proj.shape == (2, 6, 64)

    def test_different_timesteps_produce_different_embeddings(self):
        from acestep.mlx_dit.model import MLXTimestepEmbedding

        emb = MLXTimestepEmbedding(in_channels=32, time_embed_dim=64)
        mx.eval(emb.parameters())
        t1 = mx.array([0.1])
        t2 = mx.array([0.9])
        temb1, _ = emb(t1)
        temb2, _ = emb(t2)
        mx.eval(temb1)
        mx.eval(temb2)
        # Different timesteps should produce different embeddings
        assert not np.allclose(np.array(temb1), np.array(temb2))


# ======================================================================
# 16. Edge cases
# ======================================================================


@requires_mlx
class TestEdgeCases:
    """Edge cases and boundary conditions."""

    def test_single_step_diffusion(self):
        """A timestep schedule with only 1 step should still produce output."""
        from acestep.mlx_dit.generate import mlx_generate_diffusion

        cfg = types.SimpleNamespace(
            hidden_size=64, intermediate_size=128, num_hidden_layers=2,
            num_attention_heads=4, num_key_value_heads=2, head_dim=16,
            rms_norm_eps=1e-6, attention_bias=False, in_channels=24,
            audio_acoustic_hidden_dim=8, patch_size=2, sliding_window=16,
            layer_types=["sliding_attention", "full_attention"],
            rope_theta=1_000_000.0, max_position_embeddings=512,
        )
        from acestep.mlx_dit.model import MLXDiTDecoder

        dec = MLXDiTDecoder.from_config(cfg)
        mx.eval(dec.parameters())

        B, T, C = 1, 16, 8
        enc_np = np.random.randn(B, 10, 64).astype(np.float32)
        ctx_np = np.random.randn(B, T, 16).astype(np.float32)

        result = mlx_generate_diffusion(
            mlx_decoder=dec,
            encoder_hidden_states_np=enc_np,
            context_latents_np=ctx_np,
            src_latents_shape=(B, T, C),
            seed=42,
            timesteps=[1.0],  # single step
        )
        assert result["target_latents"].shape == (B, T, C)

    def test_cover_strength_1_skips_non_cover(self):
        """With cover_strength=1.0, non-cover conditions should never activate."""
        from acestep.mlx_dit.generate import get_timestep_schedule

        schedule = get_timestep_schedule(shift=3.0)
        num_steps = len(schedule)
        cover_steps = int(num_steps * 1.0)
        # All steps are cover steps, so non-cover never activates
        assert cover_steps == num_steps

    def test_cover_strength_0_switches_immediately(self):
        """With cover_strength=0.0, non-cover should activate from step 0."""
        from acestep.mlx_dit.generate import get_timestep_schedule

        schedule = get_timestep_schedule(shift=3.0)
        num_steps = len(schedule)
        cover_steps = int(num_steps * 0.0)
        assert cover_steps == 0

    def test_output_not_all_zeros(self):
        """Model output should not be all zeros (sanity check)."""
        from acestep.mlx_dit.generate import mlx_generate_diffusion
        from acestep.mlx_dit.model import MLXDiTDecoder

        cfg = types.SimpleNamespace(
            hidden_size=64, intermediate_size=128, num_hidden_layers=2,
            num_attention_heads=4, num_key_value_heads=2, head_dim=16,
            rms_norm_eps=1e-6, attention_bias=False, in_channels=24,
            audio_acoustic_hidden_dim=8, patch_size=2, sliding_window=16,
            layer_types=["sliding_attention", "full_attention"],
            rope_theta=1_000_000.0, max_position_embeddings=512,
        )
        dec = MLXDiTDecoder.from_config(cfg)
        mx.eval(dec.parameters())

        B, T, C = 1, 16, 8
        enc_np = np.random.randn(B, 10, 64).astype(np.float32)
        ctx_np = np.random.randn(B, T, 16).astype(np.float32)

        result = mlx_generate_diffusion(
            mlx_decoder=dec,
            encoder_hidden_states_np=enc_np,
            context_latents_np=ctx_np,
            src_latents_shape=(B, T, C),
            seed=42,
        )
        assert not np.allclose(result["target_latents"], 0.0)

    def test_output_finite(self):
        """Model output should not contain NaN or Inf."""
        from acestep.mlx_dit.generate import mlx_generate_diffusion
        from acestep.mlx_dit.model import MLXDiTDecoder

        cfg = types.SimpleNamespace(
            hidden_size=64, intermediate_size=128, num_hidden_layers=2,
            num_attention_heads=4, num_key_value_heads=2, head_dim=16,
            rms_norm_eps=1e-6, attention_bias=False, in_channels=24,
            audio_acoustic_hidden_dim=8, patch_size=2, sliding_window=16,
            layer_types=["sliding_attention", "full_attention"],
            rope_theta=1_000_000.0, max_position_embeddings=512,
        )
        dec = MLXDiTDecoder.from_config(cfg)
        mx.eval(dec.parameters())

        B, T, C = 1, 16, 8
        enc_np = np.random.randn(B, 10, 64).astype(np.float32)
        ctx_np = np.random.randn(B, T, 16).astype(np.float32)

        result = mlx_generate_diffusion(
            mlx_decoder=dec,
            encoder_hidden_states_np=enc_np,
            context_latents_np=ctx_np,
            src_latents_shape=(B, T, C),
            seed=42,
        )
        assert np.all(np.isfinite(result["target_latents"])), "Output contains NaN or Inf"


# ======================================================================
# 17. DiT layer standalone test
# ======================================================================


@requires_mlx
class TestDiTLayerStandalone:
    def test_dit_layer_forward(self):
        from acestep.mlx_dit.model import MLXDiTLayer

        layer = MLXDiTLayer(
            hidden_size=64,
            intermediate_size=128,
            num_attention_heads=4,
            num_key_value_heads=2,
            head_dim=16,
            rms_norm_eps=1e-6,
            attention_bias=False,
            layer_idx=0,
            layer_type="full_attention",
        )
        mx.eval(layer.parameters())

        B, L, D = 1, 8, 64
        hidden = mx.random.normal((B, L, D))
        cos = mx.random.normal((1, 1, L, 16))
        sin = mx.random.normal((1, 1, L, 16))
        temb = mx.random.normal((B, 6, D))
        enc_hs = mx.random.normal((B, 10, D))

        out = layer(
            hidden_states=hidden,
            position_cos_sin=(cos, sin),
            temb=temb,
            self_attn_mask=None,
            encoder_hidden_states=enc_hs,
            encoder_attention_mask=None,
        )
        mx.eval(out)
        assert out.shape == (B, L, D)

    def test_sliding_attention_layer(self):
        from acestep.mlx_dit.model import MLXDiTLayer, _create_sliding_window_mask

        layer = MLXDiTLayer(
            hidden_size=64,
            intermediate_size=128,
            num_attention_heads=4,
            num_key_value_heads=2,
            head_dim=16,
            rms_norm_eps=1e-6,
            attention_bias=False,
            layer_idx=0,
            layer_type="sliding_attention",
            sliding_window=4,
        )
        mx.eval(layer.parameters())

        B, L, D = 1, 8, 64
        hidden = mx.random.normal((B, L, D))
        cos = mx.random.normal((1, 1, L, 16))
        sin = mx.random.normal((1, 1, L, 16))
        temb = mx.random.normal((B, 6, D))
        enc_hs = mx.random.normal((B, 10, D))
        mask = _create_sliding_window_mask(L, 4)

        out = layer(
            hidden_states=hidden,
            position_cos_sin=(cos, sin),
            temb=temb,
            self_attn_mask=mask,
            encoder_hidden_states=enc_hs,
            encoder_attention_mask=None,
        )
        mx.eval(out)
        assert out.shape == (B, L, D)


# Import needed for TestGradioUIIntegration
import inspect

Summary by CodeRabbit

  • New Features
    • Added MLX DiT option for Apple Silicon with UI checkbox and localized labels (EN/HE/JA/ZH).
  • Behavior / Reliability
    • GPU-tier validations and automatic backend/model fallbacks with clear status warnings when an option is unavailable or downgraded.
  • Generation
    • When MLX DiT is available, generation uses the MLX fast path and falls back gracefully to the default path on failure.

@coderabbitai
Copy link

coderabbitai bot commented Feb 11, 2026

📝 Walkthrough

Walkthrough

Adds optional Apple Silicon MLX support for DiT: availability detection, MLX decoder/model/convert/generate modules, integration into AceStepHandler with init and runtime routing (MLX ↔ PyTorch), UI checkbox and i18n entries, and GPU-tier validations for initialization paths.

Changes

Cohort / File(s) Summary
MLX Foundation
acestep/mlx_dit/__init__.py
Adds MLX availability detection with lazy cache and platform checks.
MLX Model & Conversion
acestep/mlx_dit/model.py, acestep/mlx_dit/convert.py
Implements a full MLX DiT decoder and weight conversion utilities (PyTorch → MLX).
MLX Diffusion Loop
acestep/mlx_dit/generate.py
Adds MLX-native diffusion generation loop, timestep scheduling, and timing reporting returning NumPy outputs.
Handler Integration
acestep/handler.py
Extends AceStepHandler with MLX init, _init_mlx_dit, _mlx_run_diffusion, new state flags, and runtime routing/fallbacks to PyTorch.
Event & Generation Wiring
acestep/gradio_ui/events/__init__.py, acestep/gradio_ui/events/generation_handlers.py
Threads new mlx_dit_checkbox into init_btn inputs and adds mlx_dit param to init_service_wrapper with GPU-tier validations before DiT init.
UI Component & i18n
acestep/gradio_ui/interfaces/generation.py, acestep/gradio_ui/i18n/{en,he,ja,zh}.json
Adds mlx_dit_checkbox UI component and localization keys (mlx_dit_label, mlx_dit_info_enabled, mlx_dit_info_disabled) across languages.

Sequence Diagram

sequenceDiagram
    participant UI as Gradio UI
    participant Event as Event Handler
    participant GenH as Generation Handler
    participant Handler as AceStepHandler
    participant MLX as MLX Module
    participant PT as PyTorch Backend

    UI->>Event: init_btn (mlx_dit_checkbox)
    Event->>GenH: init_service_wrapper(mlx_dit)
    GenH->>GenH: GPU-tier & LM validations
    GenH->>Handler: initialize_service(use_mlx_dit=mlx_dit)
    Handler->>MLX: mlx_available()
    MLX-->>Handler: available / unavailable
    alt MLX available & enabled
        Handler->>MLX: _init_mlx_dit()
        MLX-->>Handler: mlx_decoder ready
    else MLX unavailable or disabled
        Handler-->>Handler: disable MLX path
    end
    Handler-->>GenH: init complete
    UI->>Handler: service_generate(inputs)
    alt MLX decoder ready
        Handler->>MLX: _mlx_run_diffusion(...)
        MLX-->>Handler: results (NumPy)
    else fallback
        Handler->>PT: PyTorch diffusion(...)
        PT-->>Handler: results (Tensor)
    end
    Handler-->>UI: generation results
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

Suggested reviewers

  • ChuxiJ

Poem

🐰✨ I hopped into code on Apple breeze,
I swapped PyTorch hops for native ease,
We convert weights, then run the song,
If MLX falters, PyTorch keeps strong,
Hooray — faster DiT steps, hop along!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 62.50% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: adding a native MLX backend for DiT diffusion inference on Apple Silicon with a claimed 2-3x speedup, which aligns with the substantial changes across multiple modules.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

No actionable comments were generated in the recent review. 🎉


Comment @coderabbitai help to get the list of available commands and usage tips.

@tonyjohnvan
Copy link
Contributor Author

@ChuxiJ Please take a look when you have time, this initial Native MLX DiT acceleration will give roughly 2x to 3x perf boost on my test machine, please test and try at your convince, native support for VAE will be next.

also a new findlng during my test, worth mentioning but not related: aba0e7b breaks the MPS VAE decoding due to vram check returns 0 in MacOS, feel free to patch it.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@acestep/handler.py`:
- Around line 148-214: The MLX path drops attention masks: _mlx_run_diffusion
accepts encoder_attention_mask and encoder_attention_mask_non_cover but never
uses them, so MLX cross-attention may attend padded tokens; update
_mlx_run_diffusion to either (A) forward the masks into mlx_generate_diffusion
(add encoder_attention_mask_np and encoder_attention_mask_non_cover_np
parameters by converting tensors to numpy like enc_np/enc_nc_np) and ensure
mlx_generate_diffusion and mlx_decoder consume them, or (B) assert the masks are
all-ones before calling mlx_generate_diffusion (e.g., check
mask.detach().cpu().numpy().all()) to guarantee parity with PyTorch behavior;
reference function names: _mlx_run_diffusion, mlx_generate_diffusion,
mlx_decoder, encoder_attention_mask, encoder_attention_mask_non_cover when
applying the chosen fix.
- Around line 550-559: The MLX state (self.mlx_decoder and self.use_mlx_dit)
must be explicitly reset whenever MLX init is skipped or fails so
service_generate won't incorrectly take the MLX path; update the block that sets
mlx_dit_status and calls self._init_mlx_dit() so that: when the init path is not
taken because compile_model is True or device not in ("mps","cpu"), set
self.mlx_decoder = None and self.use_mlx_dit = False; and when
self._init_mlx_dit() returns False (mlx_ok is False) also ensure
self.mlx_decoder = None and self.use_mlx_dit = False while setting
mlx_dit_status accordingly; keep mlx_dit_status assignments as shown and
reference the symbols mlx_dit_status, use_mlx_dit, compile_model, device,
self._init_mlx_dit(), self.mlx_decoder, self.use_mlx_dit, and service_generate.

In `@acestep/mlx_dit/generate.py`:
- Around line 185-192: The SDE branch ignores the provided RNG key and uses
global randomness; update the SDE re-noise to derive a per-step key via
mx.random.split from the existing seed/key and pass that key into
mx.random.normal for new_noise. Locate the SDE block around infer_method ==
"sde" (variables t_schedule_list, step_idx, t_curr, pred_clean, new_noise) and
replace the call new_noise = mx.random.normal(xt.shape) with a deterministic
draw using a split key (e.g., split the main RNG with mx.random.split(main_key,
num=total_steps) or split per iteration and use the resulting subkey in
mx.random.normal(shape=xt.shape, key=subkey)), ensuring the main seed/key is
threaded into the loop so each step uses its own derived key.

Comment on lines +148 to +214
def _mlx_run_diffusion(
self,
encoder_hidden_states,
encoder_attention_mask,
context_latents,
src_latents,
seed,
infer_method: str = "ode",
shift: float = 3.0,
timesteps=None,
audio_cover_strength: float = 1.0,
encoder_hidden_states_non_cover=None,
encoder_attention_mask_non_cover=None,
context_latents_non_cover=None,
) -> Dict[str, Any]:
"""Run the diffusion loop using the MLX decoder.

Accepts PyTorch tensors, converts to numpy for MLX, runs the loop,
and converts results back to PyTorch tensors.
"""
import numpy as np
from acestep.mlx_dit.generate import mlx_generate_diffusion

# Convert inputs to numpy (float32)
enc_np = encoder_hidden_states.detach().cpu().float().numpy()
ctx_np = context_latents.detach().cpu().float().numpy()
src_shape = (src_latents.shape[0], src_latents.shape[1], src_latents.shape[2])

enc_nc_np = (
encoder_hidden_states_non_cover.detach().cpu().float().numpy()
if encoder_hidden_states_non_cover is not None else None
)
ctx_nc_np = (
context_latents_non_cover.detach().cpu().float().numpy()
if context_latents_non_cover is not None else None
)

# Convert timesteps tensor if present
ts_list = None
if timesteps is not None:
if hasattr(timesteps, "tolist"):
ts_list = timesteps.tolist()
else:
ts_list = list(timesteps)

result = mlx_generate_diffusion(
mlx_decoder=self.mlx_decoder,
encoder_hidden_states_np=enc_np,
context_latents_np=ctx_np,
src_latents_shape=src_shape,
seed=seed,
infer_method=infer_method,
shift=shift,
timesteps=ts_list,
audio_cover_strength=audio_cover_strength,
encoder_hidden_states_non_cover_np=enc_nc_np,
context_latents_non_cover_np=ctx_nc_np,
)

# Convert result latents back to PyTorch tensor on the correct device
target_np = result["target_latents"]
target_tensor = torch.from_numpy(target_np).to(device=self.device, dtype=self.dtype)

return {
"target_latents": target_tensor,
"time_costs": result["time_costs"],
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Attention masks are dropped in the MLX path.

encoder_attention_mask and encoder_attention_mask_non_cover are accepted but unused. If these masks include padding zeros (likely), MLX cross‑attention will attend to padded tokens and diverge from PyTorch behavior. Please either apply the masks in the MLX decoder path or assert they are all‑ones.

🧰 Tools
🪛 Ruff (0.15.0)

[warning] 151-151: Unused method argument: encoder_attention_mask

(ARG002)


[warning] 160-160: Unused method argument: encoder_attention_mask_non_cover

(ARG002)

🤖 Prompt for AI Agents
In `@acestep/handler.py` around lines 148 - 214, The MLX path drops attention
masks: _mlx_run_diffusion accepts encoder_attention_mask and
encoder_attention_mask_non_cover but never uses them, so MLX cross-attention may
attend padded tokens; update _mlx_run_diffusion to either (A) forward the masks
into mlx_generate_diffusion (add encoder_attention_mask_np and
encoder_attention_mask_non_cover_np parameters by converting tensors to numpy
like enc_np/enc_nc_np) and ensure mlx_generate_diffusion and mlx_decoder consume
them, or (B) assert the masks are all-ones before calling mlx_generate_diffusion
(e.g., check mask.detach().cpu().numpy().all()) to guarantee parity with PyTorch
behavior; reference function names: _mlx_run_diffusion, mlx_generate_diffusion,
mlx_decoder, encoder_attention_mask, encoder_attention_mask_non_cover when
applying the chosen fix.

Comment on lines +550 to 559
# Try to initialize native MLX DiT for Apple Silicon acceleration
mlx_dit_status = "Disabled"
if use_mlx_dit and device in ("mps", "cpu") and not compile_model:
mlx_ok = self._init_mlx_dit()
mlx_dit_status = "Active (native MLX)" if mlx_ok else "Unavailable (PyTorch fallback)"
elif not use_mlx_dit:
mlx_dit_status = "Disabled by user"
self.mlx_decoder = None
self.use_mlx_dit = False

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

MLX state isn’t reset when compile_model/unsupported device skips init.

If MLX was previously initialized, re‑init with compile_model=True or a non‑MPS/CPU device keeps the old mlx_decoder + use_mlx_dit, so service_generate can still take the MLX path despite the guard. Reset state in the non‑MLX branch.

✅ Suggested fix
-            if use_mlx_dit and device in ("mps", "cpu") and not compile_model:
+            if use_mlx_dit and device in ("mps", "cpu") and not compile_model:
                 mlx_ok = self._init_mlx_dit()
                 mlx_dit_status = "Active (native MLX)" if mlx_ok else "Unavailable (PyTorch fallback)"
-            elif not use_mlx_dit:
-                mlx_dit_status = "Disabled by user"
-                self.mlx_decoder = None
-                self.use_mlx_dit = False
+            else:
+                if not use_mlx_dit:
+                    mlx_dit_status = "Disabled by user"
+                elif compile_model:
+                    mlx_dit_status = "Disabled (torch.compile enabled)"
+                else:
+                    mlx_dit_status = "Unavailable (PyTorch fallback)"
+                self.mlx_decoder = None
+                self.use_mlx_dit = False
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Try to initialize native MLX DiT for Apple Silicon acceleration
mlx_dit_status = "Disabled"
if use_mlx_dit and device in ("mps", "cpu") and not compile_model:
mlx_ok = self._init_mlx_dit()
mlx_dit_status = "Active (native MLX)" if mlx_ok else "Unavailable (PyTorch fallback)"
elif not use_mlx_dit:
mlx_dit_status = "Disabled by user"
self.mlx_decoder = None
self.use_mlx_dit = False
# Try to initialize native MLX DiT for Apple Silicon acceleration
mlx_dit_status = "Disabled"
if use_mlx_dit and device in ("mps", "cpu") and not compile_model:
mlx_ok = self._init_mlx_dit()
mlx_dit_status = "Active (native MLX)" if mlx_ok else "Unavailable (PyTorch fallback)"
else:
if not use_mlx_dit:
mlx_dit_status = "Disabled by user"
elif compile_model:
mlx_dit_status = "Disabled (torch.compile enabled)"
else:
mlx_dit_status = "Unavailable (PyTorch fallback)"
self.mlx_decoder = None
self.use_mlx_dit = False
🤖 Prompt for AI Agents
In `@acestep/handler.py` around lines 550 - 559, The MLX state (self.mlx_decoder
and self.use_mlx_dit) must be explicitly reset whenever MLX init is skipped or
fails so service_generate won't incorrectly take the MLX path; update the block
that sets mlx_dit_status and calls self._init_mlx_dit() so that: when the init
path is not taken because compile_model is True or device not in ("mps","cpu"),
set self.mlx_decoder = None and self.use_mlx_dit = False; and when
self._init_mlx_dit() returns False (mlx_ok is False) also ensure
self.mlx_decoder = None and self.use_mlx_dit = False while setting
mlx_dit_status accordingly; keep mlx_dit_status assignments as shown and
reference the symbols mlx_dit_status, use_mlx_dit, compile_model, device,
self._init_mlx_dit(), self.mlx_decoder, self.use_mlx_dit, and service_generate.

Comment on lines +185 to +192
# ODE / SDE update
next_t = t_schedule_list[step_idx + 1]
if infer_method == "sde":
t_unsq = mx.expand_dims(mx.expand_dims(t_curr, axis=-1), axis=-1)
pred_clean = xt - vt * t_unsq
# Re-noise with next timestep
new_noise = mx.random.normal(xt.shape)
xt = next_t * new_noise + (1.0 - next_t) * pred_clean
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

head -n 200 acestep/mlx_dit/generate.py | tail -n 50

Repository: ace-step/ACE-Step-1.5

Length of output: 1634


🏁 Script executed:

# Get the full context around lines 185-192
sed -n '170,210p' acestep/mlx_dit/generate.py

Repository: ace-step/ACE-Step-1.5

Length of output: 1431


🏁 Script executed:

# Check what parameters the generate function takes, especially seed
sed -n '1,50p' acestep/mlx_dit/generate.py | head -n 40

Repository: ace-step/ACE-Step-1.5

Length of output: 1377


🏁 Script executed:

# Find the generate function signature and seed handling
sed -n '40,150p' acestep/mlx_dit/generate.py

Repository: ace-step/ACE-Step-1.5

Length of output: 4187


🏁 Script executed:

# Look for where noise is initially created with seed
rg "seed" acestep/mlx_dit/generate.py -B 2 -A 2

Repository: ace-step/ACE-Step-1.5

Length of output: 946


🏁 Script executed:

# Search for mx.random.key and mx.random.split usage to verify API
rg "mx\.random\.(key|split)" acestep/ -B 1 -A 1

Repository: ace-step/ACE-Step-1.5

Length of output: 468


🏁 Script executed:

# Check for any other MLX RNG usage patterns in the codebase
rg "mx\.random" acestep/ -A 1 | head -50

Repository: ace-step/ACE-Step-1.5

Length of output: 1154


🏁 Script executed:

# Look at the full loop structure to understand variable scope
sed -n '140,230p' acestep/mlx_dit/generate.py

Repository: ace-step/ACE-Step-1.5

Length of output: 2424


🏁 Script executed:

# Check if mx.random.split exists in MLX docs or examples in the repo
rg "split" acestep/ | grep -i random

Repository: ace-step/ACE-Step-1.5

Length of output: 182


🏁 Script executed:

# Check MLX imports to understand available functions
head -20 acestep/mlx_dit/generate.py

Repository: ace-step/ACE-Step-1.5

Length of output: 727


🌐 Web query:

MLX random.split function documentation

💡 Result:

mlx.core.random.split (also available as mx.random.split) splits a PRNG key into multiple independent subkeys.

Signature

  • split(key: array, num: int = 2, stream: None | Stream | Device = None) -> array [1]

Parameters

  • key: input PRNG key to split [1]
  • num: number of subkeys to generate (default 2) [1]
  • stream: optional execution stream/device [1]

Returns

  • An array of subkeys, with num as the first dimension [1]

Docs: [1]

Related (Random module overview)

  • MLX random functions use an implicit global PRNG by default, but accept an optional key for explicit control; MLX uses a splittable Threefry-style PRNG (JAX-like design). [2]

Sources

Citations:


🏁 Script executed:

# Search for any mx.random.split usage examples in the repo
rg "mx\.random" acestep/ -A 3 -B 1 | grep -A 5 -B 5 "split"

Repository: ace-step/ACE-Step-1.5

Length of output: 47


🏁 Script executed:

# Check if there are any test files or examples that show MLX RNG patterns
find acestep -name "*.py" -type f -exec grep -l "mx.random" {} \;

Repository: ace-step/ACE-Step-1.5

Length of output: 117


SDE path ignores the provided seed, breaking determinism.

The initial noise is seeded with mx.random.key(), but line 188 calls new_noise = mx.random.normal(xt.shape) without a key parameter. This means every SDE step uses the global random state, making runs non-deterministic even with an identical seed—unlike the PyTorch implementation.

Derive independent per-step RNG keys from the provided seed using mx.random.split():

Suggested fix
-    if seed is None:
+    rng_key = None
+    if seed is None:
         noise = mx.random.normal((bsz, T, C))
-    else:
+    else:
         key = mx.random.key(int(seed))
         noise = mx.random.normal((bsz, T, C), key=key)
+        rng_key = key
 
         # ---- Diffusion loop ----
         ...
         if infer_method == "sde":
             t_unsq = mx.expand_dims(mx.expand_dims(t_curr, axis=-1), axis=-1)
             pred_clean = xt - vt * t_unsq
             # Re-noise with next timestep
-            new_noise = mx.random.normal(xt.shape)
+            if rng_key is not None:
+                rng_key, step_key = mx.random.split(rng_key)
+                new_noise = mx.random.normal(xt.shape, key=step_key)
+            else:
+                new_noise = mx.random.normal(xt.shape)
🤖 Prompt for AI Agents
In `@acestep/mlx_dit/generate.py` around lines 185 - 192, The SDE branch ignores
the provided RNG key and uses global randomness; update the SDE re-noise to
derive a per-step key via mx.random.split from the existing seed/key and pass
that key into mx.random.normal for new_noise. Locate the SDE block around
infer_method == "sde" (variables t_schedule_list, step_idx, t_curr, pred_clean,
new_noise) and replace the call new_noise = mx.random.normal(xt.shape) with a
deterministic draw using a split key (e.g., split the main RNG with
mx.random.split(main_key, num=total_steps) or split per iteration and use the
resulting subkey in mx.random.normal(shape=xt.shape, key=subkey)), ensuring the
main seed/key is threaded into the loop so each step uses its own derived key.

@ChuxiJ
Copy link
Contributor

ChuxiJ commented Feb 11, 2026

Thanks a lot for your contribution to adding DiT support on MLX! It’s great to see the speed has been boosted by 2–3x—this is really helpful.

If you have time, could you also take a look at VAE? It would be awesome if we could get it working with MLX too.

Thanks again for your work!

@tonyjohnvan
Copy link
Contributor Author

Thanks a lot for your contribution to adding DiT support on MLX! It’s great to see the speed has been boosted by 2–3x—this is really helpful.

If you have time, could you also take a look at VAE? It would be awesome if we could get it working with MLX too.

Thanks again for your work!

You are very welcome! yup definitely native support for VAE is under my radar, taking a look now, if this PR is safe to merge I'll based on this one or the new main for the MLX VAE

@ChuxiJ ChuxiJ merged commit 9a3f495 into ace-step:main Feb 11, 2026
1 check passed
@ChuxiJ
Copy link
Contributor

ChuxiJ commented Feb 11, 2026

@tonyjohnvan
image
I have merged this PR. There are two issues you need to pay attention to:

  1. The model is not compiled by default on Mac. I found that if I enable compilation, it fails because it cannot use dit_mlx as the backend.
  2. The VAE decoder is extremely slow on my end. The cause is still unknown. You may need to look into this as well.

@tonyjohnvan
Copy link
Contributor Author

tonyjohnvan commented Feb 11, 2026

@tonyjohnvan image I have merged this PR. There are two issues you need to pay attention to:

  1. The model is not compiled by default on Mac. I found that if I enable compilation, it fails because it cannot use dit_mlx as the backend.
  2. The VAE decoder is extremely slow on my end. The cause is still unknown. You may need to look into this as well.

@ChuxiJ

  1. got it, will work on that in a separate PR
  2. yup this is as what I mentioned in the commit message:

Seems Recent main introduced changes that failed to use MPS vae decode due to 0GB vram detection, resulting extremely slow vae decoding aba0e7b

2026-02-10 23:47:45.444 | INFO     | acestep.handler:generate_music:3604 - [generate_music] Decoding latents with VAE...
2026-02-10 23:47:45.600 | DEBUG    | acestep.handler:generate_music:3622 - [generate_music] Before VAE decode: allocated=0.00GB, max=0.00GB
2026-02-10 23:47:45.600 | INFO     | acestep.handler:generate_music:3629 - [generate_music] Effective free VRAM before VAE decode: 0.00 GB
2026-02-10 23:47:45.600 | WARNING  | acestep.handler:generate_music:3632 - [generate_music] Only 0.00 GB free VRAM — auto-enabling CPU VAE decode
2026-02-10 23:47:45.600 | INFO     | acestep.handler:generate_music:3635 - [generate_music] Moving VAE to CPU for decode (ACESTEP_VAE_ON_CPU=1)...
2026-02-10 23:47:45.995 | INFO     | acestep.handler:generate_music:3642 - [generate_music] Using tiled VAE decode to reduce VRAM usage...

so the VAE fall back to limited CPU decoding which is extremely slow, I have it fixed in my wip MLX VAE branch, but seems you also fixed it in this PR? #443

@coderabbitai coderabbitai bot mentioned this pull request Feb 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants