-
Notifications
You must be signed in to change notification settings - Fork 582
feat(mlx): Native MLX backend for DiT diffusion on Apple Silicon (2-3x speedup) #439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…pple Silicon with availability check
…ceStep DiT decoder in MLX format
…execution for Apple Silicon
…ed initialization
…anese, and Chinese
📝 WalkthroughWalkthroughAdds optional Apple Silicon MLX support for DiT: availability detection, MLX decoder/model/convert/generate modules, integration into AceStepHandler with init and runtime routing (MLX ↔ PyTorch), UI checkbox and i18n entries, and GPU-tier validations for initialization paths. Changes
Sequence DiagramsequenceDiagram
participant UI as Gradio UI
participant Event as Event Handler
participant GenH as Generation Handler
participant Handler as AceStepHandler
participant MLX as MLX Module
participant PT as PyTorch Backend
UI->>Event: init_btn (mlx_dit_checkbox)
Event->>GenH: init_service_wrapper(mlx_dit)
GenH->>GenH: GPU-tier & LM validations
GenH->>Handler: initialize_service(use_mlx_dit=mlx_dit)
Handler->>MLX: mlx_available()
MLX-->>Handler: available / unavailable
alt MLX available & enabled
Handler->>MLX: _init_mlx_dit()
MLX-->>Handler: mlx_decoder ready
else MLX unavailable or disabled
Handler-->>Handler: disable MLX path
end
Handler-->>GenH: init complete
UI->>Handler: service_generate(inputs)
alt MLX decoder ready
Handler->>MLX: _mlx_run_diffusion(...)
MLX-->>Handler: results (NumPy)
else fallback
Handler->>PT: PyTorch diffusion(...)
PT-->>Handler: results (Tensor)
end
Handler-->>UI: generation results
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
No actionable comments were generated in the recent review. 🎉 Comment |
|
@ChuxiJ Please take a look when you have time, this initial Native MLX DiT acceleration will give roughly 2x to 3x perf boost on my test machine, please test and try at your convince, native support for VAE will be next. also a new findlng during my test, worth mentioning but not related: aba0e7b breaks the MPS VAE decoding due to vram check returns 0 in MacOS, feel free to patch it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@acestep/handler.py`:
- Around line 148-214: The MLX path drops attention masks: _mlx_run_diffusion
accepts encoder_attention_mask and encoder_attention_mask_non_cover but never
uses them, so MLX cross-attention may attend padded tokens; update
_mlx_run_diffusion to either (A) forward the masks into mlx_generate_diffusion
(add encoder_attention_mask_np and encoder_attention_mask_non_cover_np
parameters by converting tensors to numpy like enc_np/enc_nc_np) and ensure
mlx_generate_diffusion and mlx_decoder consume them, or (B) assert the masks are
all-ones before calling mlx_generate_diffusion (e.g., check
mask.detach().cpu().numpy().all()) to guarantee parity with PyTorch behavior;
reference function names: _mlx_run_diffusion, mlx_generate_diffusion,
mlx_decoder, encoder_attention_mask, encoder_attention_mask_non_cover when
applying the chosen fix.
- Around line 550-559: The MLX state (self.mlx_decoder and self.use_mlx_dit)
must be explicitly reset whenever MLX init is skipped or fails so
service_generate won't incorrectly take the MLX path; update the block that sets
mlx_dit_status and calls self._init_mlx_dit() so that: when the init path is not
taken because compile_model is True or device not in ("mps","cpu"), set
self.mlx_decoder = None and self.use_mlx_dit = False; and when
self._init_mlx_dit() returns False (mlx_ok is False) also ensure
self.mlx_decoder = None and self.use_mlx_dit = False while setting
mlx_dit_status accordingly; keep mlx_dit_status assignments as shown and
reference the symbols mlx_dit_status, use_mlx_dit, compile_model, device,
self._init_mlx_dit(), self.mlx_decoder, self.use_mlx_dit, and service_generate.
In `@acestep/mlx_dit/generate.py`:
- Around line 185-192: The SDE branch ignores the provided RNG key and uses
global randomness; update the SDE re-noise to derive a per-step key via
mx.random.split from the existing seed/key and pass that key into
mx.random.normal for new_noise. Locate the SDE block around infer_method ==
"sde" (variables t_schedule_list, step_idx, t_curr, pred_clean, new_noise) and
replace the call new_noise = mx.random.normal(xt.shape) with a deterministic
draw using a split key (e.g., split the main RNG with mx.random.split(main_key,
num=total_steps) or split per iteration and use the resulting subkey in
mx.random.normal(shape=xt.shape, key=subkey)), ensuring the main seed/key is
threaded into the loop so each step uses its own derived key.
| def _mlx_run_diffusion( | ||
| self, | ||
| encoder_hidden_states, | ||
| encoder_attention_mask, | ||
| context_latents, | ||
| src_latents, | ||
| seed, | ||
| infer_method: str = "ode", | ||
| shift: float = 3.0, | ||
| timesteps=None, | ||
| audio_cover_strength: float = 1.0, | ||
| encoder_hidden_states_non_cover=None, | ||
| encoder_attention_mask_non_cover=None, | ||
| context_latents_non_cover=None, | ||
| ) -> Dict[str, Any]: | ||
| """Run the diffusion loop using the MLX decoder. | ||
|
|
||
| Accepts PyTorch tensors, converts to numpy for MLX, runs the loop, | ||
| and converts results back to PyTorch tensors. | ||
| """ | ||
| import numpy as np | ||
| from acestep.mlx_dit.generate import mlx_generate_diffusion | ||
|
|
||
| # Convert inputs to numpy (float32) | ||
| enc_np = encoder_hidden_states.detach().cpu().float().numpy() | ||
| ctx_np = context_latents.detach().cpu().float().numpy() | ||
| src_shape = (src_latents.shape[0], src_latents.shape[1], src_latents.shape[2]) | ||
|
|
||
| enc_nc_np = ( | ||
| encoder_hidden_states_non_cover.detach().cpu().float().numpy() | ||
| if encoder_hidden_states_non_cover is not None else None | ||
| ) | ||
| ctx_nc_np = ( | ||
| context_latents_non_cover.detach().cpu().float().numpy() | ||
| if context_latents_non_cover is not None else None | ||
| ) | ||
|
|
||
| # Convert timesteps tensor if present | ||
| ts_list = None | ||
| if timesteps is not None: | ||
| if hasattr(timesteps, "tolist"): | ||
| ts_list = timesteps.tolist() | ||
| else: | ||
| ts_list = list(timesteps) | ||
|
|
||
| result = mlx_generate_diffusion( | ||
| mlx_decoder=self.mlx_decoder, | ||
| encoder_hidden_states_np=enc_np, | ||
| context_latents_np=ctx_np, | ||
| src_latents_shape=src_shape, | ||
| seed=seed, | ||
| infer_method=infer_method, | ||
| shift=shift, | ||
| timesteps=ts_list, | ||
| audio_cover_strength=audio_cover_strength, | ||
| encoder_hidden_states_non_cover_np=enc_nc_np, | ||
| context_latents_non_cover_np=ctx_nc_np, | ||
| ) | ||
|
|
||
| # Convert result latents back to PyTorch tensor on the correct device | ||
| target_np = result["target_latents"] | ||
| target_tensor = torch.from_numpy(target_np).to(device=self.device, dtype=self.dtype) | ||
|
|
||
| return { | ||
| "target_latents": target_tensor, | ||
| "time_costs": result["time_costs"], | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Attention masks are dropped in the MLX path.
encoder_attention_mask and encoder_attention_mask_non_cover are accepted but unused. If these masks include padding zeros (likely), MLX cross‑attention will attend to padded tokens and diverge from PyTorch behavior. Please either apply the masks in the MLX decoder path or assert they are all‑ones.
🧰 Tools
🪛 Ruff (0.15.0)
[warning] 151-151: Unused method argument: encoder_attention_mask
(ARG002)
[warning] 160-160: Unused method argument: encoder_attention_mask_non_cover
(ARG002)
🤖 Prompt for AI Agents
In `@acestep/handler.py` around lines 148 - 214, The MLX path drops attention
masks: _mlx_run_diffusion accepts encoder_attention_mask and
encoder_attention_mask_non_cover but never uses them, so MLX cross-attention may
attend padded tokens; update _mlx_run_diffusion to either (A) forward the masks
into mlx_generate_diffusion (add encoder_attention_mask_np and
encoder_attention_mask_non_cover_np parameters by converting tensors to numpy
like enc_np/enc_nc_np) and ensure mlx_generate_diffusion and mlx_decoder consume
them, or (B) assert the masks are all-ones before calling mlx_generate_diffusion
(e.g., check mask.detach().cpu().numpy().all()) to guarantee parity with PyTorch
behavior; reference function names: _mlx_run_diffusion, mlx_generate_diffusion,
mlx_decoder, encoder_attention_mask, encoder_attention_mask_non_cover when
applying the chosen fix.
| # Try to initialize native MLX DiT for Apple Silicon acceleration | ||
| mlx_dit_status = "Disabled" | ||
| if use_mlx_dit and device in ("mps", "cpu") and not compile_model: | ||
| mlx_ok = self._init_mlx_dit() | ||
| mlx_dit_status = "Active (native MLX)" if mlx_ok else "Unavailable (PyTorch fallback)" | ||
| elif not use_mlx_dit: | ||
| mlx_dit_status = "Disabled by user" | ||
| self.mlx_decoder = None | ||
| self.use_mlx_dit = False | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MLX state isn’t reset when compile_model/unsupported device skips init.
If MLX was previously initialized, re‑init with compile_model=True or a non‑MPS/CPU device keeps the old mlx_decoder + use_mlx_dit, so service_generate can still take the MLX path despite the guard. Reset state in the non‑MLX branch.
✅ Suggested fix
- if use_mlx_dit and device in ("mps", "cpu") and not compile_model:
+ if use_mlx_dit and device in ("mps", "cpu") and not compile_model:
mlx_ok = self._init_mlx_dit()
mlx_dit_status = "Active (native MLX)" if mlx_ok else "Unavailable (PyTorch fallback)"
- elif not use_mlx_dit:
- mlx_dit_status = "Disabled by user"
- self.mlx_decoder = None
- self.use_mlx_dit = False
+ else:
+ if not use_mlx_dit:
+ mlx_dit_status = "Disabled by user"
+ elif compile_model:
+ mlx_dit_status = "Disabled (torch.compile enabled)"
+ else:
+ mlx_dit_status = "Unavailable (PyTorch fallback)"
+ self.mlx_decoder = None
+ self.use_mlx_dit = False📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Try to initialize native MLX DiT for Apple Silicon acceleration | |
| mlx_dit_status = "Disabled" | |
| if use_mlx_dit and device in ("mps", "cpu") and not compile_model: | |
| mlx_ok = self._init_mlx_dit() | |
| mlx_dit_status = "Active (native MLX)" if mlx_ok else "Unavailable (PyTorch fallback)" | |
| elif not use_mlx_dit: | |
| mlx_dit_status = "Disabled by user" | |
| self.mlx_decoder = None | |
| self.use_mlx_dit = False | |
| # Try to initialize native MLX DiT for Apple Silicon acceleration | |
| mlx_dit_status = "Disabled" | |
| if use_mlx_dit and device in ("mps", "cpu") and not compile_model: | |
| mlx_ok = self._init_mlx_dit() | |
| mlx_dit_status = "Active (native MLX)" if mlx_ok else "Unavailable (PyTorch fallback)" | |
| else: | |
| if not use_mlx_dit: | |
| mlx_dit_status = "Disabled by user" | |
| elif compile_model: | |
| mlx_dit_status = "Disabled (torch.compile enabled)" | |
| else: | |
| mlx_dit_status = "Unavailable (PyTorch fallback)" | |
| self.mlx_decoder = None | |
| self.use_mlx_dit = False |
🤖 Prompt for AI Agents
In `@acestep/handler.py` around lines 550 - 559, The MLX state (self.mlx_decoder
and self.use_mlx_dit) must be explicitly reset whenever MLX init is skipped or
fails so service_generate won't incorrectly take the MLX path; update the block
that sets mlx_dit_status and calls self._init_mlx_dit() so that: when the init
path is not taken because compile_model is True or device not in ("mps","cpu"),
set self.mlx_decoder = None and self.use_mlx_dit = False; and when
self._init_mlx_dit() returns False (mlx_ok is False) also ensure
self.mlx_decoder = None and self.use_mlx_dit = False while setting
mlx_dit_status accordingly; keep mlx_dit_status assignments as shown and
reference the symbols mlx_dit_status, use_mlx_dit, compile_model, device,
self._init_mlx_dit(), self.mlx_decoder, self.use_mlx_dit, and service_generate.
| # ODE / SDE update | ||
| next_t = t_schedule_list[step_idx + 1] | ||
| if infer_method == "sde": | ||
| t_unsq = mx.expand_dims(mx.expand_dims(t_curr, axis=-1), axis=-1) | ||
| pred_clean = xt - vt * t_unsq | ||
| # Re-noise with next timestep | ||
| new_noise = mx.random.normal(xt.shape) | ||
| xt = next_t * new_noise + (1.0 - next_t) * pred_clean |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
head -n 200 acestep/mlx_dit/generate.py | tail -n 50Repository: ace-step/ACE-Step-1.5
Length of output: 1634
🏁 Script executed:
# Get the full context around lines 185-192
sed -n '170,210p' acestep/mlx_dit/generate.pyRepository: ace-step/ACE-Step-1.5
Length of output: 1431
🏁 Script executed:
# Check what parameters the generate function takes, especially seed
sed -n '1,50p' acestep/mlx_dit/generate.py | head -n 40Repository: ace-step/ACE-Step-1.5
Length of output: 1377
🏁 Script executed:
# Find the generate function signature and seed handling
sed -n '40,150p' acestep/mlx_dit/generate.pyRepository: ace-step/ACE-Step-1.5
Length of output: 4187
🏁 Script executed:
# Look for where noise is initially created with seed
rg "seed" acestep/mlx_dit/generate.py -B 2 -A 2Repository: ace-step/ACE-Step-1.5
Length of output: 946
🏁 Script executed:
# Search for mx.random.key and mx.random.split usage to verify API
rg "mx\.random\.(key|split)" acestep/ -B 1 -A 1Repository: ace-step/ACE-Step-1.5
Length of output: 468
🏁 Script executed:
# Check for any other MLX RNG usage patterns in the codebase
rg "mx\.random" acestep/ -A 1 | head -50Repository: ace-step/ACE-Step-1.5
Length of output: 1154
🏁 Script executed:
# Look at the full loop structure to understand variable scope
sed -n '140,230p' acestep/mlx_dit/generate.pyRepository: ace-step/ACE-Step-1.5
Length of output: 2424
🏁 Script executed:
# Check if mx.random.split exists in MLX docs or examples in the repo
rg "split" acestep/ | grep -i randomRepository: ace-step/ACE-Step-1.5
Length of output: 182
🏁 Script executed:
# Check MLX imports to understand available functions
head -20 acestep/mlx_dit/generate.pyRepository: ace-step/ACE-Step-1.5
Length of output: 727
🌐 Web query:
MLX random.split function documentation
💡 Result:
mlx.core.random.split (also available as mx.random.split) splits a PRNG key into multiple independent subkeys.
Signature
split(key: array, num: int = 2, stream: None | Stream | Device = None) -> array[1]
Parameters
key: input PRNG key to split [1]num: number of subkeys to generate (default2) [1]stream: optional execution stream/device [1]
Returns
- An array of subkeys, with
numas the first dimension [1]
Docs: [1]
Related (Random module overview)
- MLX random functions use an implicit global PRNG by default, but accept an optional
keyfor explicit control; MLX uses a splittable Threefry-style PRNG (JAX-like design). [2]
Sources
- [1] (ml-explore.github.io)
- [2] (ml-explore.github.io)
Citations:
- 1: https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.random.split.html?utm_source=openai
- 2: https://ml-explore.github.io/mlx/build/html/python/random.html?utm_source=openai
🏁 Script executed:
# Search for any mx.random.split usage examples in the repo
rg "mx\.random" acestep/ -A 3 -B 1 | grep -A 5 -B 5 "split"Repository: ace-step/ACE-Step-1.5
Length of output: 47
🏁 Script executed:
# Check if there are any test files or examples that show MLX RNG patterns
find acestep -name "*.py" -type f -exec grep -l "mx.random" {} \;Repository: ace-step/ACE-Step-1.5
Length of output: 117
SDE path ignores the provided seed, breaking determinism.
The initial noise is seeded with mx.random.key(), but line 188 calls new_noise = mx.random.normal(xt.shape) without a key parameter. This means every SDE step uses the global random state, making runs non-deterministic even with an identical seed—unlike the PyTorch implementation.
Derive independent per-step RNG keys from the provided seed using mx.random.split():
Suggested fix
- if seed is None:
+ rng_key = None
+ if seed is None:
noise = mx.random.normal((bsz, T, C))
- else:
+ else:
key = mx.random.key(int(seed))
noise = mx.random.normal((bsz, T, C), key=key)
+ rng_key = key
# ---- Diffusion loop ----
...
if infer_method == "sde":
t_unsq = mx.expand_dims(mx.expand_dims(t_curr, axis=-1), axis=-1)
pred_clean = xt - vt * t_unsq
# Re-noise with next timestep
- new_noise = mx.random.normal(xt.shape)
+ if rng_key is not None:
+ rng_key, step_key = mx.random.split(rng_key)
+ new_noise = mx.random.normal(xt.shape, key=step_key)
+ else:
+ new_noise = mx.random.normal(xt.shape)🤖 Prompt for AI Agents
In `@acestep/mlx_dit/generate.py` around lines 185 - 192, The SDE branch ignores
the provided RNG key and uses global randomness; update the SDE re-noise to
derive a per-step key via mx.random.split from the existing seed/key and pass
that key into mx.random.normal for new_noise. Locate the SDE block around
infer_method == "sde" (variables t_schedule_list, step_idx, t_curr, pred_clean,
new_noise) and replace the call new_noise = mx.random.normal(xt.shape) with a
deterministic draw using a split key (e.g., split the main RNG with
mx.random.split(main_key, num=total_steps) or split per iteration and use the
resulting subkey in mx.random.normal(shape=xt.shape, key=subkey)), ensuring the
main seed/key is threaded into the loop so each step uses its own derived key.
|
Thanks a lot for your contribution to adding DiT support on MLX! It’s great to see the speed has been boosted by 2–3x—this is really helpful. If you have time, could you also take a look at VAE? It would be awesome if we could get it working with MLX too. Thanks again for your work! |
You are very welcome! yup definitely native support for VAE is under my radar, taking a look now, if this PR is safe to merge I'll based on this one or the new main for the MLX VAE |
|
@tonyjohnvan
|
Seems Recent main introduced changes that failed to use MPS vae decode due to 0GB vram detection, resulting extremely slow vae decoding aba0e7b so the VAE fall back to limited CPU decoding which is extremely slow, I have it fixed in my wip MLX VAE branch, but seems you also fixed it in this PR? #443 |


Summary
What feature is addressed:
Adds a native Apple MLX backend for the DiT (Diffusion Transformer) decoder inference loop, replacing the PyTorch MPS path on Apple Silicon Macs. The DiT diffusion loop is the single most expensive phase of audio generation -- this change reimplements it in pure MLX to bypass PyTorch-to-MPS overhead entirely.
Why this change is needed:
On Apple Silicon, PyTorch's MPS backend incurs significant dispatch and synchronization overhead for the iterative diffusion loop (8 transformer forward passes per generation). MLX's Metal-native graph execution eliminates this overhead. Real-world benchmarks show 2-3x wall-clock speedup for the DiT diffusion phase on M-series chips, making interactive music generation practical on consumer MacBooks.
The feature is entirely opt-in (checkbox in Gradio UI), auto-detected at init, and falls back gracefully to the existing PyTorch path on any failure -- no existing behavior is affected.
Scope
Files changed (12 modified/added, +1164 / -9 lines)
acestep/mlx_dit/__init__.pymlx_available()gate.acestep/mlx_dit/model.pyAceStepDiTModel: rotary embeddings, multi-head attention with QK-RMSNorm, GQA, sliding window masking, AdaLN DiT layers, timestep embeddings, Conv1d patch embedding, ConvTranspose1d de-patchify.acestep/mlx_dit/convert.py[out,in,K]->[out,K,in], ConvTranspose1d[in,out,K]->[out,K,in], Sequential index stripping, rotary buffer skipping.acestep/mlx_dit/generate.pyacestep/handler.py_init_mlx_dit(),_mlx_run_diffusion(), MLX fast-path inservice_generate()with try/except fallback,use_mlx_ditparam ininitialize_service(), MLX status in init output.acestep/gradio_ui/interfaces/generation.pyacestep/gradio_ui/events/__init__.pymlx_dit_checkboxinto init button inputs.acestep/gradio_ui/events/generation_handlers.pymlx_ditparam throughinit_service_wrappertoinitialize_service.acestep/gradio_ui/i18n/{en,he,ja,zh}.jsonmlx_dit_label,mlx_dit_info_enabled,mlx_dit_info_disabled.tests/test_mlx_dit.pyWhat is explicitly out of scope
device in ("mps", "cpu")torch.compiledisables the MLX path (not compile_modelgate)Risk and Compatibility
Target platform / path
mlxpip package installeduse_mlx_dit=TrueANDdevice in ("mps", "cpu")ANDnot compile_modelConfirmation that non-target paths are unchanged
is_mlx_available()returnsFalseon non-Darwin -> MLX init is never attempted. The only code touched inhandler.py's generation path is behindif self.use_mlx_dit and self.mlx_decoder is not None, which is alwaysFalseon non-Apple platforms. Theelsebranch calls the originalself.model.generate_audio(**generate_kwargs)identically tomain.use_mlx_dit=False-> handler explicitly setsself.mlx_decoder = None; self.use_mlx_dit = False-> PyTorch path runs._init_mlx_dit()catches all exceptions, logs a warning, setsmlx_decoder=None, returnsFalse. Status message shows "Unavailable (PyTorch fallback)".service_generateMLX fast-path is wrapped intry/except Exceptionwhich logs a warning and falls back toself.model.generate_audio()in the same request.logger.warning->logger.infofor the MPS VAE chunk-size reduction log message (cosmetic, reduces log noise).Multi-layer fallback chain
Every level is non-fatal. The worst case is a log warning + PyTorch fallback with zero behavioral change.
Regression Checks
Automated tests (74 tests, all passing)
Run:
conda run -n ace python -m pytest test_mlx_dit.py -v(4.25s)TestMLXAvailabilityDetectionis_mlx_available()on Darwin/non-Darwin, import failure, cachingTestWeightConversion[out,in,K]->[out,K,in], ConvTranspose1d[in,out,K]->[out,K,in], key remapping, rotary skip,convert_and_loadintegrationTestTimestepScheduleTestMLXModelArchitecturefrom_config, forward output shapes, batch>1, odd seq padding/crop, KV cache population, sliding mask cachingTestMLXCrossAttentionCacheTestRotaryEmbeddingTestSwiGLUMLPTestMLXDiffusionLoopTestHandlerMLXIntegration_init_mlx_ditsuccess/failure/skip, tensor conversion roundtrip, None non-cover params,__init__defaultsTestHandlerInitializeServiceMLXParamuse_mlx_ditin signature with default=True,last_init_paramsstorageTestGradioUIIntegrationmlx_ditparam ininit_service_wrapper, checkbox wired in eventsTestI18NKeysTestPyTorchFallbackPreservedTestUtilityFunctions_rotate_halfvalues,_apply_rotary_pos_embshapes, sliding mask shape + valuesTestTimestepEmbeddingTestEdgeCasesTestDiTLayerStandaloneKey scenarios validated
mlx_decoder=None, PyTorch path onlymps/cpu, never oncuda; blocked whencompile_model=TrueReviewer Notes
Known pre-existing issues not addressed
torchaodeprecation warnings in test output are from upstream, not introduced by this PRFollow-up items
PS: the test code
Summary by CodeRabbit