Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 116 additions & 3 deletions models/wan/any2video.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
import types
import math
import threading
from contextlib import contextmanager
from functools import partial
from mmgp import offload
Expand Down Expand Up @@ -51,6 +52,104 @@

WAN_USE_FP32_ROPE_FREQS = True

_MPS_TRANSFORMER_LOCK = threading.RLock()
_MPS_TRANSFORMER_CALL_ANNOUNCED_CLEANUP = None
_MPS_SAMPLING_CLEANUP_DEFAULT = False
_MPS_PREVIEW_BYPASS_ANNOUNCED = False


def _mps_sampling_cleanup_enabled():
if os.environ.get("WAN2GP_MPS_DISABLE_SAMPLING_EMPTY_CACHE") == "1":
return False
mode = os.environ.get("WAN2GP_MPS_SAMPLING_EMPTY_CACHE_MODE")
if mode is not None:
return mode.lower() not in {"0", "false", "off", "none"}
return _MPS_SAMPLING_CLEANUP_DEFAULT


def _mps_configure_sampling_empty_cache_default(model_type):
global _MPS_SAMPLING_CLEANUP_DEFAULT
_MPS_SAMPLING_CLEANUP_DEFAULT = (
sys.platform == "darwin"
and hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
and model_type in {"t2v_1.3B", "ti2v_2_2", "t2v_2_2"}
)


def _mps_generation_boundary_cleanup_enabled():
return os.environ.get("WAN2GP_MPS_GENERATION_BOUNDARY_CLEANUP", "1") != "0"


def _mps_device_active(device):
if sys.platform != "darwin" or not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available():
return False
if isinstance(device, torch.device):
return device.type in {"mps", "cuda"}
device = str(device)
return device.startswith("mps") or device.startswith("cuda")


def _mps_sync_barrier(empty_cache=False):
torch.mps.synchronize()
if empty_cache and _mps_sampling_cleanup_enabled():
torch.mps.empty_cache()
torch.mps.synchronize()


def _call_transformer_for_mps(trans, **call_kwargs):
global _MPS_TRANSFORMER_CALL_ANNOUNCED_CLEANUP
if sys.platform != "darwin" or not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available():
return trans(**call_kwargs)

with _MPS_TRANSFORMER_LOCK:
cleanup_enabled = _mps_sampling_cleanup_enabled()
if _MPS_TRANSFORMER_CALL_ANNOUNCED_CLEANUP != cleanup_enabled:
print("[MPS] Using synchronized WAN transformer calls on macOS", flush=True)
if cleanup_enabled:
print("[MPS] MPS cache cleanup boundaries: step+transformer", flush=True)
print("[MPS] Step-boundary MPS cache cleanup: after every sampler step", flush=True)
else:
print("[MPS] MPS cache cleanup boundaries: disabled", flush=True)
if cleanup_enabled and _mps_generation_boundary_cleanup_enabled():
print("[MPS] Generation-boundary MPS cache cleanup: enabled", flush=True)
_MPS_TRANSFORMER_CALL_ANNOUNCED_CLEANUP = cleanup_enabled
_mps_sync_barrier()
try:
return trans(**call_kwargs)
finally:
_mps_sync_barrier(empty_cache=True)


def _mps_sampler_step_barrier(device):
if _mps_device_active(device):
_mps_sync_barrier(empty_cache=True)


def _mps_generation_boundary_barrier(device):
if (
_mps_device_active(device)
and _mps_generation_boundary_cleanup_enabled()
and _mps_sampling_cleanup_enabled()
):
torch.mps.synchronize()
torch.mps.empty_cache()
torch.mps.synchronize()


def _mps_live_latent_preview_enabled(device):
if not _mps_device_active(device):
return True
return os.environ.get("WAN2GP_MPS_PREVIEW_LATENTS") == "1"


def _mps_announce_preview_bypass():
global _MPS_PREVIEW_BYPASS_ANNOUNCED
if not _MPS_PREVIEW_BYPASS_ANNOUNCED:
print("[MPS] Disabled live WAN latent previews on macOS for sampler stability", flush=True)
_MPS_PREVIEW_BYPASS_ANNOUNCED = True


def get_vista4d_rotary_pos_embed(latents_size):
lat_t, lat_h, lat_w = latents_size
grid_t, grid_h, grid_w = lat_t, lat_h // 2, lat_w // 2
Expand Down Expand Up @@ -499,6 +598,8 @@ def generate(self,
):

model_def = self.model_def
_mps_configure_sampling_empty_cache_default(model_type)
_mps_generation_boundary_barrier(self.device)

if sample_solver =="euler":
sample_scheduler = EulerScheduler(
Expand Down Expand Up @@ -1189,6 +1290,8 @@ def update_guidance(step_no, t, guide_scale, new_guide_scale, guidance_switch_do
callback(-1, None, True, override_num_inference_steps = updated_num_steps, denoising_extra = denoising_extra)

def clear():
if _mps_device_active(self.device):
_mps_sync_barrier()
clear_caches()
gc.collect()
torch.cuda.empty_cache()
Expand Down Expand Up @@ -1357,15 +1460,15 @@ def denoise_with_cfg_fn(latents):
}

if joint_pass and any_guidance:
ret_values = trans( **gen_args , **kwargs)
ret_values = _call_transformer_for_mps(trans, **gen_args, **kwargs)
if self._interrupt:
return clear()
else:
size = len(gen_args["x"]) if any_guidance else 1
ret_values = [None] * size
for x_id in range(size):
sub_gen_args = {k : [v[x_id]] for k, v in gen_args.items() }
ret_values[x_id] = trans( **sub_gen_args, x_id= x_id , **kwargs)[0]
ret_values[x_id] = _call_transformer_for_mps(trans, **sub_gen_args, x_id=x_id, **kwargs)[0]
if self._interrupt:
return clear()
sub_gen_args = None
Expand Down Expand Up @@ -1460,11 +1563,20 @@ def denoise_with_cfg_fn(latents):
sigma = 0 if i == len(timesteps)-1 else timesteps[i+1]/1000
noisy_image = randn[:, :, :source_latents.shape[2]] * sigma + (1 - sigma) * source_latents
latents[:, :, :source_latents.shape[2]] = noisy_image * (1-image_mask_latents) + image_mask_latents * latents[:, :, :source_latents.shape[2]]
noisy_image = None

# Release the denoise output before the MPS step synchronization
# boundary so Metal does not hold onto it longer than needed.
noise_pred = None
_mps_sampler_step_barrier(self.device)

if callback is not None:
if not _mps_live_latent_preview_enabled(self.device):
_mps_announce_preview_bypass()
callback(i, None, False, denoising_extra=denoising_extra)
continue
latents_preview = latents
if ref_images_before and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ]
if ref_images_before and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ]
if trim_frames > 0: latents_preview= latents_preview[:, :,:-trim_frames]
if image_outputs: latents_preview= latents_preview[:, :,-1:] if last_latent_preview else latents_preview[:, :,:1]
if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2)
Expand Down Expand Up @@ -1539,6 +1651,7 @@ def denoise_with_cfg_fn(latents):
if videos.dtype != torch.uint8:
videos = videos.clamp_(-1, 1).add_(1.0).mul_(127.5).round_().clamp_(0, 255).to(torch.uint8)
if BGRA_frames is not None: ret["BGRA_frames"] = BGRA_frames
_mps_generation_boundary_barrier(self.device)
return ret

def get_loras_transformer(self, get_model_recursive_prop, base_model_type, model_type, video_prompt_type, model_mode, **kwargs):
Expand Down
Loading