Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions acestep/gradio_ui/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
generation_section["offload_dit_to_cpu_checkbox"],
generation_section["compile_model_checkbox"],
generation_section["quantization_checkbox"],
generation_section["mlx_dit_checkbox"],
],
outputs=[
generation_section["init_status"],
Expand Down
4 changes: 2 additions & 2 deletions acestep/gradio_ui/events/generation_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def update_model_type_settings(config_path):
return get_model_type_ui_settings(is_turbo)


def init_service_wrapper(dit_handler, llm_handler, checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu, compile_model, quantization):
def init_service_wrapper(dit_handler, llm_handler, checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu, compile_model, quantization, mlx_dit=True):
"""Wrapper for service initialization, returns status, button state, accordion state, model type settings, and GPU-config-aware UI limits."""
# Convert quantization checkbox to value (int8_weight_only if checked, None if not)
quant_value = "int8_weight_only" if quantization else None
Expand Down Expand Up @@ -479,7 +479,7 @@ def init_service_wrapper(dit_handler, llm_handler, checkpoint, config_path, devi
checkpoint, config_path, device,
use_flash_attention=use_flash_attention, compile_model=compile_model,
offload_to_cpu=offload_to_cpu, offload_dit_to_cpu=offload_dit_to_cpu,
quantization=quant_value
quantization=quant_value, use_mlx_dit=mlx_dit,
)

# Initialize LM handler if requested
Expand Down
3 changes: 3 additions & 0 deletions acestep/gradio_ui/i18n/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
"compile_model_info": "Use torch.compile to optimize model (required for quantization)",
"quantization_label": "INT8 Quantization",
"quantization_info": "Enable INT8 weight-only quantization to reduce VRAM usage (requires Compile Model)",
"mlx_dit_label": "MLX DiT (Apple Silicon)",
"mlx_dit_info_enabled": "Use native MLX for DiT diffusion on Apple Silicon (faster than MPS)",
"mlx_dit_info_disabled": "MLX not available (requires macOS + Apple Silicon + mlx package)",
"init_btn": "Initialize Service",
"status_label": "Status",
"language_label": "UI Language",
Expand Down
3 changes: 3 additions & 0 deletions acestep/gradio_ui/i18n/he.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
"compile_model_info": "השתמש ב-torch.compile לאופטימיזציה של המודל (נדרש עבור קוונטיזציה)",
"quantization_label": "קוונטיזציה INT8",
"quantization_info": "הפעל קוונטיזציה של משקולות בלבד (INT8) להפחתת שימוש ב-VRAM (דורש הידור מודל)",
"mlx_dit_label": "MLX DiT (Apple Silicon)",
"mlx_dit_info_enabled": "השתמש ב-MLX מקורי להפצת DiT על Apple Silicon (מהיר יותר מ-MPS)",
"mlx_dit_info_disabled": "MLX לא זמין (דורש macOS + Apple Silicon + חבילת mlx)",
"init_btn": "אתחול שירות",
"status_label": "מצב",
"language_label": "שפת ממשק",
Expand Down
3 changes: 3 additions & 0 deletions acestep/gradio_ui/i18n/ja.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
"compile_model_info": "torch.compileでモデルを最適化(量子化に必要)",
"quantization_label": "INT8 量子化",
"quantization_info": "INT8重み量子化を有効にしてVRAMを節約(モデルのコンパイルが必要)",
"mlx_dit_label": "MLX DiT (Apple Silicon)",
"mlx_dit_info_enabled": "Apple SiliconでMLXネイティブDiT拡散を使用(MPSより高速)",
"mlx_dit_info_disabled": "MLXは利用不可(macOS + Apple Silicon + mlxパッケージが必要)",
"init_btn": "サービスを初期化",
"status_label": "ステータス",
"language_label": "UI言語",
Expand Down
3 changes: 3 additions & 0 deletions acestep/gradio_ui/i18n/zh.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
"compile_model_info": "使用 torch.compile 优化模型(量化必需)",
"quantization_label": "INT8 量化",
"quantization_info": "启用 INT8 仅权重量化以减少显存占用(需要启用编译模型)",
"mlx_dit_label": "MLX DiT (Apple Silicon)",
"mlx_dit_info_enabled": "使用原生 MLX 加速 DiT 扩散推理(比 MPS 更快)",
"mlx_dit_info_disabled": "MLX 不可用(需要 macOS + Apple Silicon + mlx 包)",
"init_btn": "初始化服务",
"status_label": "状态",
"language_label": "界面语言",
Expand Down
11 changes: 11 additions & 0 deletions acestep/gradio_ui/interfaces/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,16 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
value=quantization_value,
info=t("service.quantization_info") + (" (recommended for this tier)" if default_quantization else " (optional for this tier)")
)
# MLX DiT acceleration (macOS Apple Silicon only)
from acestep.mlx_dit import mlx_available as _mlx_avail
_mlx_ok = _mlx_avail()
mlx_dit_value = init_params.get('mlx_dit', _mlx_ok) if service_pre_initialized else _mlx_ok
mlx_dit_checkbox = gr.Checkbox(
label=t("service.mlx_dit_label"),
value=mlx_dit_value,
interactive=_mlx_ok,
info=t("service.mlx_dit_info_enabled") if _mlx_ok else t("service.mlx_dit_info_disabled")
)

