diff --git a/models/wan/any2video.py b/models/wan/any2video.py index 24fdbe348..9a8152e09 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -8,6 +8,7 @@ import sys import types import math +import threading from contextlib import contextmanager from functools import partial from mmgp import offload @@ -51,6 +52,104 @@ WAN_USE_FP32_ROPE_FREQS = True +_MPS_TRANSFORMER_LOCK = threading.RLock() +_MPS_TRANSFORMER_CALL_ANNOUNCED_CLEANUP = None +_MPS_SAMPLING_CLEANUP_DEFAULT = False +_MPS_PREVIEW_BYPASS_ANNOUNCED = False + + +def _mps_sampling_cleanup_enabled(): + if os.environ.get("WAN2GP_MPS_DISABLE_SAMPLING_EMPTY_CACHE") == "1": + return False + mode = os.environ.get("WAN2GP_MPS_SAMPLING_EMPTY_CACHE_MODE") + if mode is not None: + return mode.lower() not in {"0", "false", "off", "none"} + return _MPS_SAMPLING_CLEANUP_DEFAULT + + +def _mps_configure_sampling_empty_cache_default(model_type): + global _MPS_SAMPLING_CLEANUP_DEFAULT + _MPS_SAMPLING_CLEANUP_DEFAULT = ( + sys.platform == "darwin" + and hasattr(torch.backends, "mps") + and torch.backends.mps.is_available() + and model_type in {"t2v_1.3B", "ti2v_2_2", "t2v_2_2"} + ) + + +def _mps_generation_boundary_cleanup_enabled(): + return os.environ.get("WAN2GP_MPS_GENERATION_BOUNDARY_CLEANUP", "1") != "0" + + +def _mps_device_active(device): + if sys.platform != "darwin" or not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available(): + return False + if isinstance(device, torch.device): + return device.type in {"mps", "cuda"} + device = str(device) + return device.startswith("mps") or device.startswith("cuda") + + +def _mps_sync_barrier(empty_cache=False): + torch.mps.synchronize() + if empty_cache and _mps_sampling_cleanup_enabled(): + torch.mps.empty_cache() + torch.mps.synchronize() + + +def _call_transformer_for_mps(trans, **call_kwargs): + global _MPS_TRANSFORMER_CALL_ANNOUNCED_CLEANUP + if sys.platform != "darwin" or not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available(): + return trans(**call_kwargs) + + with _MPS_TRANSFORMER_LOCK: + cleanup_enabled = _mps_sampling_cleanup_enabled() + if _MPS_TRANSFORMER_CALL_ANNOUNCED_CLEANUP != cleanup_enabled: + print("[MPS] Using synchronized WAN transformer calls on macOS", flush=True) + if cleanup_enabled: + print("[MPS] MPS cache cleanup boundaries: step+transformer", flush=True) + print("[MPS] Step-boundary MPS cache cleanup: after every sampler step", flush=True) + else: + print("[MPS] MPS cache cleanup boundaries: disabled", flush=True) + if cleanup_enabled and _mps_generation_boundary_cleanup_enabled(): + print("[MPS] Generation-boundary MPS cache cleanup: enabled", flush=True) + _MPS_TRANSFORMER_CALL_ANNOUNCED_CLEANUP = cleanup_enabled + _mps_sync_barrier() + try: + return trans(**call_kwargs) + finally: + _mps_sync_barrier(empty_cache=True) + + +def _mps_sampler_step_barrier(device): + if _mps_device_active(device): + _mps_sync_barrier(empty_cache=True) + + +def _mps_generation_boundary_barrier(device): + if ( + _mps_device_active(device) + and _mps_generation_boundary_cleanup_enabled() + and _mps_sampling_cleanup_enabled() + ): + torch.mps.synchronize() + torch.mps.empty_cache() + torch.mps.synchronize() + + +def _mps_live_latent_preview_enabled(device): + if not _mps_device_active(device): + return True + return os.environ.get("WAN2GP_MPS_PREVIEW_LATENTS") == "1" + + +def _mps_announce_preview_bypass(): + global _MPS_PREVIEW_BYPASS_ANNOUNCED + if not _MPS_PREVIEW_BYPASS_ANNOUNCED: + print("[MPS] Disabled live WAN latent previews on macOS for sampler stability", flush=True) + _MPS_PREVIEW_BYPASS_ANNOUNCED = True + + def get_vista4d_rotary_pos_embed(latents_size): lat_t, lat_h, lat_w = latents_size grid_t, grid_h, grid_w = lat_t, lat_h // 2, lat_w // 2 @@ -499,6 +598,8 @@ def generate(self, ): model_def = self.model_def + _mps_configure_sampling_empty_cache_default(model_type) + _mps_generation_boundary_barrier(self.device) if sample_solver =="euler": sample_scheduler = EulerScheduler( @@ -1189,6 +1290,8 @@ def update_guidance(step_no, t, guide_scale, new_guide_scale, guidance_switch_do callback(-1, None, True, override_num_inference_steps = updated_num_steps, denoising_extra = denoising_extra) def clear(): + if _mps_device_active(self.device): + _mps_sync_barrier() clear_caches() gc.collect() torch.cuda.empty_cache() @@ -1357,7 +1460,7 @@ def denoise_with_cfg_fn(latents): } if joint_pass and any_guidance: - ret_values = trans( **gen_args , **kwargs) + ret_values = _call_transformer_for_mps(trans, **gen_args, **kwargs) if self._interrupt: return clear() else: @@ -1365,7 +1468,7 @@ def denoise_with_cfg_fn(latents): ret_values = [None] * size for x_id in range(size): sub_gen_args = {k : [v[x_id]] for k, v in gen_args.items() } - ret_values[x_id] = trans( **sub_gen_args, x_id= x_id , **kwargs)[0] + ret_values[x_id] = _call_transformer_for_mps(trans, **sub_gen_args, x_id=x_id, **kwargs)[0] if self._interrupt: return clear() sub_gen_args = None @@ -1460,11 +1563,20 @@ def denoise_with_cfg_fn(latents): sigma = 0 if i == len(timesteps)-1 else timesteps[i+1]/1000 noisy_image = randn[:, :, :source_latents.shape[2]] * sigma + (1 - sigma) * source_latents latents[:, :, :source_latents.shape[2]] = noisy_image * (1-image_mask_latents) + image_mask_latents * latents[:, :, :source_latents.shape[2]] + noisy_image = None + # Release the denoise output before the MPS step synchronization + # boundary so Metal does not hold onto it longer than needed. + noise_pred = None + _mps_sampler_step_barrier(self.device) if callback is not None: + if not _mps_live_latent_preview_enabled(self.device): + _mps_announce_preview_bypass() + callback(i, None, False, denoising_extra=denoising_extra) + continue latents_preview = latents - if ref_images_before and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ] + if ref_images_before and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ] if trim_frames > 0: latents_preview= latents_preview[:, :,:-trim_frames] if image_outputs: latents_preview= latents_preview[:, :,-1:] if last_latent_preview else latents_preview[:, :,:1] if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2) @@ -1539,6 +1651,7 @@ def denoise_with_cfg_fn(latents): if videos.dtype != torch.uint8: videos = videos.clamp_(-1, 1).add_(1.0).mul_(127.5).round_().clamp_(0, 255).to(torch.uint8) if BGRA_frames is not None: ret["BGRA_frames"] = BGRA_frames + _mps_generation_boundary_barrier(self.device) return ret def get_loras_transformer(self, get_model_recursive_prop, base_model_type, model_type, video_prompt_type, model_mode, **kwargs): diff --git a/shared/mps/device_patch.py b/shared/mps/device_patch.py index fab5f7d7f..61c6fecf9 100644 --- a/shared/mps/device_patch.py +++ b/shared/mps/device_patch.py @@ -16,8 +16,15 @@ os.environ.setdefault('TORCHINDUCTOR', '0') os.environ.setdefault('PYTORCH_ENABLE_MPS_FALLBACK', '1') +_MPS_PATCH_LOGGED = False + + def apply_mps_patch(): """Patch torch.cuda functions for MPS compatibility.""" + global _MPS_PATCH_LOGGED + log_this_call = not _MPS_PATCH_LOGGED + _MPS_PATCH_LOGGED = True + import torch as _torch chip_name = _get_chip_name() @@ -31,13 +38,14 @@ def apply_mps_patch(): dev_cap = (11, 0) bfloat16_supported = True - print(f"[MPS Patch] Detected: {chip_name}, {system_ram_gb:.0f}GB RAM") - print(f"[MPS Patch] Device capability: {dev_cap}, BF16: {bfloat16_supported}") + if log_this_call: + print(f"[MPS] Detected: {chip_name}, {system_ram_gb:.0f}GB RAM") + print(f"[MPS] Device capability: {dev_cap}, BF16: {bfloat16_supported}") # Dummy objects _dummy_stream = types.SimpleNamespace( synchronize=_torch.mps.synchronize, - wait_stream=lambda *a, **kw: None, + wait_stream=lambda *a, **kw: _torch.mps.synchronize(), query=lambda: True, priority=0, ) @@ -54,9 +62,9 @@ def __init__(self): class _DummyEvent: def __init__(self, *a, **kw): pass - def record(self, *a, **kw): pass + def record(self, *a, **kw): _torch.mps.synchronize() def elapsed_time(self, *a, **kw): return 0.0 - def synchronize(self, *a, **kw): pass + def synchronize(self, *a, **kw): _torch.mps.synchronize() def query(self): return True class _DummyDeviceContext: @@ -66,8 +74,11 @@ def __exit__(self, *args): pass class _DummyStreamContext: def __init__(self, s): pass - def __enter__(self): return self - def __exit__(self, *a): pass + def __enter__(self): + _torch.mps.synchronize() + return self + def __exit__(self, *a): + _torch.mps.synchronize() class _DummyGraph: replay = lambda s, *a, **kw: None @@ -105,9 +116,16 @@ def load_state_dict(self, *a): pass _cuda = _torch.cuda # Core function patches + _orig_mps_empty_cache = _torch.mps.empty_cache + def _safe_mps_empty_cache(*args, **kwargs): + _torch.mps.synchronize() + result = _orig_mps_empty_cache(*args, **kwargs) + _torch.mps.synchronize() + return result + _cuda.is_available = lambda: False _cuda._is_compiled = lambda: False - _cuda.empty_cache = _torch.mps.empty_cache + _cuda.empty_cache = _safe_mps_empty_cache _cuda.synchronize = _torch.mps.synchronize _cuda.get_device_capability = lambda device=None: dev_cap _cuda.manual_seed_all = lambda seed: None @@ -125,8 +143,8 @@ def load_state_dict(self, *a): pass class _PatchedStream: priority = 0 def __init__(self, *a, **kw): pass - def synchronize(self, *a, **kw): pass - def wait_stream(self, *a, **kw): pass + def synchronize(self, *a, **kw): _torch.mps.synchronize() + def wait_stream(self, *a, **kw): _torch.mps.synchronize() def query(self): return True _cuda.Stream = _PatchedStream @@ -190,6 +208,12 @@ def _patched_autocast(device_type=None, *args, **kwargs): # Handle torch.cuda.amp.autocast which calls with device_type=None initially if device_type is None and 'device_type' not in kwargs: device_type = 'mps' + # MPS only supports fp16/bf16 autocast + dtype = kwargs.get("dtype", None) + if device_type == "mps": + if dtype not in (_torch.float16, _torch.bfloat16, None): + # safest choice for Apple Silicon + kwargs["dtype"] = _torch.float16 return _orig_autocast(device_type, *args, **kwargs) _torch.autocast = _patched_autocast # Also patch torch.amp.autocast @@ -232,6 +256,17 @@ def _patched_tensor_to(self, *args, **kwargs): # Handle keyword device arg if "device" in kwargs: kwargs["device"] = _replace_cuda_device(kwargs["device"]) + target_device = kwargs.get("device", None) + if target_device is None: + for arg in new_args: + if isinstance(arg, str) and arg.startswith("mps"): + target_device = arg + break + if isinstance(arg, _torch.device) and arg.type == "mps": + target_device = arg + break + if target_device == "mps" or (isinstance(target_device, _torch.device) and target_device.type == "mps"): + kwargs["non_blocking"] = False return _orig_tensor_to(self, *new_args, **kwargs) _torch.Tensor.to = _patched_tensor_to @@ -285,9 +320,10 @@ def patched(*args, **kwargs): return patched setattr(_torch, fn_name, make_patcher(orig)) - print(f"[MPS Patch] Applied successfully") - print(f"[MPS Patch] BF16 supported: {bfloat16_supported}") - print(f"[MPS Patch] Available system RAM: {system_ram_gb:.0f}GB") + if log_this_call: + print(f"[MPS] Applied successfully") + print(f"[MPS] BF16 supported: {bfloat16_supported}") + print(f"[MPS] Available system RAM: {system_ram_gb:.0f}GB") # Fix: Some Wan model loading paths call .weight on an nn.Parameter, # which is a Tensor subclass, not a Module. On MPS this fails because @@ -336,59 +372,76 @@ def _cuda_getDefaultStream_stub(device_index=0): if not hasattr(_torch.mps, 'set_device'): _torch.mps.set_device = lambda device: None -# Force SDPA math backend on MPS to avoid Metal command buffer double-commit crash -# Reference: [IOGPUMetalCommandBuffer validate]:214: failed assertion `commit an already committed command buffer' -# -# Root cause: MPS fallback ops (CPU fallback from PYTORCH_ENABLE_MPS_FALLBACK=1) -# corrupt Metal command buffers when mixed with native MPS SDPA. This affects: -# - Wan 2.2 5B (quanto mbf16 quantization) -# - Wan 2.1 1.3B (standard safetensors, but ops still fallback) -# - Any model where CPU-fallback ops precede an SDPA call -# -# Fix strategy (defense in depth): -# 1. Synchronize MPS before SDPA to flush pending fallback ops -# 2. Force MATH backend (avoids MPS-native SDPA bugs on some macOS versions) -# 3. If SDPA fails with Metal error, fall back to manual attention (matmul + softmax) -# 4. Periodic empty_cache to prevent memory fragmentation +# Native MPS SDPA can still intermittently trip Metal command-buffer assertions +# on WAN video paths. Default to a synchronized matmul fallback for stability. +# Set WAN2GP_MPS_NATIVE_SDPA=1 to opt into native SDPA for diagnostics. if _is_mps: _orig_sdpa = _torch.nn.functional.scaled_dot_product_attention - _sdpa_call_count = [0] - - def _manual_sdpa_fallback(query, key, value, attn_mask=None, is_causal=False, scale=None): + _sdpa_mode_announced = [False] + + def _expand_attn_bias(attn_bias, ndim): + while attn_bias.dim() < ndim: + attn_bias = attn_bias.unsqueeze(0) + return attn_bias + + def _manual_sdpa_fallback( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, + ): """Manual attention fallback: matmul + softmax, no Metal SDPA.""" - # query: (B, H, L, D) or (B, L, H, D) after sdpa_kernel wrap + if enable_gqa: + repeat = query.size(-3) // key.size(-3) + key = key.repeat_interleave(repeat, -3) + value = value.repeat_interleave(repeat, -3) + L = query.size(-2) + S = key.size(-2) D = query.size(-1) - if scale is None: - scale = D ** -0.5 + scale_factor = D ** -0.5 if scale is None else scale - attn_weights = _torch.matmul(query, key.transpose(-2, -1)) * scale - if attn_mask is not None: - attn_weights = attn_weights + attn_mask + attn_bias = _torch.zeros(L, S, dtype=query.dtype, device=query.device) if is_causal: - causal_mask = _torch.triu( - _torch.ones(L, L, device=query.device, dtype=_torch.bool), diagonal=1 - ) - attn_weights = attn_weights.masked_fill(causal_mask, float('-inf')) - attn_weights = _torch.nn.functional.softmax(attn_weights, dim=-1) - return _torch.matmul(attn_weights, value) + causal_mask = _torch.ones(L, S, dtype=_torch.bool, device=query.device).tril() + attn_bias = attn_bias.masked_fill(causal_mask.logical_not(), float("-inf")) - def _patched_sdpa(*args, **kwargs): - # Flush pending MPS fallback ops before SDPA - _torch.mps.synchronize() + if attn_mask is not None: + if attn_mask.dtype == _torch.bool: + attn_bias = _expand_attn_bias(attn_bias, attn_mask.dim()) + attn_bias = attn_bias.masked_fill(attn_mask.logical_not(), float("-inf")) + else: + attn_bias = attn_bias + attn_mask + + attn_weight = _torch.matmul(query, key.transpose(-2, -1)) * scale_factor + attn_weight = attn_weight + _expand_attn_bias(attn_bias, query.dim()) + attn_weight = _torch.nn.functional.softmax(attn_weight, dim=-1) - # Periodic cache cleanup every 256 SDPA calls - _sdpa_call_count[0] += 1 - if _sdpa_call_count[0] % 256 == 0: - _torch.mps.empty_cache() + if dropout_p: + attn_weight = _torch.dropout(attn_weight, dropout_p, train=True) - try: - with _torch.nn.attention.sdpa_kernel([_torch.nn.attention.SDPBackend.MATH]): - return _orig_sdpa(*args, **kwargs) - except Exception: - # Metal command buffer corruption caught — fall back to manual attention - # This handles cases where synchronize isn't sufficient (e.g. macOS 26.x bugs) - return _manual_sdpa_fallback(*args, **kwargs) + return _torch.matmul(attn_weight, value) + + def _patched_sdpa(*args, **kwargs): + if os.environ.get("WAN2GP_MPS_NATIVE_SDPA") != "1": + if not _sdpa_mode_announced[0]: + print("[MPS] SDPA mode: synchronized manual MPS fallback", flush=True) + _sdpa_mode_announced[0] = True + _torch.mps.synchronize() + out = _manual_sdpa_fallback(*args, **kwargs) + _torch.mps.synchronize() + return out + if not _sdpa_mode_announced[0]: + print("[MPS] SDPA mode: native MPS diagnostic", flush=True) + _sdpa_mode_announced[0] = True + _torch.mps.synchronize() + out = _orig_sdpa(*args, **kwargs) + _torch.mps.synchronize() + return out _torch.nn.functional.scaled_dot_product_attention = _patched_sdpa @@ -396,6 +449,6 @@ def _patched_sdpa(*args, **kwargs): try: apply_mps_patch() except Exception as e: - print(f"[MPS Patch] Failed to apply patch: {e}") + print(f"[MPS] Failed to apply patch: {e}") import traceback traceback.print_exc() diff --git a/wgp.py b/wgp.py index decd81805..20fa10b49 100644 --- a/wgp.py +++ b/wgp.py @@ -3778,7 +3778,7 @@ def _load_models_info(message): base_model_type = get_base_model_type(model_type) model_def = get_model_def(model_type) save_quantized = args.save_quantized and model_def != None - model_filename = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy) + model_filename = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy) if "URLs2" in model_def: model_filename2 = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy, submodel_no=2) # !!!! else: @@ -3967,6 +3967,46 @@ def get_gen_info(state): state["gen"] = cache return cache + +def _env_flag_enabled(name, default=False): + raw_value = os.environ.get(name, None) + if raw_value is None: + return default + return str(raw_value).strip().lower() not in {"0", "false", "off", "no"} + + +def _is_quantized_transformer_filename(filename): + basename = os.path.basename(str(filename or "")).lower() + return "quanto" in basename or "int8" in basename or "fp8" in basename + + +def _mps_webui_inline_worker_required(queue): + if not is_mps: + return False + if os.environ.get("WAN2GP_MPS_WEBUI_INLINE_WORKER") is not None: + return _env_flag_enabled("WAN2GP_MPS_WEBUI_INLINE_WORKER") + if transformer_quantization not in ("int8", "fp8"): + return False + + for task in list(queue or []): + if not isinstance(task, dict): + continue + params = task.get("params", {}) + if not isinstance(params, dict): + continue + model_type = params.get("model_type") + if not model_type: + continue + if get_profile_type_for_model(model_type, params.get("image_mode", 0)) != "video": + continue + if get_model_family(model_type) != "wan": + continue + model_filename = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy) + if _is_quantized_transformer_filename(model_filename): + return True + return False + + def build_callback(state, pipe, send_cmd, status, num_inference_steps, preview_meta=None): gen = get_gen_info(state) gen["num_inference_steps"] = num_inference_steps @@ -6755,7 +6795,12 @@ def set_progress_status(status): fit_crop = False fit_canvas = 0 - joint_pass = boost ==1 #and profile != 1 and profile != 3 + joint_pass = boost == 1 #and profile != 1 and profile != 3 + if boost == 1 and sys.platform == "darwin" and _env_flag_enabled("WAN2GP_MPS_DISABLE_JOINT_CFG"): + # Diagnostic escape hatch for older PyTorch/MPS stacks where WAN's + # joint CFG transformer pass could abort inside Metal validation. + joint_pass = False + print("[MPS] Disabled WAN joint CFG pass on macOS for diagnostics", flush=True) skip_steps_cache = None if len(skip_steps_cache_type) == 0 else DynamicClass(cache_type = skip_steps_cache_type) @@ -6853,7 +6898,7 @@ def set_progress_status(status): length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) current_video_length = min(current_video_length, length) - if test_any_sliding_window(model_type) : + if test_any_sliding_window(model_type): sliding_window = current_video_length > sliding_window_size reuse_frames = min(sliding_window_size - latent_size, sliding_window_overlap) else: @@ -8014,7 +8059,11 @@ def queue_worker_func(): finally: send_cmd("worker_exit", None) - async_run_in("generation", queue_worker_func) + if _mps_webui_inline_worker_required(queue): + print("[MPS] Running WebUI generation inline for quantized WAN transformer stability", flush=True) + queue_worker_func() + else: + async_run_in("generation", queue_worker_func) while True: cmd, data = com_stream.output_queue.next()