-
Notifications
You must be signed in to change notification settings - Fork 586
feat(mlx): Native MLX backend for DiT diffusion on Apple Silicon (2-3x speedup) #439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
63ddff2
a12a25a
ea724bd
e5ff0fc
d1389aa
ba65120
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -111,7 +111,108 @@ def __init__(self): | |||||||||||||||||||||||||||||||||||||||||||||||||
| self._base_decoder = None # Backup of original decoder | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self._lora_adapter_registry = {} # adapter_name -> explicit scaling targets | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self._lora_active_adapter = None | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # MLX DiT acceleration (macOS Apple Silicon only) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.mlx_decoder = None | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.use_mlx_dit = False | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # ------------------------------------------------------------------ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # MLX DiT acceleration helpers | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # ------------------------------------------------------------------ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| def _init_mlx_dit(self) -> bool: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """Try to initialize the native MLX DiT decoder for Apple Silicon. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns True on success, False on failure (non-fatal). | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from acestep.mlx_dit import mlx_available | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if not mlx_available(): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.info("[MLX-DiT] MLX not available on this platform; skipping.") | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| from acestep.mlx_dit.model import MLXDiTDecoder | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from acestep.mlx_dit.convert import convert_and_load | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| mlx_decoder = MLXDiTDecoder.from_config(self.config) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| convert_and_load(self.model, mlx_decoder) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.mlx_decoder = mlx_decoder | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.use_mlx_dit = True | ||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.info("[MLX-DiT] Native MLX DiT decoder initialized successfully.") | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as exc: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.warning(f"[MLX-DiT] Failed to initialize MLX decoder (non-fatal): {exc}") | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.mlx_decoder = None | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.use_mlx_dit = False | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def _mlx_run_diffusion( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| encoder_hidden_states, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| encoder_attention_mask, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| context_latents, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| src_latents, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| seed, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| infer_method: str = "ode", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| shift: float = 3.0, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| timesteps=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| audio_cover_strength: float = 1.0, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| encoder_hidden_states_non_cover=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| encoder_attention_mask_non_cover=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| context_latents_non_cover=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> Dict[str, Any]: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """Run the diffusion loop using the MLX decoder. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Accepts PyTorch tensors, converts to numpy for MLX, runs the loop, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| and converts results back to PyTorch tensors. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from acestep.mlx_dit.generate import mlx_generate_diffusion | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # Convert inputs to numpy (float32) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| enc_np = encoder_hidden_states.detach().cpu().float().numpy() | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ctx_np = context_latents.detach().cpu().float().numpy() | ||||||||||||||||||||||||||||||||||||||||||||||||||
| src_shape = (src_latents.shape[0], src_latents.shape[1], src_latents.shape[2]) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| enc_nc_np = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| encoder_hidden_states_non_cover.detach().cpu().float().numpy() | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if encoder_hidden_states_non_cover is not None else None | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ctx_nc_np = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| context_latents_non_cover.detach().cpu().float().numpy() | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if context_latents_non_cover is not None else None | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # Convert timesteps tensor if present | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ts_list = None | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if timesteps is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if hasattr(timesteps, "tolist"): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ts_list = timesteps.tolist() | ||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ts_list = list(timesteps) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| result = mlx_generate_diffusion( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| mlx_decoder=self.mlx_decoder, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| encoder_hidden_states_np=enc_np, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| context_latents_np=ctx_np, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| src_latents_shape=src_shape, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| seed=seed, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| infer_method=infer_method, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| shift=shift, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| timesteps=ts_list, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| audio_cover_strength=audio_cover_strength, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| encoder_hidden_states_non_cover_np=enc_nc_np, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| context_latents_non_cover_np=ctx_nc_np, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # Convert result latents back to PyTorch tensor on the correct device | ||||||||||||||||||||||||||||||||||||||||||||||||||
| target_np = result["target_latents"] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| target_tensor = torch.from_numpy(target_np).to(device=self.device, dtype=self.dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| return { | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "target_latents": target_tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "time_costs": result["time_costs"], | ||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_available_checkpoints(self) -> str: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """Return project root directory path""" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Get project root (handler.py is in acestep/, so go up two levels to project root) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -174,6 +275,7 @@ def initialize_service( | |||||||||||||||||||||||||||||||||||||||||||||||||
| offload_dit_to_cpu: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| quantization: Optional[str] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| prefer_source: Optional[str] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| use_mlx_dit: bool = True, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> Tuple[str, bool]: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Initialize DiT model service | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -452,6 +554,16 @@ def _vae_len(vae_self): | |||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # Determine actual attention implementation used | ||||||||||||||||||||||||||||||||||||||||||||||||||
| actual_attn = getattr(self.config, "_attn_implementation", "eager") | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # Try to initialize native MLX DiT for Apple Silicon acceleration | ||||||||||||||||||||||||||||||||||||||||||||||||||
| mlx_dit_status = "Disabled" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if use_mlx_dit and device in ("mps", "cpu") and not compile_model: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| mlx_ok = self._init_mlx_dit() | ||||||||||||||||||||||||||||||||||||||||||||||||||
| mlx_dit_status = "Active (native MLX)" if mlx_ok else "Unavailable (PyTorch fallback)" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| elif not use_mlx_dit: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| mlx_dit_status = "Disabled by user" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.mlx_decoder = None | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.use_mlx_dit = False | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+558
to
567
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. MLX state isn’t reset when compile_model/unsupported device skips init. If MLX was previously initialized, re‑init with ✅ Suggested fix- if use_mlx_dit and device in ("mps", "cpu") and not compile_model:
+ if use_mlx_dit and device in ("mps", "cpu") and not compile_model:
mlx_ok = self._init_mlx_dit()
mlx_dit_status = "Active (native MLX)" if mlx_ok else "Unavailable (PyTorch fallback)"
- elif not use_mlx_dit:
- mlx_dit_status = "Disabled by user"
- self.mlx_decoder = None
- self.use_mlx_dit = False
+ else:
+ if not use_mlx_dit:
+ mlx_dit_status = "Disabled by user"
+ elif compile_model:
+ mlx_dit_status = "Disabled (torch.compile enabled)"
+ else:
+ mlx_dit_status = "Unavailable (PyTorch fallback)"
+ self.mlx_decoder = None
+ self.use_mlx_dit = False📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||
| status_msg = f"✅ Model initialized successfully on {device}\n" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| status_msg += f"Main model: {acestep_v15_checkpoint_path}\n" | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -461,7 +573,8 @@ def _vae_len(vae_self): | |||||||||||||||||||||||||||||||||||||||||||||||||
| status_msg += f"Attention: {actual_attn}\n" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| status_msg += f"Compiled: {compile_model}\n" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| status_msg += f"Offload to CPU: {self.offload_to_cpu}\n" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| status_msg += f"Offload DiT to CPU: {self.offload_dit_to_cpu}" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| status_msg += f"Offload DiT to CPU: {self.offload_dit_to_cpu}\n" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| status_msg += f"MLX DiT: {mlx_dit_status}" | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # Persist latest successful init settings for mode switching (e.g. training preset). | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.last_init_params = { | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -473,6 +586,7 @@ def _vae_len(vae_self): | |||||||||||||||||||||||||||||||||||||||||||||||||
| "offload_to_cpu": offload_to_cpu, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "offload_dit_to_cpu": offload_dit_to_cpu, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "quantization": quantization, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "use_mlx_dit": use_mlx_dit, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "prefer_source": prefer_source, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -2665,7 +2779,8 @@ def service_generate( | |||||||||||||||||||||||||||||||||||||||||||||||||
| # Add custom timesteps if provided (convert to tensor) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if timesteps is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| generate_kwargs["timesteps"] = torch.tensor(timesteps, dtype=torch.float32, device=self.device) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.info("[service_generate] Generating audio...") | ||||||||||||||||||||||||||||||||||||||||||||||||||
| dit_backend = "MLX (native)" if (self.use_mlx_dit and self.mlx_decoder is not None) else f"PyTorch ({self.device})" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.info(f"[service_generate] Generating audio... (DiT backend: {dit_backend})") | ||||||||||||||||||||||||||||||||||||||||||||||||||
| with torch.inference_mode(): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| with self._load_model_context("model"): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Prepare condition tensors first (for LRC timestamp generation) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -2684,8 +2799,65 @@ def service_generate( | |||||||||||||||||||||||||||||||||||||||||||||||||
| is_covers=is_covers, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| outputs = self.model.generate_audio(**generate_kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # ---- MLX fast-path for the diffusion loop ---- | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.use_mlx_dit and self.mlx_decoder is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # For non-cover blend, prepare the non-cover conditions via PyTorch | ||||||||||||||||||||||||||||||||||||||||||||||||||
| enc_hs_nc, enc_am_nc, ctx_nc = None, None, None | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if audio_cover_strength < 1.0 and non_cover_text_hidden_states is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| non_is_covers = torch.zeros_like(is_covers) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| sil_exp = self.silence_latent[:, :src_latents.shape[1], :].expand( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| src_latents.shape[0], -1, -1 | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| enc_hs_nc, enc_am_nc, ctx_nc = self.model.prepare_condition( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| text_hidden_states=non_cover_text_hidden_states, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| text_attention_mask=non_cover_text_attention_masks, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| lyric_hidden_states=lyric_hidden_states, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| lyric_attention_mask=lyric_attention_mask, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| refer_audio_order_mask=refer_audio_order_mask, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states=sil_exp, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| attention_mask=torch.ones( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| sil_exp.shape[0], sil_exp.shape[1], | ||||||||||||||||||||||||||||||||||||||||||||||||||
| device=sil_exp.device, dtype=sil_exp.dtype, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||
| silence_latent=self.silence_latent, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| src_latents=sil_exp, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| chunk_masks=chunk_mask, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| is_covers=non_is_covers, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| ts_arg = generate_kwargs.get("timesteps") | ||||||||||||||||||||||||||||||||||||||||||||||||||
| outputs = self._mlx_run_diffusion( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| encoder_hidden_states=encoder_hidden_states, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| encoder_attention_mask=encoder_attention_mask, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| context_latents=context_latents, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| src_latents=src_latents, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| seed=seed_param, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| infer_method=infer_method, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| shift=shift, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| timesteps=ts_arg, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| audio_cover_strength=audio_cover_strength, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| encoder_hidden_states_non_cover=enc_hs_nc, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| encoder_attention_mask_non_cover=enc_am_nc, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| context_latents_non_cover=ctx_nc, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| _tc = outputs.get("time_costs", {}) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| _dt = _tc.get("diffusion_time_cost", 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| _ps = _tc.get("diffusion_per_step_time_cost", 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.info( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| f"[service_generate] DiT diffusion complete via MLX ({_dt:.2f}s total, {_ps:.3f}s/step)." | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as exc: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.warning( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "[service_generate] MLX diffusion failed (%s); falling back to PyTorch.", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| exc, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| outputs = self.model.generate_audio(**generate_kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.info("[service_generate] DiT diffusion via PyTorch (%s)...", self.device) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| outputs = self.model.generate_audio(**generate_kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # Add intermediate information to outputs for extra_outputs | ||||||||||||||||||||||||||||||||||||||||||||||||||
| outputs["src_latents"] = src_latents | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -2734,8 +2906,8 @@ def tiled_decode(self, latents, chunk_size: Optional[int] = None, overlap: int = | |||||||||||||||||||||||||||||||||||||||||||||||||
| _mps_overlap = self._MPS_DECODE_OVERLAP | ||||||||||||||||||||||||||||||||||||||||||||||||||
| _needs_reduction = (chunk_size > _mps_chunk) or (overlap > _mps_overlap) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if _needs_reduction: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.warning( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| f"[tiled_decode] MPS device detected; reducing chunk_size from {chunk_size} " | ||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.info( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| f"[tiled_decode] VAE decode via PyTorch MPS; reducing chunk_size from {chunk_size} " | ||||||||||||||||||||||||||||||||||||||||||||||||||
| f"to {min(chunk_size, _mps_chunk)} and overlap from {overlap} " | ||||||||||||||||||||||||||||||||||||||||||||||||||
| f"to {min(overlap, _mps_overlap)} to avoid MPS conv output limit." | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| # Native MLX implementation of the AceStep DiT decoder for Apple Silicon. | ||
| # Provides pure MLX inference with graceful fallback to PyTorch. | ||
|
|
||
| import logging | ||
| import platform | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def is_mlx_available() -> bool: | ||
| """Check if MLX is available on this platform (macOS + Apple Silicon).""" | ||
| if platform.system() != "Darwin": | ||
| return False | ||
| try: | ||
| import mlx.core as mx | ||
| import mlx.nn | ||
| # Verify we can actually create arrays (Metal backend works) | ||
| _ = mx.array([1.0]) | ||
| mx.eval(_) | ||
| return True | ||
| except Exception: | ||
| return False | ||
|
|
||
|
|
||
| _MLX_AVAILABLE = None | ||
|
|
||
|
|
||
| def mlx_available() -> bool: | ||
| """Cached check for MLX availability.""" | ||
| global _MLX_AVAILABLE | ||
| if _MLX_AVAILABLE is None: | ||
| _MLX_AVAILABLE = is_mlx_available() | ||
| return _MLX_AVAILABLE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Attention masks are dropped in the MLX path.
encoder_attention_maskandencoder_attention_mask_non_coverare accepted but unused. If these masks include padding zeros (likely), MLX cross‑attention will attend to padded tokens and diverge from PyTorch behavior. Please either apply the masks in the MLX decoder path or assert they are all‑ones.🧰 Tools
🪛 Ruff (0.15.0)
[warning] 151-151: Unused method argument:
encoder_attention_mask(ARG002)
[warning] 160-160: Unused method argument:
encoder_attention_mask_non_cover(ARG002)
🤖 Prompt for AI Agents