init_btn = gr.Button(t("service.init_btn"), variant="primary", size="lg")
# Set init_status value from init_params if pre-initialized
Expand Down Expand Up @@ -775,6 +785,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
"offload_dit_to_cpu_checkbox": offload_dit_to_cpu_checkbox,
"compile_model_checkbox": compile_model_checkbox,
"quantization_checkbox": quantization_checkbox,
"mlx_dit_checkbox": mlx_dit_checkbox,
# LoRA components
"lora_path": lora_path,
"load_lora_btn": load_lora_btn,
Expand Down
186 changes: 179 additions & 7 deletions acestep/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,108 @@ def __init__(self):
self._base_decoder = None # Backup of original decoder
self._lora_adapter_registry = {} # adapter_name -> explicit scaling targets
self._lora_active_adapter = None


# MLX DiT acceleration (macOS Apple Silicon only)
self.mlx_decoder = None
self.use_mlx_dit = False

# ------------------------------------------------------------------
# MLX DiT acceleration helpers
# ------------------------------------------------------------------
def _init_mlx_dit(self) -> bool:
"""Try to initialize the native MLX DiT decoder for Apple Silicon.

Returns True on success, False on failure (non-fatal).
"""
try:
from acestep.mlx_dit import mlx_available
if not mlx_available():
logger.info("[MLX-DiT] MLX not available on this platform; skipping.")
return False

from acestep.mlx_dit.model import MLXDiTDecoder
from acestep.mlx_dit.convert import convert_and_load

mlx_decoder = MLXDiTDecoder.from_config(self.config)
convert_and_load(self.model, mlx_decoder)
self.mlx_decoder = mlx_decoder
self.use_mlx_dit = True
logger.info("[MLX-DiT] Native MLX DiT decoder initialized successfully.")
return True
except Exception as exc:
logger.warning(f"[MLX-DiT] Failed to initialize MLX decoder (non-fatal): {exc}")
self.mlx_decoder = None
self.use_mlx_dit = False
return False

def _mlx_run_diffusion(
self,
encoder_hidden_states,
encoder_attention_mask,
context_latents,
src_latents,
seed,
infer_method: str = "ode",
shift: float = 3.0,
timesteps=None,
audio_cover_strength: float = 1.0,
encoder_hidden_states_non_cover=None,
encoder_attention_mask_non_cover=None,
context_latents_non_cover=None,
) -> Dict[str, Any]:
"""Run the diffusion loop using the MLX decoder.

Accepts PyTorch tensors, converts to numpy for MLX, runs the loop,
and converts results back to PyTorch tensors.
"""
import numpy as np
from acestep.mlx_dit.generate import mlx_generate_diffusion

# Convert inputs to numpy (float32)
enc_np = encoder_hidden_states.detach().cpu().float().numpy()
ctx_np = context_latents.detach().cpu().float().numpy()
src_shape = (src_latents.shape[0], src_latents.shape[1], src_latents.shape[2])

enc_nc_np = (
encoder_hidden_states_non_cover.detach().cpu().float().numpy()
if encoder_hidden_states_non_cover is not None else None
)
ctx_nc_np = (
context_latents_non_cover.detach().cpu().float().numpy()
if context_latents_non_cover is not None else None
)

