diff --git a/models/ltx_video/models/transformers/attention.py b/models/ltx_video/models/transformers/attention.py index a87a8a0f3..f2e98b364 100644 --- a/models/ltx_video/models/transformers/attention.py +++ b/models/ltx_video/models/transformers/attention.py @@ -1122,15 +1122,17 @@ def __call__( skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip ): - hidden_states = hidden_states_a * skip_layer_mask + hidden_states * ( - 1.0 - skip_layer_mask + _mask = skip_layer_mask.to(hidden_states_a.device) + hidden_states = hidden_states_a * _mask + hidden_states * ( + 1.0 - _mask ) elif ( skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionValues ): - hidden_states_a *= skip_layer_mask - value_for_stg *= 1.0 - skip_layer_mask + _mask = skip_layer_mask.to(hidden_states_a.device) + hidden_states_a *= _mask + value_for_stg *= 1.0 - _mask hidden_states_a += value_for_stg hidden_states = hidden_states_a del value_for_stg diff --git a/plugins/wan2gp-video-mask-creator/plugin.py b/plugins/wan2gp-video-mask-creator/plugin.py index ae80fb995..a45ce10ca 100644 --- a/plugins/wan2gp-video-mask-creator/plugin.py +++ b/plugins/wan2gp-video-mask-creator/plugin.py @@ -1,4 +1,31 @@ import gradio as gr + +# ROCm Windows PyTorch lacks CPU LAPACK, breaking nn.init.orthogonal_ during +# matanyone model construction (the QR decomposition needs torch.geqrf). +# Our torch's GPU linalg works fine; detour CPU calls through GPU only when +# the CPU path raises a LAPACK error. Zero overhead on systems where CPU LAPACK works. +# Reversible: pip install --force-reinstall torch to restore stock behavior. +import torch +import torch.nn.init as _wan2gp_init +_wan2gp_original_orthogonal = _wan2gp_init.orthogonal_ + +def _wan2gp_gpu_safe_orthogonal_(tensor, gain=1, generator=None): + try: + return _wan2gp_original_orthogonal(tensor, gain=gain, generator=generator) + except RuntimeError as e: + msg = str(e) + if (tensor.device.type == 'cpu' + and torch.cuda.is_available() + and ('LAPACK' in msg or 'geqrf' in msg)): + gpu_t = tensor.detach().to('cuda') + _wan2gp_original_orthogonal(gpu_t, gain=gain, generator=generator) + tensor.copy_(gpu_t.to('cpu')) + del gpu_t + return tensor + raise + +_wan2gp_init.orthogonal_ = _wan2gp_gpu_safe_orthogonal_ + from shared.utils.plugins import WAN2GPPlugin from preprocessing.matanyone import app as matanyone_app diff --git a/shared/kernels/quanto_int8_triton.py b/shared/kernels/quanto_int8_triton.py index ecef96250..cb1c55c36 100644 --- a/shared/kernels/quanto_int8_triton.py +++ b/shared/kernels/quanto_int8_triton.py @@ -11,7 +11,14 @@ try: import triton import triton.language as tl - from triton.language.extra.cuda import libdevice as tl_libdevice + # Pick the right libdevice for the active backend. The CUDA libdevice + # triggers "__nv_rintf has been dropped" warnings/fallbacks on HIP/ROCm + # in Triton 3.6+, killing the fused-INT8 kernel on AMD. The HIP libdevice + # has the same op surface (rint, etc.) and compiles cleanly on gfx1201. + if getattr(torch.version, "hip", None): + from triton.language.extra.hip import libdevice as tl_libdevice + else: + from triton.language.extra.cuda import libdevice as tl_libdevice _TRITON_AVAILABLE = True except Exception: # pragma: no cover @@ -21,6 +28,8 @@ _TRITON_AVAILABLE = False +_HIP_BACKEND = bool(getattr(torch.version, "hip", None)) + _ENV_ENABLE = "WAN2GP_QUANTO_INT8_TRITON" _ENV_AUTOTUNE_ENABLE = "WAN2GP_QUANTO_INT8_AUTOTUNE" _ENV_AUTOTUNE_DEBUG = "WAN2GP_QUANTO_INT8_AUTOTUNE_DEBUG" @@ -539,6 +548,7 @@ def _launch_candidate(kind: str, cfg: tuple[int, int, int, int, int], tensors: t block_k=block_k, num_warps=num_warps, num_stages=num_stages, + hip_mode=_HIP_BACKEND, ) return a_int8_c, b_int8_c, a_scale_c, b_scale_c, out = tensors @@ -562,6 +572,7 @@ def _launch_candidate(kind: str, cfg: tuple[int, int, int, int, int], tensors: t block_k=block_k, num_warps=num_warps, num_stages=num_stages, + hip_mode=_HIP_BACKEND, ) @@ -943,6 +954,7 @@ def _fused_dynamic_int8_gemm_kernel( block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr, + hip_mode: tl.constexpr = False, ): pid_m = tl.program_id(0) pid_n = tl.program_id(1) @@ -966,7 +978,7 @@ def _fused_dynamic_int8_gemm_kernel( row_inv_scale = 1.0 / row_scale # Pass 2: quantize activations on the fly + int8 dot. - acc = tl.zeros((block_m, block_n), dtype=tl.int32) + acc = tl.zeros((block_m, block_n), dtype=tl.float32) for k0 in range(0, k, block_k): kk = k0 + offs_k a = tl.load( @@ -985,10 +997,15 @@ def _fused_dynamic_int8_gemm_kernel( mask=(offs_n[None, :] < n) & (kk[:, None] < k), other=0, ).to(tl.int8) - acc += tl.dot(a, b) + # gfx1201 (RDNA4/Wave32): bf16 MFMA only works for K>=128; use f32 MFMA (K=2/4) + # which tiles freely for any block_k (covers LTX-Video K=64 heads too). + if hip_mode: + acc += tl.dot(a.to(tl.float32), b.to(tl.float32), out_dtype=tl.float32) + else: + acc += tl.dot(a, b).to(tl.float32) scales = tl.load(s_ptr + offs_n, mask=offs_n < n, other=0).to(tl.float32) - out = acc.to(tl.float32) * row_scale[:, None] * scales[None, :] + out = acc * row_scale[:, None] * scales[None, :] tl.store( c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn, out, @@ -1013,6 +1030,7 @@ def _fused_dynamic_int8_blockscale_gemm_kernel( block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr, + hip_mode: tl.constexpr = False, ): pid_m = tl.program_id(0) pid_n = tl.program_id(1) @@ -1041,8 +1059,12 @@ def _fused_dynamic_int8_blockscale_gemm_kernel( other=0, ).to(tl.int8) - dot_i32 = tl.dot(a, b) - acc += dot_i32.to(tl.float32) * row_scale[:, None] + # gfx1201 (RDNA4/Wave32): bf16 MFMA only works for K>=128; use f32 MFMA (K=2/4) + # which tiles freely for any block_k (covers LTX-Video K=64 heads too). + if hip_mode: + acc += tl.dot(a.to(tl.float32), b.to(tl.float32), out_dtype=tl.float32) * row_scale[:, None] + else: + acc += tl.dot(a, b).to(tl.float32) * row_scale[:, None] scales = tl.load(s_ptr + offs_n, mask=offs_n < n, other=0).to(tl.float32) out = acc * scales[None, :] @@ -1071,6 +1093,7 @@ def _scaled_int8_gemm_kernel( block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr, + hip_mode: tl.constexpr = False, ): pid_m = tl.program_id(0) pid_n = tl.program_id(1) @@ -1078,7 +1101,7 @@ def _scaled_int8_gemm_kernel( offs_n = pid_n * block_n + tl.arange(0, block_n) offs_k = tl.arange(0, block_k) - acc = tl.zeros((block_m, block_n), dtype=tl.int32) + acc = tl.zeros((block_m, block_n), dtype=tl.float32) for k0 in range(0, k, block_k): kk = k0 + offs_k a = tl.load( @@ -1092,11 +1115,16 @@ def _scaled_int8_gemm_kernel( mask=(offs_n[None, :] < n) & (kk[:, None] < k), other=0, ).to(tl.int8) - acc += tl.dot(a, b) + # gfx1201 (RDNA4/Wave32): bf16 MFMA only works for K>=128; use f32 MFMA (K=2/4) + # which tiles freely for any block_k (covers LTX-Video K=64 heads too). + if hip_mode: + acc += tl.dot(a.to(tl.float32), b.to(tl.float32), out_dtype=tl.float32) + else: + acc += tl.dot(a, b).to(tl.float32) a_scales = tl.load(a_scales_ptr + offs_m, mask=offs_m < m, other=1).to(tl.float32) b_scales = tl.load(b_scales_ptr + offs_n, mask=offs_n < n, other=1).to(tl.float32) - out = acc.to(tl.float32) * a_scales[:, None] * b_scales[None, :] + out = acc * a_scales[:, None] * b_scales[None, :] tl.store( c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn, out, @@ -1158,6 +1186,7 @@ def _fused_quant_scaled_mm_common( block_k=block_k, num_warps=num_warps, num_stages=num_stages, + hip_mode=_HIP_BACKEND, ) return out @@ -1317,5 +1346,6 @@ def scaled_int8_mm( block_k=block_k, num_warps=num_warps, num_stages=num_stages, + hip_mode=_HIP_BACKEND, ) return out diff --git a/wgp.py b/wgp.py index a6efd639b..2b84d278d 100644 --- a/wgp.py +++ b/wgp.py @@ -21,6 +21,13 @@ from shared.asyncio_utils import silence_proactor_connection_reset silence_proactor_connection_reset() +import logging +logging.getLogger("asyncio").addFilter( + type("_SocketSendFilter", (logging.Filter,), { + "filter": staticmethod(lambda r: "socket.send() raised exception" not in r.getMessage()) + })() +) + # ── Apple Silicon MPS patch: MUST come before mmgp import ── import torch is_mps = sys.platform == 'darwin' and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() @@ -755,23 +762,32 @@ def err(error=""): resolution = inputs["resolution"] width, height = resolution.split("x") width, height = int(width), int(height) - image_start = inputs["image_start"] - image_end = inputs["image_end"] - image_refs = inputs["image_refs"] - image_prompt_type = inputs["image_prompt_type"] - audio_prompt_type = inputs["audio_prompt_type"] + # Local patch: Wan2GP UI sometimes drops media keys from the form payload (e.g. when + # the input image is cleared, the model is switched, or a stale tab is submitted). + # Use .get() so we produce a friendly error instead of crashing with KeyError, and + # explicitly check that i2v generations have an input image. + image_start = inputs.get("image_start") + image_end = inputs.get("image_end") + image_refs = inputs.get("image_refs") + image_prompt_type = inputs.get("image_prompt_type") + audio_prompt_type = inputs.get("audio_prompt_type") if image_prompt_type == None: image_prompt_type = "" - video_prompt_type = inputs["video_prompt_type"] + video_prompt_type = inputs.get("video_prompt_type") if video_prompt_type == None: video_prompt_type = "" - force_fps = inputs["force_fps"] - audio_guide = inputs["audio_guide"] - audio_guide2 = inputs["audio_guide2"] - audio_source = inputs["audio_source"] - video_guide = inputs["video_guide"] - image_guide = inputs["image_guide"] - video_mask = inputs["video_mask"] - image_mask = inputs["image_mask"] - custom_guide = inputs["custom_guide"] + force_fps = inputs.get("force_fps") + audio_guide = inputs.get("audio_guide") + audio_guide2 = inputs.get("audio_guide2") + audio_source = inputs.get("audio_source") + video_guide = inputs.get("video_guide") + image_guide = inputs.get("image_guide") + video_mask = inputs.get("video_mask") + image_mask = inputs.get("image_mask") + custom_guide = inputs.get("custom_guide") + # Friendly error for i2v generations that landed here without an input image. + if (not image_outputs + and image_start is None + and "i2v" in (get_base_model_type(model_type) or "").lower()): + return err("No input image provided. Upload an image in the I2V panel before clicking Generate.") speakers_locations = inputs["speakers_locations"] video_source = inputs["video_source"] frames_positions = inputs["frames_positions"]