Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions models/ltx_video/models/transformers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,15 +1122,17 @@ def __call__(
skip_layer_mask is not None
and skip_layer_strategy == SkipLayerStrategy.AttentionSkip
):
hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (
1.0 - skip_layer_mask
_mask = skip_layer_mask.to(hidden_states_a.device)
hidden_states = hidden_states_a * _mask + hidden_states * (
1.0 - _mask
)
elif (
skip_layer_mask is not None
and skip_layer_strategy == SkipLayerStrategy.AttentionValues
):
hidden_states_a *= skip_layer_mask
value_for_stg *= 1.0 - skip_layer_mask
_mask = skip_layer_mask.to(hidden_states_a.device)
hidden_states_a *= _mask
value_for_stg *= 1.0 - _mask
hidden_states_a += value_for_stg
hidden_states = hidden_states_a
del value_for_stg
Expand Down
27 changes: 27 additions & 0 deletions plugins/wan2gp-video-mask-creator/plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,31 @@
import gradio as gr

# ROCm Windows PyTorch lacks CPU LAPACK, breaking nn.init.orthogonal_ during
# matanyone model construction (the QR decomposition needs torch.geqrf).
# Our torch's GPU linalg works fine; detour CPU calls through GPU only when
# the CPU path raises a LAPACK error. Zero overhead on systems where CPU LAPACK works.
# Reversible: pip install --force-reinstall torch to restore stock behavior.
import torch
import torch.nn.init as _wan2gp_init
_wan2gp_original_orthogonal = _wan2gp_init.orthogonal_

def _wan2gp_gpu_safe_orthogonal_(tensor, gain=1, generator=None):
try:
return _wan2gp_original_orthogonal(tensor, gain=gain, generator=generator)
except RuntimeError as e:
msg = str(e)
if (tensor.device.type == 'cpu'
and torch.cuda.is_available()
and ('LAPACK' in msg or 'geqrf' in msg)):
gpu_t = tensor.detach().to('cuda')
_wan2gp_original_orthogonal(gpu_t, gain=gain, generator=generator)
tensor.copy_(gpu_t.to('cpu'))
del gpu_t
return tensor
raise

_wan2gp_init.orthogonal_ = _wan2gp_gpu_safe_orthogonal_

from shared.utils.plugins import WAN2GPPlugin
from preprocessing.matanyone import app as matanyone_app

Expand Down
48 changes: 39 additions & 9 deletions shared/kernels/quanto_int8_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@
try:
import triton
import triton.language as tl
from triton.language.extra.cuda import libdevice as tl_libdevice
# Pick the right libdevice for the active backend. The CUDA libdevice
# triggers "__nv_rintf has been dropped" warnings/fallbacks on HIP/ROCm
# in Triton 3.6+, killing the fused-INT8 kernel on AMD. The HIP libdevice
# has the same op surface (rint, etc.) and compiles cleanly on gfx1201.
if getattr(torch.version, "hip", None):
from triton.language.extra.hip import libdevice as tl_libdevice
else:
from triton.language.extra.cuda import libdevice as tl_libdevice

_TRITON_AVAILABLE = True
except Exception: # pragma: no cover
Expand All @@ -21,6 +28,8 @@
_TRITON_AVAILABLE = False


_HIP_BACKEND = bool(getattr(torch.version, "hip", None))

_ENV_ENABLE = "WAN2GP_QUANTO_INT8_TRITON"
_ENV_AUTOTUNE_ENABLE = "WAN2GP_QUANTO_INT8_AUTOTUNE"
_ENV_AUTOTUNE_DEBUG = "WAN2GP_QUANTO_INT8_AUTOTUNE_DEBUG"
Expand Down Expand Up @@ -539,6 +548,7 @@ def _launch_candidate(kind: str, cfg: tuple[int, int, int, int, int], tensors: t
block_k=block_k,
num_warps=num_warps,
num_stages=num_stages,
hip_mode=_HIP_BACKEND,
)
return
a_int8_c, b_int8_c, a_scale_c, b_scale_c, out = tensors
Expand All @@ -562,6 +572,7 @@ def _launch_candidate(kind: str, cfg: tuple[int, int, int, int, int], tensors: t
block_k=block_k,
num_warps=num_warps,
num_stages=num_stages,
hip_mode=_HIP_BACKEND,
)


Expand Down Expand Up @@ -943,6 +954,7 @@ def _fused_dynamic_int8_gemm_kernel(
block_m: tl.constexpr,
block_n: tl.constexpr,
block_k: tl.constexpr,
hip_mode: tl.constexpr = False,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
Expand All @@ -966,7 +978,7 @@ def _fused_dynamic_int8_gemm_kernel(
row_inv_scale = 1.0 / row_scale

# Pass 2: quantize activations on the fly + int8 dot.
acc = tl.zeros((block_m, block_n), dtype=tl.int32)
acc = tl.zeros((block_m, block_n), dtype=tl.float32)
for k0 in range(0, k, block_k):
kk = k0 + offs_k
a = tl.load(
Expand All @@ -985,10 +997,15 @@ def _fused_dynamic_int8_gemm_kernel(
mask=(offs_n[None, :] < n) & (kk[:, None] < k),
other=0,
).to(tl.int8)
acc += tl.dot(a, b)
# gfx1201 (RDNA4/Wave32): bf16 MFMA only works for K>=128; use f32 MFMA (K=2/4)
# which tiles freely for any block_k (covers LTX-Video K=64 heads too).
if hip_mode:
acc += tl.dot(a.to(tl.float32), b.to(tl.float32), out_dtype=tl.float32)
else:
acc += tl.dot(a, b).to(tl.float32)

scales = tl.load(s_ptr + offs_n, mask=offs_n < n, other=0).to(tl.float32)
out = acc.to(tl.float32) * row_scale[:, None] * scales[None, :]
out = acc * row_scale[:, None] * scales[None, :]
tl.store(
c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn,
out,
Expand All @@ -1013,6 +1030,7 @@ def _fused_dynamic_int8_blockscale_gemm_kernel(
block_m: tl.constexpr,
block_n: tl.constexpr,
block_k: tl.constexpr,
hip_mode: tl.constexpr = False,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
Expand Down Expand Up @@ -1041,8 +1059,12 @@ def _fused_dynamic_int8_blockscale_gemm_kernel(
other=0,
).to(tl.int8)

dot_i32 = tl.dot(a, b)
acc += dot_i32.to(tl.float32) * row_scale[:, None]
# gfx1201 (RDNA4/Wave32): bf16 MFMA only works for K>=128; use f32 MFMA (K=2/4)
# which tiles freely for any block_k (covers LTX-Video K=64 heads too).
if hip_mode:
acc += tl.dot(a.to(tl.float32), b.to(tl.float32), out_dtype=tl.float32) * row_scale[:, None]
else:
acc += tl.dot(a, b).to(tl.float32) * row_scale[:, None]

scales = tl.load(s_ptr + offs_n, mask=offs_n < n, other=0).to(tl.float32)
out = acc * scales[None, :]
Expand Down Expand Up @@ -1071,14 +1093,15 @@ def _scaled_int8_gemm_kernel(
block_m: tl.constexpr,
block_n: tl.constexpr,
block_k: tl.constexpr,
hip_mode: tl.constexpr = False,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * block_m + tl.arange(0, block_m)
offs_n = pid_n * block_n + tl.arange(0, block_n)
offs_k = tl.arange(0, block_k)

acc = tl.zeros((block_m, block_n), dtype=tl.int32)
acc = tl.zeros((block_m, block_n), dtype=tl.float32)
for k0 in range(0, k, block_k):
kk = k0 + offs_k
a = tl.load(
Expand All @@ -1092,11 +1115,16 @@ def _scaled_int8_gemm_kernel(
mask=(offs_n[None, :] < n) & (kk[:, None] < k),
other=0,
).to(tl.int8)
acc += tl.dot(a, b)
# gfx1201 (RDNA4/Wave32): bf16 MFMA only works for K>=128; use f32 MFMA (K=2/4)
# which tiles freely for any block_k (covers LTX-Video K=64 heads too).
if hip_mode:
acc += tl.dot(a.to(tl.float32), b.to(tl.float32), out_dtype=tl.float32)
else:
acc += tl.dot(a, b).to(tl.float32)

a_scales = tl.load(a_scales_ptr + offs_m, mask=offs_m < m, other=1).to(tl.float32)
b_scales = tl.load(b_scales_ptr + offs_n, mask=offs_n < n, other=1).to(tl.float32)
out = acc.to(tl.float32) * a_scales[:, None] * b_scales[None, :]
out = acc * a_scales[:, None] * b_scales[None, :]
tl.store(
c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn,
out,
Expand Down Expand Up @@ -1158,6 +1186,7 @@ def _fused_quant_scaled_mm_common(
block_k=block_k,
num_warps=num_warps,
num_stages=num_stages,
hip_mode=_HIP_BACKEND,
)
return out

Expand Down Expand Up @@ -1317,5 +1346,6 @@ def scaled_int8_mm(
block_k=block_k,
num_warps=num_warps,
num_stages=num_stages,
hip_mode=_HIP_BACKEND,
)
return out
46 changes: 31 additions & 15 deletions wgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
from shared.asyncio_utils import silence_proactor_connection_reset
silence_proactor_connection_reset()

import logging
logging.getLogger("asyncio").addFilter(
type("_SocketSendFilter", (logging.Filter,), {
"filter": staticmethod(lambda r: "socket.send() raised exception" not in r.getMessage())
})()
)

# ── Apple Silicon MPS patch: MUST come before mmgp import ──
import torch
is_mps = sys.platform == 'darwin' and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
Expand Down Expand Up @@ -755,23 +762,32 @@ def err(error=""):
resolution = inputs["resolution"]
width, height = resolution.split("x")
width, height = int(width), int(height)
image_start = inputs["image_start"]
image_end = inputs["image_end"]
image_refs = inputs["image_refs"]
image_prompt_type = inputs["image_prompt_type"]
audio_prompt_type = inputs["audio_prompt_type"]
# Local patch: Wan2GP UI sometimes drops media keys from the form payload (e.g. when
# the input image is cleared, the model is switched, or a stale tab is submitted).
# Use .get() so we produce a friendly error instead of crashing with KeyError, and
# explicitly check that i2v generations have an input image.
image_start = inputs.get("image_start")
image_end = inputs.get("image_end")
image_refs = inputs.get("image_refs")
image_prompt_type = inputs.get("image_prompt_type")
audio_prompt_type = inputs.get("audio_prompt_type")
if image_prompt_type == None: image_prompt_type = ""
video_prompt_type = inputs["video_prompt_type"]
video_prompt_type = inputs.get("video_prompt_type")
if video_prompt_type == None: video_prompt_type = ""
force_fps = inputs["force_fps"]
audio_guide = inputs["audio_guide"]
audio_guide2 = inputs["audio_guide2"]
audio_source = inputs["audio_source"]
video_guide = inputs["video_guide"]
image_guide = inputs["image_guide"]
video_mask = inputs["video_mask"]
image_mask = inputs["image_mask"]
custom_guide = inputs["custom_guide"]
force_fps = inputs.get("force_fps")
audio_guide = inputs.get("audio_guide")
audio_guide2 = inputs.get("audio_guide2")
audio_source = inputs.get("audio_source")
video_guide = inputs.get("video_guide")
image_guide = inputs.get("image_guide")
video_mask = inputs.get("video_mask")
image_mask = inputs.get("image_mask")
custom_guide = inputs.get("custom_guide")
# Friendly error for i2v generations that landed here without an input image.
if (not image_outputs
and image_start is None
and "i2v" in (get_base_model_type(model_type) or "").lower()):
return err("No input image provided. Upload an image in the I2V panel before clicking Generate.")
speakers_locations = inputs["speakers_locations"]
video_source = inputs["video_source"]
frames_positions = inputs["frames_positions"]
Expand Down