# Convert timesteps tensor if present
ts_list = None
if timesteps is not None:
if hasattr(timesteps, "tolist"):
ts_list = timesteps.tolist()
else:
ts_list = list(timesteps)

result = mlx_generate_diffusion(
mlx_decoder=self.mlx_decoder,
encoder_hidden_states_np=enc_np,
context_latents_np=ctx_np,
src_latents_shape=src_shape,
seed=seed,
infer_method=infer_method,
shift=shift,
timesteps=ts_list,
audio_cover_strength=audio_cover_strength,
encoder_hidden_states_non_cover_np=enc_nc_np,
context_latents_non_cover_np=ctx_nc_np,
)

# Convert result latents back to PyTorch tensor on the correct device
target_np = result["target_latents"]
target_tensor = torch.from_numpy(target_np).to(device=self.device, dtype=self.dtype)

return {
"target_latents": target_tensor,
"time_costs": result["time_costs"],
}
Comment on lines +148 to +214
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Attention masks are dropped in the MLX path.

encoder_attention_mask and encoder_attention_mask_non_cover are accepted but unused. If these masks include padding zeros (likely), MLX cross‑attention will attend to padded tokens and diverge from PyTorch behavior. Please either apply the masks in the MLX decoder path or assert they are all‑ones.

🧰 Tools
🪛 Ruff (0.15.0)

[warning] 151-151: Unused method argument: encoder_attention_mask

(ARG002)


[warning] 160-160: Unused method argument: encoder_attention_mask_non_cover

(ARG002)

🤖 Prompt for AI Agents
In `@acestep/handler.py` around lines 148 - 214, The MLX path drops attention
masks: _mlx_run_diffusion accepts encoder_attention_mask and
encoder_attention_mask_non_cover but never uses them, so MLX cross-attention may
attend padded tokens; update _mlx_run_diffusion to either (A) forward the masks
into mlx_generate_diffusion (add encoder_attention_mask_np and
encoder_attention_mask_non_cover_np parameters by converting tensors to numpy
like enc_np/enc_nc_np) and ensure mlx_generate_diffusion and mlx_decoder consume
them, or (B) assert the masks are all-ones before calling mlx_generate_diffusion
(e.g., check mask.detach().cpu().numpy().all()) to guarantee parity with PyTorch
behavior; reference function names: _mlx_run_diffusion, mlx_generate_diffusion,
mlx_decoder, encoder_attention_mask, encoder_attention_mask_non_cover when
applying the chosen fix.


