diff --git a/acestep/gradio_ui/events/__init__.py b/acestep/gradio_ui/events/__init__.py index c76c6c0..90c6e12 100644 --- a/acestep/gradio_ui/events/__init__.py +++ b/acestep/gradio_ui/events/__init__.py @@ -57,6 +57,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase generation_section["offload_dit_to_cpu_checkbox"], generation_section["compile_model_checkbox"], generation_section["quantization_checkbox"], + generation_section["mlx_dit_checkbox"], ], outputs=[ generation_section["init_status"], diff --git a/acestep/gradio_ui/events/generation_handlers.py b/acestep/gradio_ui/events/generation_handlers.py index 8a0dd0c..cf0690e 100644 --- a/acestep/gradio_ui/events/generation_handlers.py +++ b/acestep/gradio_ui/events/generation_handlers.py @@ -441,7 +441,7 @@ def update_model_type_settings(config_path): return get_model_type_ui_settings(is_turbo) -def init_service_wrapper(dit_handler, llm_handler, checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu, compile_model, quantization): +def init_service_wrapper(dit_handler, llm_handler, checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu, compile_model, quantization, mlx_dit=True): """Wrapper for service initialization, returns status, button state, accordion state, model type settings, and GPU-config-aware UI limits.""" # Convert quantization checkbox to value (int8_weight_only if checked, None if not) quant_value = "int8_weight_only" if quantization else None @@ -479,7 +479,7 @@ def init_service_wrapper(dit_handler, llm_handler, checkpoint, config_path, devi checkpoint, config_path, device, use_flash_attention=use_flash_attention, compile_model=compile_model, offload_to_cpu=offload_to_cpu, offload_dit_to_cpu=offload_dit_to_cpu, - quantization=quant_value + quantization=quant_value, use_mlx_dit=mlx_dit, ) # Initialize LM handler if requested diff --git a/acestep/gradio_ui/i18n/en.json b/acestep/gradio_ui/i18n/en.json index 5874d53..099711d 100644 --- a/acestep/gradio_ui/i18n/en.json +++ b/acestep/gradio_ui/i18n/en.json @@ -53,6 +53,9 @@ "compile_model_info": "Use torch.compile to optimize model (required for quantization)", "quantization_label": "INT8 Quantization", "quantization_info": "Enable INT8 weight-only quantization to reduce VRAM usage (requires Compile Model)", + "mlx_dit_label": "MLX DiT (Apple Silicon)", + "mlx_dit_info_enabled": "Use native MLX for DiT diffusion on Apple Silicon (faster than MPS)", + "mlx_dit_info_disabled": "MLX not available (requires macOS + Apple Silicon + mlx package)", "init_btn": "Initialize Service", "status_label": "Status", "language_label": "UI Language", diff --git a/acestep/gradio_ui/i18n/he.json b/acestep/gradio_ui/i18n/he.json index e8d549e..6e0535d 100644 --- a/acestep/gradio_ui/i18n/he.json +++ b/acestep/gradio_ui/i18n/he.json @@ -53,6 +53,9 @@ "compile_model_info": "השתמש ב-torch.compile לאופטימיזציה של המודל (נדרש עבור קוונטיזציה)", "quantization_label": "קוונטיזציה INT8", "quantization_info": "הפעל קוונטיזציה של משקולות בלבד (INT8) להפחתת שימוש ב-VRAM (דורש הידור מודל)", + "mlx_dit_label": "MLX DiT (Apple Silicon)", + "mlx_dit_info_enabled": "השתמש ב-MLX מקורי להפצת DiT על Apple Silicon (מהיר יותר מ-MPS)", + "mlx_dit_info_disabled": "MLX לא זמין (דורש macOS + Apple Silicon + חבילת mlx)", "init_btn": "אתחול שירות", "status_label": "מצב", "language_label": "שפת ממשק", diff --git a/acestep/gradio_ui/i18n/ja.json b/acestep/gradio_ui/i18n/ja.json index c90aba6..e5b934c 100644 --- a/acestep/gradio_ui/i18n/ja.json +++ b/acestep/gradio_ui/i18n/ja.json @@ -53,6 +53,9 @@ "compile_model_info": "torch.compileでモデルを最適化(量子化に必要)", "quantization_label": "INT8 量子化", "quantization_info": "INT8重み量子化を有効にしてVRAMを節約(モデルのコンパイルが必要)", + "mlx_dit_label": "MLX DiT (Apple Silicon)", + "mlx_dit_info_enabled": "Apple SiliconでMLXネイティブDiT拡散を使用(MPSより高速)", + "mlx_dit_info_disabled": "MLXは利用不可(macOS + Apple Silicon + mlxパッケージが必要)", "init_btn": "サービスを初期化", "status_label": "ステータス", "language_label": "UI言語", diff --git a/acestep/gradio_ui/i18n/zh.json b/acestep/gradio_ui/i18n/zh.json index 97bef71..cd222e4 100644 --- a/acestep/gradio_ui/i18n/zh.json +++ b/acestep/gradio_ui/i18n/zh.json @@ -53,6 +53,9 @@ "compile_model_info": "使用 torch.compile 优化模型(量化必需)", "quantization_label": "INT8 量化", "quantization_info": "启用 INT8 仅权重量化以减少显存占用(需要启用编译模型)", + "mlx_dit_label": "MLX DiT (Apple Silicon)", + "mlx_dit_info_enabled": "使用原生 MLX 加速 DiT 扩散推理(比 MPS 更快)", + "mlx_dit_info_disabled": "MLX 不可用(需要 macOS + Apple Silicon + mlx 包)", "init_btn": "初始化服务", "status_label": "状态", "language_label": "界面语言", diff --git a/acestep/gradio_ui/interfaces/generation.py b/acestep/gradio_ui/interfaces/generation.py index 564b07a..855b36d 100644 --- a/acestep/gradio_ui/interfaces/generation.py +++ b/acestep/gradio_ui/interfaces/generation.py @@ -218,6 +218,16 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua value=quantization_value, info=t("service.quantization_info") + (" (recommended for this tier)" if default_quantization else " (optional for this tier)") ) + # MLX DiT acceleration (macOS Apple Silicon only) + from acestep.mlx_dit import mlx_available as _mlx_avail + _mlx_ok = _mlx_avail() + mlx_dit_value = init_params.get('mlx_dit', _mlx_ok) if service_pre_initialized else _mlx_ok + mlx_dit_checkbox = gr.Checkbox( + label=t("service.mlx_dit_label"), + value=mlx_dit_value, + interactive=_mlx_ok, + info=t("service.mlx_dit_info_enabled") if _mlx_ok else t("service.mlx_dit_info_disabled") + ) init_btn = gr.Button(t("service.init_btn"), variant="primary", size="lg") # Set init_status value from init_params if pre-initialized @@ -775,6 +785,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua "offload_dit_to_cpu_checkbox": offload_dit_to_cpu_checkbox, "compile_model_checkbox": compile_model_checkbox, "quantization_checkbox": quantization_checkbox, + "mlx_dit_checkbox": mlx_dit_checkbox, # LoRA components "lora_path": lora_path, "load_lora_btn": load_lora_btn, diff --git a/acestep/handler.py b/acestep/handler.py index 907c67d..42100d4 100644 --- a/acestep/handler.py +++ b/acestep/handler.py @@ -111,7 +111,108 @@ def __init__(self): self._base_decoder = None # Backup of original decoder self._lora_adapter_registry = {} # adapter_name -> explicit scaling targets self._lora_active_adapter = None - + + # MLX DiT acceleration (macOS Apple Silicon only) + self.mlx_decoder = None + self.use_mlx_dit = False + + # ------------------------------------------------------------------ + # MLX DiT acceleration helpers + # ------------------------------------------------------------------ + def _init_mlx_dit(self) -> bool: + """Try to initialize the native MLX DiT decoder for Apple Silicon. + + Returns True on success, False on failure (non-fatal). + """ + try: + from acestep.mlx_dit import mlx_available + if not mlx_available(): + logger.info("[MLX-DiT] MLX not available on this platform; skipping.") + return False + + from acestep.mlx_dit.model import MLXDiTDecoder + from acestep.mlx_dit.convert import convert_and_load + + mlx_decoder = MLXDiTDecoder.from_config(self.config) + convert_and_load(self.model, mlx_decoder) + self.mlx_decoder = mlx_decoder + self.use_mlx_dit = True + logger.info("[MLX-DiT] Native MLX DiT decoder initialized successfully.") + return True + except Exception as exc: + logger.warning(f"[MLX-DiT] Failed to initialize MLX decoder (non-fatal): {exc}") + self.mlx_decoder = None + self.use_mlx_dit = False + return False + + def _mlx_run_diffusion( + self, + encoder_hidden_states, + encoder_attention_mask, + context_latents, + src_latents, + seed, + infer_method: str = "ode", + shift: float = 3.0, + timesteps=None, + audio_cover_strength: float = 1.0, + encoder_hidden_states_non_cover=None, + encoder_attention_mask_non_cover=None, + context_latents_non_cover=None, + ) -> Dict[str, Any]: + """Run the diffusion loop using the MLX decoder. + + Accepts PyTorch tensors, converts to numpy for MLX, runs the loop, + and converts results back to PyTorch tensors. + """ + import numpy as np + from acestep.mlx_dit.generate import mlx_generate_diffusion + + # Convert inputs to numpy (float32) + enc_np = encoder_hidden_states.detach().cpu().float().numpy() + ctx_np = context_latents.detach().cpu().float().numpy() + src_shape = (src_latents.shape[0], src_latents.shape[1], src_latents.shape[2]) + + enc_nc_np = ( + encoder_hidden_states_non_cover.detach().cpu().float().numpy() + if encoder_hidden_states_non_cover is not None else None + ) + ctx_nc_np = ( + context_latents_non_cover.detach().cpu().float().numpy() + if context_latents_non_cover is not None else None + ) + + # Convert timesteps tensor if present + ts_list = None + if timesteps is not None: + if hasattr(timesteps, "tolist"): + ts_list = timesteps.tolist() + else: + ts_list = list(timesteps) + + result = mlx_generate_diffusion( + mlx_decoder=self.mlx_decoder, + encoder_hidden_states_np=enc_np, + context_latents_np=ctx_np, + src_latents_shape=src_shape, + seed=seed, + infer_method=infer_method, + shift=shift, + timesteps=ts_list, + audio_cover_strength=audio_cover_strength, + encoder_hidden_states_non_cover_np=enc_nc_np, + context_latents_non_cover_np=ctx_nc_np, + ) + + # Convert result latents back to PyTorch tensor on the correct device + target_np = result["target_latents"] + target_tensor = torch.from_numpy(target_np).to(device=self.device, dtype=self.dtype) + + return { + "target_latents": target_tensor, + "time_costs": result["time_costs"], + } + def get_available_checkpoints(self) -> str: """Return project root directory path""" # Get project root (handler.py is in acestep/, so go up two levels to project root) @@ -174,6 +275,7 @@ def initialize_service( offload_dit_to_cpu: bool = False, quantization: Optional[str] = None, prefer_source: Optional[str] = None, + use_mlx_dit: bool = True, ) -> Tuple[str, bool]: """ Initialize DiT model service @@ -452,6 +554,16 @@ def _vae_len(vae_self): # Determine actual attention implementation used actual_attn = getattr(self.config, "_attn_implementation", "eager") + + # Try to initialize native MLX DiT for Apple Silicon acceleration + mlx_dit_status = "Disabled" + if use_mlx_dit and device in ("mps", "cpu") and not compile_model: + mlx_ok = self._init_mlx_dit() + mlx_dit_status = "Active (native MLX)" if mlx_ok else "Unavailable (PyTorch fallback)" + elif not use_mlx_dit: + mlx_dit_status = "Disabled by user" + self.mlx_decoder = None + self.use_mlx_dit = False status_msg = f"✅ Model initialized successfully on {device}\n" status_msg += f"Main model: {acestep_v15_checkpoint_path}\n" @@ -461,7 +573,8 @@ def _vae_len(vae_self): status_msg += f"Attention: {actual_attn}\n" status_msg += f"Compiled: {compile_model}\n" status_msg += f"Offload to CPU: {self.offload_to_cpu}\n" - status_msg += f"Offload DiT to CPU: {self.offload_dit_to_cpu}" + status_msg += f"Offload DiT to CPU: {self.offload_dit_to_cpu}\n" + status_msg += f"MLX DiT: {mlx_dit_status}" # Persist latest successful init settings for mode switching (e.g. training preset). self.last_init_params = { @@ -473,6 +586,7 @@ def _vae_len(vae_self): "offload_to_cpu": offload_to_cpu, "offload_dit_to_cpu": offload_dit_to_cpu, "quantization": quantization, + "use_mlx_dit": use_mlx_dit, "prefer_source": prefer_source, } @@ -2665,7 +2779,8 @@ def service_generate( # Add custom timesteps if provided (convert to tensor) if timesteps is not None: generate_kwargs["timesteps"] = torch.tensor(timesteps, dtype=torch.float32, device=self.device) - logger.info("[service_generate] Generating audio...") + dit_backend = "MLX (native)" if (self.use_mlx_dit and self.mlx_decoder is not None) else f"PyTorch ({self.device})" + logger.info(f"[service_generate] Generating audio... (DiT backend: {dit_backend})") with torch.inference_mode(): with self._load_model_context("model"): # Prepare condition tensors first (for LRC timestamp generation) @@ -2684,8 +2799,65 @@ def service_generate( is_covers=is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, ) - - outputs = self.model.generate_audio(**generate_kwargs) + + # ---- MLX fast-path for the diffusion loop ---- + if self.use_mlx_dit and self.mlx_decoder is not None: + try: + # For non-cover blend, prepare the non-cover conditions via PyTorch + enc_hs_nc, enc_am_nc, ctx_nc = None, None, None + if audio_cover_strength < 1.0 and non_cover_text_hidden_states is not None: + non_is_covers = torch.zeros_like(is_covers) + sil_exp = self.silence_latent[:, :src_latents.shape[1], :].expand( + src_latents.shape[0], -1, -1 + ) + enc_hs_nc, enc_am_nc, ctx_nc = self.model.prepare_condition( + text_hidden_states=non_cover_text_hidden_states, + text_attention_mask=non_cover_text_attention_masks, + lyric_hidden_states=lyric_hidden_states, + lyric_attention_mask=lyric_attention_mask, + refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed, + refer_audio_order_mask=refer_audio_order_mask, + hidden_states=sil_exp, + attention_mask=torch.ones( + sil_exp.shape[0], sil_exp.shape[1], + device=sil_exp.device, dtype=sil_exp.dtype, + ), + silence_latent=self.silence_latent, + src_latents=sil_exp, + chunk_masks=chunk_mask, + is_covers=non_is_covers, + ) + + ts_arg = generate_kwargs.get("timesteps") + outputs = self._mlx_run_diffusion( + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + context_latents=context_latents, + src_latents=src_latents, + seed=seed_param, + infer_method=infer_method, + shift=shift, + timesteps=ts_arg, + audio_cover_strength=audio_cover_strength, + encoder_hidden_states_non_cover=enc_hs_nc, + encoder_attention_mask_non_cover=enc_am_nc, + context_latents_non_cover=ctx_nc, + ) + _tc = outputs.get("time_costs", {}) + _dt = _tc.get("diffusion_time_cost", 0) + _ps = _tc.get("diffusion_per_step_time_cost", 0) + logger.info( + f"[service_generate] DiT diffusion complete via MLX ({_dt:.2f}s total, {_ps:.3f}s/step)." + ) + except Exception as exc: + logger.warning( + "[service_generate] MLX diffusion failed (%s); falling back to PyTorch.", + exc, + ) + outputs = self.model.generate_audio(**generate_kwargs) + else: + logger.info("[service_generate] DiT diffusion via PyTorch (%s)...", self.device) + outputs = self.model.generate_audio(**generate_kwargs) # Add intermediate information to outputs for extra_outputs outputs["src_latents"] = src_latents @@ -2734,8 +2906,8 @@ def tiled_decode(self, latents, chunk_size: Optional[int] = None, overlap: int = _mps_overlap = self._MPS_DECODE_OVERLAP _needs_reduction = (chunk_size > _mps_chunk) or (overlap > _mps_overlap) if _needs_reduction: - logger.warning( - f"[tiled_decode] MPS device detected; reducing chunk_size from {chunk_size} " + logger.info( + f"[tiled_decode] VAE decode via PyTorch MPS; reducing chunk_size from {chunk_size} " f"to {min(chunk_size, _mps_chunk)} and overlap from {overlap} " f"to {min(overlap, _mps_overlap)} to avoid MPS conv output limit." ) diff --git a/acestep/mlx_dit/__init__.py b/acestep/mlx_dit/__init__.py new file mode 100644 index 0000000..c71ae43 --- /dev/null +++ b/acestep/mlx_dit/__init__.py @@ -0,0 +1,33 @@ +# Native MLX implementation of the AceStep DiT decoder for Apple Silicon. +# Provides pure MLX inference with graceful fallback to PyTorch. + +import logging +import platform + +logger = logging.getLogger(__name__) + + +def is_mlx_available() -> bool: + """Check if MLX is available on this platform (macOS + Apple Silicon).""" + if platform.system() != "Darwin": + return False + try: + import mlx.core as mx + import mlx.nn + # Verify we can actually create arrays (Metal backend works) + _ = mx.array([1.0]) + mx.eval(_) + return True + except Exception: + return False + + +_MLX_AVAILABLE = None + + +def mlx_available() -> bool: + """Cached check for MLX availability.""" + global _MLX_AVAILABLE + if _MLX_AVAILABLE is None: + _MLX_AVAILABLE = is_mlx_available() + return _MLX_AVAILABLE diff --git a/acestep/mlx_dit/convert.py b/acestep/mlx_dit/convert.py new file mode 100644 index 0000000..3305e1d --- /dev/null +++ b/acestep/mlx_dit/convert.py @@ -0,0 +1,84 @@ +# Weight conversion from PyTorch AceStep DiT decoder to native MLX format. + +import logging +from typing import Dict, List, Tuple + +import numpy as np + +logger = logging.getLogger(__name__) + + +def convert_decoder_weights( + pytorch_model, +) -> List[Tuple[str, "mx.array"]]: + """Convert PyTorch decoder weights to a list of (name, mx.array) pairs + suitable for ``mlx_decoder.load_weights()``. + + The function extracts weights from + ``pytorch_model.decoder`` (``AceStepDiTModel``) and converts them to MLX + format, handling: + - Conv1d weight layout: PT ``[out, in, K]`` -> MLX ``[out, K, in]`` + - ConvTranspose1d layout: PT ``[in, out, K]`` -> MLX ``[out, K, in]`` + - nn.Sequential index remapping (Lambda wrappers removed in MLX) + - All other weights are transferred as-is + + Args: + pytorch_model: The full ``AceStepConditionGenerationModel`` (PyTorch). + + Returns: + List of (param_name, mx.array) pairs ready for ``model.load_weights()``. + """ + import mlx.core as mx + + decoder = pytorch_model.decoder + state_dict = decoder.state_dict() + + weights: List[Tuple[str, "mx.array"]] = [] + + for key, value in state_dict.items(): + np_val = value.detach().cpu().float().numpy() + new_key = key + + # PyTorch proj_in is Sequential(Lambda, Conv1d, Lambda) + # The Conv1d is at index 1. In MLX we use a bare Conv1d. + if key.startswith("proj_in.1."): + new_key = key.replace("proj_in.1.", "proj_in.") + if new_key.endswith(".weight"): + # PT Conv1d weight: [out, in, K] -> MLX: [out, K, in] + np_val = np_val.swapaxes(1, 2) + + # PyTorch proj_out is Sequential(Lambda, ConvTranspose1d, Lambda) + elif key.startswith("proj_out.1."): + new_key = key.replace("proj_out.1.", "proj_out.") + if new_key.endswith(".weight"): + # PT ConvTranspose1d weight: [in, out, K] -> MLX: [out, K, in] + np_val = np_val.transpose(1, 2, 0) + + # Skip rotary embedding buffers (recomputed in MLX) + elif "rotary_emb" in key: + continue + + weights.append((new_key, mx.array(np_val))) + + logger.info( + "[MLX-DiT] Converted %d decoder parameters to MLX format.", len(weights) + ) + return weights + + +def convert_and_load( + pytorch_model, + mlx_decoder: "MLXDiTDecoder", +) -> None: + """Convert PyTorch decoder weights and load them into an MLX decoder. + + Args: + pytorch_model: The full AceStepConditionGenerationModel (PyTorch). + mlx_decoder: An instance of ``MLXDiTDecoder`` (already constructed). + """ + import mlx.core as mx + + weights = convert_decoder_weights(pytorch_model) + mlx_decoder.load_weights(weights) + mx.eval(mlx_decoder.parameters()) + logger.info("[MLX-DiT] Weights loaded and evaluated successfully.") diff --git a/acestep/mlx_dit/generate.py b/acestep/mlx_dit/generate.py new file mode 100644 index 0000000..5840372 --- /dev/null +++ b/acestep/mlx_dit/generate.py @@ -0,0 +1,213 @@ +# MLX diffusion generation loop for AceStep DiT decoder. +# +# Replicates the timestep scheduling and ODE/SDE stepping from +# ``AceStepConditionGenerationModel.generate_audio`` using pure MLX arrays. + +import logging +import time +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np + +logger = logging.getLogger(__name__) + +# Pre-defined timestep schedules (from modeling_acestep_v15_turbo.py) +VALID_SHIFTS = [1.0, 2.0, 3.0] + +VALID_TIMESTEPS = [ + 1.0, 0.9545454545454546, 0.9333333333333333, 0.9, 0.875, + 0.8571428571428571, 0.8333333333333334, 0.7692307692307693, 0.75, + 0.6666666666666666, 0.6428571428571429, 0.625, 0.5454545454545454, + 0.5, 0.4, 0.375, 0.3, 0.25, 0.2222222222222222, 0.125, +] + +SHIFT_TIMESTEPS = { + 1.0: [1.0, 0.875, 0.75, 0.625, 0.5, 0.375, 0.25, 0.125], + 2.0: [1.0, 0.9333333333333333, 0.8571428571428571, 0.7692307692307693, + 0.6666666666666666, 0.5454545454545454, 0.4, 0.2222222222222222], + 3.0: [1.0, 0.9545454545454546, 0.9, 0.8333333333333334, 0.75, + 0.6428571428571429, 0.5, 0.3], +} + + +def get_timestep_schedule( + shift: float = 3.0, + timesteps: Optional[list] = None, +) -> List[float]: + """Compute the timestep schedule for diffusion sampling. + + Replicates the logic from the turbo model's ``generate_audio`` method. + + Args: + shift: Diffusion timestep shift (1, 2, or 3). + timesteps: Optional custom list of timesteps. + + Returns: + List of timestep values (descending, without trailing 0). + """ + t_schedule_list = None + + if timesteps is not None: + ts_list = list(timesteps) + # Remove trailing zeros + while ts_list and ts_list[-1] == 0: + ts_list.pop() + if len(ts_list) < 1: + logger.warning("timesteps empty after removing zeros; using default shift=%s", shift) + else: + if len(ts_list) > 20: + logger.warning("timesteps length=%d > 20; truncating", len(ts_list)) + ts_list = ts_list[:20] + # Map each timestep to the nearest valid value + mapped = [min(VALID_TIMESTEPS, key=lambda x, t=t: abs(x - t)) for t in ts_list] + t_schedule_list = mapped + + if t_schedule_list is None: + original_shift = shift + shift = min(VALID_SHIFTS, key=lambda x: abs(x - shift)) + if original_shift != shift: + logger.warning("shift=%.2f rounded to nearest valid shift=%.1f", original_shift, shift) + t_schedule_list = SHIFT_TIMESTEPS[shift] + + return t_schedule_list + + +def mlx_generate_diffusion( + mlx_decoder, + encoder_hidden_states_np: np.ndarray, + context_latents_np: np.ndarray, + src_latents_shape: Tuple[int, ...], + seed: Optional[Union[int, List[int]]] = None, + infer_method: str = "ode", + shift: float = 3.0, + timesteps: Optional[list] = None, + audio_cover_strength: float = 1.0, + encoder_hidden_states_non_cover_np: Optional[np.ndarray] = None, + context_latents_non_cover_np: Optional[np.ndarray] = None, +) -> Dict[str, object]: + """Run the complete MLX diffusion loop. + + This is the core generation function. It accepts numpy arrays (converted + from PyTorch tensors by the handler) and returns numpy arrays that the + handler converts back to PyTorch. + + Args: + mlx_decoder: ``MLXDiTDecoder`` instance with loaded weights. + encoder_hidden_states_np: [B, enc_L, D] from prepare_condition (numpy). + context_latents_np: [B, T, C] from prepare_condition (numpy). + src_latents_shape: shape tuple [B, T, 64] for noise generation. + seed: random seed (int, list[int], or None). + infer_method: "ode" or "sde". + shift: timestep shift factor. + timesteps: optional custom timestep list. + audio_cover_strength: cover strength (0-1). + encoder_hidden_states_non_cover_np: optional [B, enc_L, D] for non-cover. + context_latents_non_cover_np: optional [B, T, C] for non-cover. + + Returns: + Dict with ``"target_latents"`` (numpy) and ``"time_costs"`` dict. + """ + import mlx.core as mx + from .model import MLXCrossAttentionCache + + time_costs = {} + total_start = time.time() + + # Convert numpy arrays to MLX + enc_hs = mx.array(encoder_hidden_states_np) + ctx = mx.array(context_latents_np) + + enc_hs_nc = mx.array(encoder_hidden_states_non_cover_np) if encoder_hidden_states_non_cover_np is not None else None + ctx_nc = mx.array(context_latents_non_cover_np) if context_latents_non_cover_np is not None else None + + bsz = src_latents_shape[0] + T = src_latents_shape[1] + C = src_latents_shape[2] + + # ---- Noise preparation ---- + if seed is None: + noise = mx.random.normal((bsz, T, C)) + elif isinstance(seed, list): + parts = [] + for s in seed: + if s is None or s < 0: + parts.append(mx.random.normal((1, T, C))) + else: + key = mx.random.key(int(s)) + parts.append(mx.random.normal((1, T, C), key=key)) + noise = mx.concatenate(parts, axis=0) + else: + key = mx.random.key(int(seed)) + noise = mx.random.normal((bsz, T, C), key=key) + + # ---- Timestep schedule ---- + t_schedule_list = get_timestep_schedule(shift, timesteps) + num_steps = len(t_schedule_list) + + cover_steps = int(num_steps * audio_cover_strength) + + # ---- Diffusion loop ---- + cache = MLXCrossAttentionCache() + xt = noise + + diff_start = time.time() + + for step_idx in range(num_steps): + current_t = t_schedule_list[step_idx] + t_curr = mx.full((bsz,), current_t) + + # Switch to non-cover conditions when appropriate + if step_idx >= cover_steps and enc_hs_nc is not None: + enc_hs = enc_hs_nc + ctx = ctx_nc + cache = MLXCrossAttentionCache() + + vt, cache = mlx_decoder( + hidden_states=xt, + timestep=t_curr, + timestep_r=t_curr, + encoder_hidden_states=enc_hs, + context_latents=ctx, + cache=cache, + use_cache=True, + ) + + # Evaluate to ensure computation is complete before next step + mx.eval(vt) + + # Final step: compute x0 + if step_idx == num_steps - 1: + t_unsq = mx.expand_dims(mx.expand_dims(t_curr, axis=-1), axis=-1) + xt = xt - vt * t_unsq + mx.eval(xt) + break + + # ODE / SDE update + next_t = t_schedule_list[step_idx + 1] + if infer_method == "sde": + t_unsq = mx.expand_dims(mx.expand_dims(t_curr, axis=-1), axis=-1) + pred_clean = xt - vt * t_unsq + # Re-noise with next timestep + new_noise = mx.random.normal(xt.shape) + xt = next_t * new_noise + (1.0 - next_t) * pred_clean + else: + # ODE Euler step: x_{t+1} = x_t - v_t * dt + dt = current_t - next_t + dt_arr = mx.full((bsz, 1, 1), dt) + xt = xt - vt * dt_arr + + mx.eval(xt) + + diff_end = time.time() + total_end = time.time() + + time_costs["diffusion_time_cost"] = diff_end - diff_start + time_costs["diffusion_per_step_time_cost"] = time_costs["diffusion_time_cost"] / max(num_steps, 1) + time_costs["total_time_cost"] = total_end - total_start + + # Convert result back to numpy + result_np = np.array(xt) + return { + "target_latents": result_np, + "time_costs": time_costs, + } diff --git a/acestep/mlx_dit/model.py b/acestep/mlx_dit/model.py new file mode 100644 index 0000000..6bdea0a --- /dev/null +++ b/acestep/mlx_dit/model.py @@ -0,0 +1,629 @@ +# This module re-implements the diffusion transformer decoder from +# modeling_acestep_v15_turbo.py using pure MLX operations for optimal +# performance on Apple Silicon. + +import math +from typing import Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + + +# --------------------------------------------------------------------------- +# Utility helpers +# --------------------------------------------------------------------------- + +def _rotate_half(x: mx.array) -> mx.array: + """Rotate the last dimension by splitting in half and swapping with negation.""" + half = x.shape[-1] // 2 + x1 = x[..., :half] + x2 = x[..., half:] + return mx.concatenate([-x2, x1], axis=-1) + + +def _apply_rotary_pos_emb( + q: mx.array, k: mx.array, cos: mx.array, sin: mx.array +) -> Tuple[mx.array, mx.array]: + """Apply rotary position embeddings to query and key tensors. + + Args: + q, k: [B, n_heads, L, head_dim] + cos, sin: [1, 1, L, head_dim] + """ + q_embed = (q * cos) + (_rotate_half(q) * sin) + k_embed = (k * cos) + (_rotate_half(k) * sin) + return q_embed, k_embed + + +def _create_sliding_window_mask( + seq_len: int, window_size: int, dtype: mx.Dtype = mx.float32 +) -> mx.array: + """Create a bidirectional sliding-window additive attention mask. + + Positions within ``window_size`` of each other get ``0``; all others + receive a large negative value (``-1e9``). + + Returns: + [1, 1, seq_len, seq_len] + """ + indices = mx.arange(seq_len) + # diff[i, j] = |i - j| + diff = mx.abs(indices[:, None] - indices[None, :]) + zeros = mx.zeros(diff.shape, dtype=dtype) + neginf = mx.full(diff.shape, -1e9, dtype=dtype) + mask = mx.where(diff <= window_size, zeros, neginf) + return mask[None, None, :, :] # [1, 1, L, L] + + +# --------------------------------------------------------------------------- +# Rotary Position Embedding +# --------------------------------------------------------------------------- + +class MLXRotaryEmbedding(nn.Module): + """Pre-computes and caches cos/sin tables for rotary position embeddings.""" + + def __init__(self, head_dim: int, max_len: int = 32768, base: float = 1_000_000.0): + super().__init__() + self.head_dim = head_dim + self.max_len = max_len + self.base = base + + inv_freq = 1.0 / ( + base ** (mx.arange(0, head_dim, 2).astype(mx.float32) / head_dim) + ) + positions = mx.arange(max_len).astype(mx.float32) + freqs = positions[:, None] * inv_freq[None, :] # [max_len, head_dim//2] + freqs = mx.concatenate([freqs, freqs], axis=-1) # [max_len, head_dim] + self._cos = mx.cos(freqs) # [max_len, head_dim] + self._sin = mx.sin(freqs) # [max_len, head_dim] + + def __call__(self, seq_len: int) -> Tuple[mx.array, mx.array]: + """Return (cos, sin) each shaped [1, 1, seq_len, head_dim].""" + cos = self._cos[:seq_len][None, None, :, :] + sin = self._sin[:seq_len][None, None, :, :] + return cos, sin + + +# --------------------------------------------------------------------------- +# Cross-Attention KV Cache +# --------------------------------------------------------------------------- + +class MLXCrossAttentionCache: + """Simple KV cache for cross-attention layers. + + Cross-attention K/V are computed from encoder hidden states once on the + first diffusion step and re-used for all subsequent steps. + """ + + def __init__(self): + self._keys: dict[int, mx.array] = {} + self._values: dict[int, mx.array] = {} + self._updated: set[int] = set() + + def update(self, key: mx.array, value: mx.array, layer_idx: int): + self._keys[layer_idx] = key + self._values[layer_idx] = value + self._updated.add(layer_idx) + + def is_updated(self, layer_idx: int) -> bool: + return layer_idx in self._updated + + def get(self, layer_idx: int) -> Tuple[mx.array, mx.array]: + return self._keys[layer_idx], self._values[layer_idx] + + +# --------------------------------------------------------------------------- +# Core Layers +# --------------------------------------------------------------------------- + +class MLXSwiGLUMLP(nn.Module): + """SwiGLU MLP (equivalent to Qwen3MLP): gate * silu(gate_proj) * up_proj.""" + + def __init__(self, hidden_size: int, intermediate_size: int): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + + def __call__(self, x: mx.array) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class MLXAttention(nn.Module): + """Multi-head attention with QK-RMSNorm for the AceStep DiT. + + Supports both self-attention (with RoPE) and cross-attention (with + optional KV caching). + """ + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + head_dim: int, + rms_norm_eps: float, + attention_bias: bool, + layer_idx: int, + is_cross_attention: bool = False, + sliding_window: Optional[int] = None, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_attention_heads + self.num_kv_heads = num_key_value_heads + self.head_dim = head_dim + self.n_rep = num_attention_heads // num_key_value_heads + self.scale = head_dim ** -0.5 + self.layer_idx = layer_idx + self.is_cross_attention = is_cross_attention + self.sliding_window = sliding_window + + self.q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=attention_bias) + self.k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias) + self.v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias) + self.o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias=attention_bias) + + self.q_norm = nn.RMSNorm(head_dim, eps=rms_norm_eps) + self.k_norm = nn.RMSNorm(head_dim, eps=rms_norm_eps) + + @staticmethod + def _repeat_kv(x: mx.array, n_rep: int) -> mx.array: + """Repeat KV heads for GQA: [B, n_kv, L, D] -> [B, n_kv*n_rep, L, D].""" + if n_rep == 1: + return x + B, n_kv, L, D = x.shape + x = mx.expand_dims(x, axis=2) # [B, n_kv, 1, L, D] + x = mx.broadcast_to(x, (B, n_kv, n_rep, L, D)) + return x.reshape(B, n_kv * n_rep, L, D) + + def __call__( + self, + hidden_states: mx.array, + position_cos_sin: Optional[Tuple[mx.array, mx.array]] = None, + attention_mask: Optional[mx.array] = None, + encoder_hidden_states: Optional[mx.array] = None, + cache: Optional[MLXCrossAttentionCache] = None, + use_cache: bool = False, + ) -> mx.array: + B, L, _ = hidden_states.shape + + # Project queries (always from hidden_states) + q = self.q_proj(hidden_states) + q = self.q_norm(q.reshape(B, L, self.num_heads, self.head_dim)) + q = q.transpose(0, 2, 1, 3) # [B, n_heads, L, D] + + if self.is_cross_attention and encoder_hidden_states is not None: + # Cross-attention: K,V come from encoder + if cache is not None and cache.is_updated(self.layer_idx): + k, v = cache.get(self.layer_idx) + else: + enc_L = encoder_hidden_states.shape[1] + k = self.k_proj(encoder_hidden_states) + k = self.k_norm(k.reshape(B, enc_L, self.num_kv_heads, self.head_dim)) + k = k.transpose(0, 2, 1, 3) + v = self.v_proj(encoder_hidden_states).reshape( + B, enc_L, self.num_kv_heads, self.head_dim + ).transpose(0, 2, 1, 3) + if cache is not None and use_cache: + cache.update(k, v, self.layer_idx) + else: + # Self-attention: K,V come from hidden_states + k = self.k_proj(hidden_states) + k = self.k_norm(k.reshape(B, L, self.num_kv_heads, self.head_dim)) + k = k.transpose(0, 2, 1, 3) + v = self.v_proj(hidden_states).reshape( + B, L, self.num_kv_heads, self.head_dim + ).transpose(0, 2, 1, 3) + + # Apply RoPE to self-attention Q,K + if position_cos_sin is not None: + cos, sin = position_cos_sin + q, k = _apply_rotary_pos_emb(q, k, cos, sin) + + # GQA: repeat KV heads to match Q heads + k = self._repeat_kv(k, self.n_rep) + v = self._repeat_kv(v, self.n_rep) + + # Scaled dot-product attention + attn_out = mx.fast.scaled_dot_product_attention( + q, k, v, scale=self.scale, mask=attention_mask + ) + + # Merge heads and project output: [B, n_heads, L, D] -> [B, L, hidden] + attn_out = attn_out.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(attn_out) + + +# --------------------------------------------------------------------------- +# DiT Layer +# --------------------------------------------------------------------------- + +class MLXDiTLayer(nn.Module): + """A single DiT transformer layer with AdaLN modulation. + + Implements: + 1. Self-attention with adaptive layer norm (AdaLN) + 2. Cross-attention to encoder hidden states + 3. Feed-forward MLP with adaptive layer norm + """ + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + num_attention_heads: int, + num_key_value_heads: int, + head_dim: int, + rms_norm_eps: float, + attention_bias: bool, + layer_idx: int, + layer_type: str, + sliding_window: Optional[int] = None, + ): + super().__init__() + self.layer_type = layer_type + sw = sliding_window if layer_type == "sliding_attention" else None + + # 1. Self-attention + self.self_attn_norm = nn.RMSNorm(hidden_size, eps=rms_norm_eps) + self.self_attn = MLXAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + rms_norm_eps=rms_norm_eps, + attention_bias=attention_bias, + layer_idx=layer_idx, + is_cross_attention=False, + sliding_window=sw, + ) + + # 2. Cross-attention + self.cross_attn_norm = nn.RMSNorm(hidden_size, eps=rms_norm_eps) + self.cross_attn = MLXAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + rms_norm_eps=rms_norm_eps, + attention_bias=attention_bias, + layer_idx=layer_idx, + is_cross_attention=True, + ) + + # 3. MLP + self.mlp_norm = nn.RMSNorm(hidden_size, eps=rms_norm_eps) + self.mlp = MLXSwiGLUMLP(hidden_size, intermediate_size) + + # AdaLN modulation table (6 values: shift/scale/gate for self-attn & MLP) + self.scale_shift_table = mx.zeros((1, 6, hidden_size)) + + def __call__( + self, + hidden_states: mx.array, + position_cos_sin: Tuple[mx.array, mx.array], + temb: mx.array, + self_attn_mask: Optional[mx.array], + encoder_hidden_states: Optional[mx.array], + encoder_attention_mask: Optional[mx.array], + cache: Optional[MLXCrossAttentionCache] = None, + use_cache: bool = False, + ) -> mx.array: + # AdaLN modulation from timestep embeddings + # scale_shift_table: [1, 6, D], temb: [B, 6, D] + modulation = self.scale_shift_table + temb # [B, 6, D] + parts = mx.split(modulation, 6, axis=1) + # Each part: [B, 1, D] + shift_msa, scale_msa, gate_msa = parts[0], parts[1], parts[2] + c_shift_msa, c_scale_msa, c_gate_msa = parts[3], parts[4], parts[5] + + # Step 1: Self-attention with AdaLN + normed = self.self_attn_norm(hidden_states) + normed = normed * (1.0 + scale_msa) + shift_msa + attn_out = self.self_attn( + normed, + position_cos_sin=position_cos_sin, + attention_mask=self_attn_mask, + ) + hidden_states = hidden_states + attn_out * gate_msa + + # Step 2: Cross-attention + normed = self.cross_attn_norm(hidden_states) + cross_out = self.cross_attn( + normed, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + cache=cache, + use_cache=use_cache, + ) + hidden_states = hidden_states + cross_out + + # Step 3: MLP with AdaLN + normed = self.mlp_norm(hidden_states) + normed = normed * (1.0 + c_scale_msa) + c_shift_msa + ff_out = self.mlp(normed) + hidden_states = hidden_states + ff_out * c_gate_msa + + return hidden_states + + +# --------------------------------------------------------------------------- +# Timestep Embedding +# --------------------------------------------------------------------------- + +class MLXTimestepEmbedding(nn.Module): + """Sinusoidal timestep embedding followed by MLP.""" + + def __init__(self, in_channels: int = 256, time_embed_dim: int = 2048, scale: float = 1000.0): + super().__init__() + self.in_channels = in_channels + self.scale = scale + + self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True) + self.act1 = nn.SiLU() + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, bias=True) + self.act2 = nn.SiLU() + self.time_proj = nn.Linear(time_embed_dim, time_embed_dim * 6, bias=True) + + def _sinusoidal_embedding(self, t: mx.array, dim: int, max_period: int = 10000) -> mx.array: + """Create sinusoidal timestep embeddings. + + Args: + t: 1-D array of shape [N] + dim: embedding dimension + Returns: + [N, dim] + """ + t = t * self.scale + half = dim // 2 + freqs = mx.exp( + -math.log(max_period) + * mx.arange(half).astype(mx.float32) / half + ) + args = t[:, None].astype(mx.float32) * freqs[None, :] + embedding = mx.concatenate([mx.cos(args), mx.sin(args)], axis=-1) + if dim % 2: + embedding = mx.concatenate( + [embedding, mx.zeros_like(embedding[:, :1])], axis=-1 + ) + return embedding + + def __call__(self, t: mx.array) -> Tuple[mx.array, mx.array]: + """ + Args: + t: [B] timestep values + Returns: + temb: [B, D] + timestep_proj: [B, 6, D] + """ + t_freq = self._sinusoidal_embedding(t, self.in_channels) + temb = self.linear_1(t_freq.astype(t.dtype)) + temb = self.act1(temb) + temb = self.linear_2(temb) + proj = self.time_proj(self.act2(temb)) # [B, D*6] + timestep_proj = proj.reshape(proj.shape[0], 6, -1) # [B, 6, D] + return temb, timestep_proj + + +# --------------------------------------------------------------------------- +# Full DiT Decoder +# --------------------------------------------------------------------------- + +class MLXDiTDecoder(nn.Module): + """Native MLX implementation of AceStepDiTModel (the diffusion transformer decoder). + + Mirrors the PyTorch ``AceStepDiTModel`` class exactly: + - Patch-based input projection (Conv1d) + - Timestep conditioning via dual TimestepEmbedding + - N DiT transformer layers with self/cross-attention and AdaLN + - Patch-based output projection (ConvTranspose1d) + - Adaptive output layer norm + """ + + def __init__( + self, + hidden_size: int = 2048, + intermediate_size: int = 6144, + num_hidden_layers: int = 24, + num_attention_heads: int = 16, + num_key_value_heads: int = 8, + head_dim: int = 128, + rms_norm_eps: float = 1e-6, + attention_bias: bool = False, + in_channels: int = 192, + audio_acoustic_hidden_dim: int = 64, + patch_size: int = 2, + sliding_window: int = 128, + layer_types: Optional[list] = None, + rope_theta: float = 1_000_000.0, + max_position_embeddings: int = 32768, + ): + super().__init__() + self.hidden_size = hidden_size + self.patch_size = patch_size + inner_dim = hidden_size + + if layer_types is None: + layer_types = [ + "sliding_attention" if bool((i + 1) % 2) else "full_attention" + for i in range(num_hidden_layers) + ] + + # Rotary position embeddings + self.rotary_emb = MLXRotaryEmbedding( + head_dim, max_len=max_position_embeddings, base=rope_theta + ) + + # Input projection: Conv1d patch embedding + # MLX Conv1d uses channels-last: [B, L, C] -> [B, L//stride, out_C] + self.proj_in = nn.Conv1d( + in_channels=in_channels, + out_channels=inner_dim, + kernel_size=patch_size, + stride=patch_size, + padding=0, + ) + + # Timestep embeddings (two: t and t-r) + self.time_embed = MLXTimestepEmbedding(in_channels=256, time_embed_dim=inner_dim) + self.time_embed_r = MLXTimestepEmbedding(in_channels=256, time_embed_dim=inner_dim) + + # Condition embedder + self.condition_embedder = nn.Linear(inner_dim, inner_dim, bias=True) + + # Transformer layers + self.layers = [ + MLXDiTLayer( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + rms_norm_eps=rms_norm_eps, + attention_bias=attention_bias, + layer_idx=i, + layer_type=layer_types[i], + sliding_window=sliding_window, + ) + for i in range(num_hidden_layers) + ] + + # Output + self.norm_out = nn.RMSNorm(inner_dim, eps=rms_norm_eps) + self.proj_out = nn.ConvTranspose1d( + in_channels=inner_dim, + out_channels=audio_acoustic_hidden_dim, + kernel_size=patch_size, + stride=patch_size, + padding=0, + ) + + # Output adaptive layer norm modulation (2 values: shift, scale) + self.scale_shift_table = mx.zeros((1, 2, inner_dim)) + + # Pre-compute sliding window mask (will be set on first forward) + self._sliding_masks: dict[int, mx.array] = {} + self._sliding_window = sliding_window + self._layer_types = layer_types + + def _get_sliding_mask(self, seq_len: int, dtype: mx.Dtype) -> mx.array: + if seq_len not in self._sliding_masks: + self._sliding_masks[seq_len] = _create_sliding_window_mask( + seq_len, self._sliding_window, dtype + ) + return self._sliding_masks[seq_len] + + def __call__( + self, + hidden_states: mx.array, + timestep: mx.array, + timestep_r: mx.array, + encoder_hidden_states: mx.array, + context_latents: mx.array, + cache: Optional[MLXCrossAttentionCache] = None, + use_cache: bool = True, + ) -> Tuple[mx.array, Optional[MLXCrossAttentionCache]]: + """ + Args: + hidden_states: noisy latents [B, T, 64] + timestep: [B] current timestep + timestep_r: [B] reference timestep + encoder_hidden_states: [B, enc_L, D] from condition encoder + context_latents: [B, T, C_ctx] (src_latents + chunk_masks) + cache: cross-attention KV cache + use_cache: whether to cache cross-attention KV + + Returns: + (output_hidden_states, cache) + """ + # Timestep embeddings + temb_t, proj_t = self.time_embed(timestep) + temb_r, proj_r = self.time_embed_r(timestep - timestep_r) + temb = temb_t + temb_r # [B, D] + timestep_proj = proj_t + proj_r # [B, 6, D] + + # Concatenate context with hidden states: [B, T, C_ctx + 64] -> [B, T, in_channels] + hidden_states = mx.concatenate([context_latents, hidden_states], axis=-1) + + original_seq_len = hidden_states.shape[1] + + # Pad to multiple of patch_size + pad_length = 0 + if hidden_states.shape[1] % self.patch_size != 0: + pad_length = self.patch_size - (hidden_states.shape[1] % self.patch_size) + # Pad along time dimension + padding = mx.zeros( + (hidden_states.shape[0], pad_length, hidden_states.shape[2]), + dtype=hidden_states.dtype, + ) + hidden_states = mx.concatenate([hidden_states, padding], axis=1) + + # Patch embedding: [B, T, in_ch] -> [B, T//patch, D] + hidden_states = self.proj_in(hidden_states) + + # Project encoder states + encoder_hidden_states = self.condition_embedder(encoder_hidden_states) + + seq_len = hidden_states.shape[1] + dtype = hidden_states.dtype + + # Position embeddings (RoPE) + cos, sin = self.rotary_emb(seq_len) + + # Attention masks + # Self-attention: full layers get None; sliding layers get windowed mask + # Cross-attention: always None (no masking) + sliding_mask = None + has_sliding = any(lt == "sliding_attention" for lt in self._layer_types) + if has_sliding: + sliding_mask = self._get_sliding_mask(seq_len, dtype) + + # Process through transformer layers + for layer in self.layers: + self_attn_mask = sliding_mask if layer.layer_type == "sliding_attention" else None + hidden_states = layer( + hidden_states, + position_cos_sin=(cos, sin), + temb=timestep_proj, + self_attn_mask=self_attn_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=None, + cache=cache, + use_cache=use_cache, + ) + + # Output adaptive layer norm + shift, scale = mx.split( + self.scale_shift_table + mx.expand_dims(temb, axis=1), 2, axis=1 + ) + hidden_states = self.norm_out(hidden_states) * (1.0 + scale) + shift + + # De-patchify: [B, T//patch, D] -> [B, T, out_channels] + hidden_states = self.proj_out(hidden_states) + + # Crop back to original sequence length + hidden_states = hidden_states[:, :original_seq_len, :] + + return hidden_states, cache + + @classmethod + def from_config(cls, config) -> "MLXDiTDecoder": + """Construct from an AceStepConfig (transformers PretrainedConfig).""" + return cls( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + num_hidden_layers=config.num_hidden_layers, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=getattr(config, "head_dim", config.hidden_size // config.num_attention_heads), + rms_norm_eps=config.rms_norm_eps, + attention_bias=config.attention_bias, + in_channels=config.in_channels, + audio_acoustic_hidden_dim=config.audio_acoustic_hidden_dim, + patch_size=config.patch_size, + sliding_window=config.sliding_window if config.sliding_window else 128, + layer_types=config.layer_types, + rope_theta=config.rope_theta, + max_position_embeddings=config.max_position_embeddings, + )