def get_available_checkpoints(self) -> str:
"""Return project root directory path"""
# Get project root (handler.py is in acestep/, so go up two levels to project root)
Expand Down Expand Up @@ -174,6 +275,7 @@ def initialize_service(
offload_dit_to_cpu: bool = False,
quantization: Optional[str] = None,
prefer_source: Optional[str] = None,
use_mlx_dit: bool = True,
) -> Tuple[str, bool]:
"""
Initialize DiT model service
Expand Down Expand Up @@ -452,6 +554,16 @@ def _vae_len(vae_self):

# Determine actual attention implementation used
actual_attn = getattr(self.config, "_attn_implementation", "eager")

# Try to initialize native MLX DiT for Apple Silicon acceleration
mlx_dit_status = "Disabled"
if use_mlx_dit and device in ("mps", "cpu") and not compile_model:
mlx_ok = self._init_mlx_dit()
mlx_dit_status = "Active (native MLX)" if mlx_ok else "Unavailable (PyTorch fallback)"
elif not use_mlx_dit:
mlx_dit_status = "Disabled by user"
self.mlx_decoder = None
self.use_mlx_dit = False

Comment on lines +558 to 567
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

MLX state isn’t reset when compile_model/unsupported device skips init.

If MLX was previously initialized, re‑init with compile_model=True or a non‑MPS/CPU device keeps the old mlx_decoder + use_mlx_dit, so service_generate can still take the MLX path despite the guard. Reset state in the non‑MLX branch.

✅ Suggested fix
-            if use_mlx_dit and device in ("mps", "cpu") and not compile_model:
+            if use_mlx_dit and device in ("mps", "cpu") and not compile_model:
                 mlx_ok = self._init_mlx_dit()
                 mlx_dit_status = "Active (native MLX)" if mlx_ok else "Unavailable (PyTorch fallback)"
-            elif not use_mlx_dit:
-                mlx_dit_status = "Disabled by user"
-                self.mlx_decoder = None
-                self.use_mlx_dit = False
+            else:
+                if not use_mlx_dit:
+                    mlx_dit_status = "Disabled by user"
+                elif compile_model:
+                    mlx_dit_status = "Disabled (torch.compile enabled)"
+                else:
+                    mlx_dit_status = "Unavailable (PyTorch fallback)"
+                self.mlx_decoder = None
+                self.use_mlx_dit = False
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Try to initialize native MLX DiT for Apple Silicon acceleration
mlx_dit_status = "Disabled"
if use_mlx_dit and device in ("mps", "cpu") and not compile_model:
mlx_ok = self._init_mlx_dit()
mlx_dit_status = "Active (native MLX)" if mlx_ok else "Unavailable (PyTorch fallback)"
elif not use_mlx_dit:
mlx_dit_status = "Disabled by user"
self.mlx_decoder = None
self.use_mlx_dit = False
# Try to initialize native MLX DiT for Apple Silicon acceleration
mlx_dit_status = "Disabled"
if use_mlx_dit and device in ("mps", "cpu") and not compile_model:
mlx_ok = self._init_mlx_dit()
mlx_dit_status = "Active (native MLX)" if mlx_ok else "Unavailable (PyTorch fallback)"
else:
if not use_mlx_dit:
mlx_dit_status = "Disabled by user"
elif compile_model:
mlx_dit_status = "Disabled (torch.compile enabled)"
else:
mlx_dit_status = "Unavailable (PyTorch fallback)"
self.mlx_decoder = None
self.use_mlx_dit = False
🤖 Prompt for AI Agents
In `@acestep/handler.py` around lines 550 - 559, The MLX state (self.mlx_decoder
and self.use_mlx_dit) must be explicitly reset whenever MLX init is skipped or
fails so service_generate won't incorrectly take the MLX path; update the block
that sets mlx_dit_status and calls self._init_mlx_dit() so that: when the init
path is not taken because compile_model is True or device not in ("mps","cpu"),
set self.mlx_decoder = None and self.use_mlx_dit = False; and when
self._init_mlx_dit() returns False (mlx_ok is False) also ensure
self.mlx_decoder = None and self.use_mlx_dit = False while setting
mlx_dit_status accordingly; keep mlx_dit_status assignments as shown and
reference the symbols mlx_dit_status, use_mlx_dit, compile_model, device,
self._init_mlx_dit(), self.mlx_decoder, self.use_mlx_dit, and service_generate.

status_msg = f"✅ Model initialized successfully on {device}\n"
status_msg += f"Main model: {acestep_v15_checkpoint_path}\n"
Expand All @@ -461,7 +573,8 @@ def _vae_len(vae_self):
status_msg += f"Attention: {actual_attn}\n"
status_msg += f"Compiled: {compile_model}\n"
status_msg += f"Offload to CPU: {self.offload_to_cpu}\n"
status_msg += f"Offload DiT to CPU: {self.offload_dit_to_cpu}"
status_msg += f"Offload DiT to CPU: {self.offload_dit_to_cpu}\n"
status_msg += f"MLX DiT: {mlx_dit_status}"

# Persist latest successful init settings for mode switching (e.g. training preset).
self.last_init_params = {
Expand All @@ -473,6 +586,7 @@ def _vae_len(vae_self):
"offload_to_cpu": offload_to_cpu,
"offload_dit_to_cpu": offload_dit_to_cpu,
"quantization": quantization,
"use_mlx_dit": use_mlx_dit,
"prefer_source": prefer_source,
}

Expand Down Expand Up @@ -2665,7 +2779,8 @@ def service_generate(
# Add custom timesteps if provided (convert to tensor)
if timesteps is not None:
generate_kwargs["timesteps"] = torch.tensor(timesteps, dtype=torch.float32, device=self.device)
logger.info("[service_generate] Generating audio...")
dit_backend = "MLX (native)" if (self.use_mlx_dit and self.mlx_decoder is not None) else f"PyTorch ({self.device})"
logger.info(f"[service_generate] Generating audio... (DiT backend: {dit_backend})")
with torch.inference_mode():
with self._load_model_context("model"):
# Prepare condition tensors first (for LRC timestamp generation)
Expand All @@ -2684,8 +2799,65 @@ def service_generate(
is_covers=is_covers,
precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz,
)

outputs = self.model.generate_audio(**generate_kwargs)

# ---- MLX fast-path for the diffusion loop ----
if self.use_mlx_dit and self.mlx_decoder is not None:
try:
# For non-cover blend, prepare the non-cover conditions via PyTorch
enc_hs_nc, enc_am_nc, ctx_nc = None, None, None
if audio_cover_strength < 1.0 and non_cover_text_hidden_states is not None:
non_is_covers = torch.zeros_like(is_covers)
sil_exp = self.silence_latent[:, :src_latents.shape[1], :].expand(
src_latents.shape[0], -1, -1
)
enc_hs_nc, enc_am_nc, ctx_nc = self.model.prepare_condition(
text_hidden_states=non_cover_text_hidden_states,
text_attention_mask=non_cover_text_attention_masks,
lyric_hidden_states=lyric_hidden_states,
lyric_attention_mask=lyric_attention_mask,
refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
refer_audio_order_mask=refer_audio_order_mask,
hidden_states=sil_exp,
attention_mask=torch.ones(
sil_exp.shape[0], sil_exp.shape[1],
device=sil_exp.device, dtype=sil_exp.dtype,
),
silence_latent=self.silence_latent,
src_latents=sil_exp,
chunk_masks=chunk_mask,
is_covers=non_is_covers,
)

ts_arg = generate_kwargs.get("timesteps")
outputs = self._mlx_run_diffusion(
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
context_latents=context_latents,
src_latents=src_latents,
seed=seed_param,
infer_method=infer_method,
shift=shift,
timesteps=ts_arg,
audio_cover_strength=audio_cover_strength,
encoder_hidden_states_non_cover=enc_hs_nc,
encoder_attention_mask_non_cover=enc_am_nc,
context_latents_non_cover=ctx_nc,
)
_tc = outputs.get("time_costs", {})
_dt = _tc.get("diffusion_time_cost", 0)
_ps = _tc.get("diffusion_per_step_time_cost", 0)
logger.info(
f"[service_generate] DiT diffusion complete via MLX ({_dt:.2f}s total, {_ps:.3f}s/step)."
)
except Exception as exc:
logger.warning(
"[service_generate] MLX diffusion failed (%s); falling back to PyTorch.",
exc,
)
outputs = self.model.generate_audio(**generate_kwargs)
else:
logger.info("[service_generate] DiT diffusion via PyTorch (%s)...", self.device)
outputs = self.model.generate_audio(**generate_kwargs)

# Add intermediate information to outputs for extra_outputs
outputs["src_latents"] = src_latents
Expand Down Expand Up @@ -2734,8 +2906,8 @@ def tiled_decode(self, latents, chunk_size: Optional[int] = None, overlap: int =
_mps_overlap = self._MPS_DECODE_OVERLAP
_needs_reduction = (chunk_size > _mps_chunk) or (overlap > _mps_overlap)
if _needs_reduction:
logger.warning(
f"[tiled_decode] MPS device detected; reducing chunk_size from {chunk_size} "
logger.info(
f"[tiled_decode] VAE decode via PyTorch MPS; reducing chunk_size from {chunk_size} "
f"to {min(chunk_size, _mps_chunk)} and overlap from {overlap} "
f"to {min(overlap, _mps_overlap)} to avoid MPS conv output limit."
)
Expand Down
33 changes: 33 additions & 0 deletions acestep/mlx_dit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Native MLX implementation of the AceStep DiT decoder for Apple Silicon.
# Provides pure MLX inference with graceful fallback to PyTorch.

import logging
import platform

logger = logging.getLogger(__name__)


def is_mlx_available() -> bool:
"""Check if MLX is available on this platform (macOS + Apple Silicon)."""
if platform.system() != "Darwin":
return False
try:
import mlx.core as mx
import mlx.nn
# Verify we can actually create arrays (Metal backend works)
_ = mx.array([1.0])
mx.eval(_)
return True
except Exception:
return False


_MLX_AVAILABLE = None


def mlx_available() -> bool:
"""Cached check for MLX availability."""
global _MLX_AVAILABLE
if _MLX_AVAILABLE is None:
_MLX_AVAILABLE = is_mlx_available()
return _MLX_AVAILABLE
Loading