fix: AMD gfx1201 (RDNA4/ROCm) — INT8 Triton f32 MFMA, LTX Video device fix, validate_settings KeyError#1822
Open
athornbro wants to merge 1 commit into
Open
Conversation
Four fixes to make Wan2GP and LTX Video generation work correctly on AMD RDNA4 GPUs (RX 9070/9070XT, gfx1201) running ROCm/HIP on Windows. ## shared/kernels/quanto_int8_triton.py 1. **HIP libdevice**: Use `triton.language.extra.hip.libdevice` instead of the CUDA libdevice on ROCm. The CUDA libdevice triggers "__nv_rintf has been dropped" errors under Triton 3.6+ on AMD, silently killing the fused INT8 kernel. 2. **f32 MFMA for int8 dot on gfx1201**: All three Triton INT8 GEMM kernels now cast operands to float32 (instead of bfloat16) before `tl.dot` when running under HIP. gfx1201 bf16 MFMA only provides K=8 and K=16 intrinsics; Wan 2.2 uses K=256 (works) but LTX Video uses K=64 (no matching intrinsic → crash). Float32 MFMA (K=2/4) tiles freely at any block_k and covers both models. All three kernels' accumulators were changed from int32 to float32 accordingly. ## models/ltx_video/models/transformers/attention.py 3. **skip_layer_mask device mismatch**: `skip_layer_mask` is constructed on CPU but `hidden_states_a` is on the GPU, causing a "Expected all tensors to be on the same device" RuntimeError at the first denoising step. Fix: `.to(hidden_states_a.device)` on both the AttentionSkip and AttentionValues branches. ## wgp.py 4. **validate_settings KeyError fix**: Several media input keys (`image_start`, `image_end`, `image_refs`, `video_guide`, etc.) are absent from the form payload when inputs are cleared or a different model tab is active, producing an unhandled KeyError. Changed to `.get()` with a friendly error for I2V-without-image. ## plugins/wan2gp-video-mask-creator/plugin.py 5. **ROCm CPU LAPACK fix**: ROCm Windows PyTorch ships without CPU LAPACK, so `nn.init.orthogonal_` (used during matanyone model construction) fails with a geqrf RuntimeError. Monkey-patches orthogonal_ to transparently retry the QR decomposition on the GPU when the CPU path raises a LAPACK error. Zero overhead on systems with working CPU LAPACK. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Owner
|
thx, but I cant merge this PR as is :
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Five fixes to make Wan2GP and LTX Video run correctly on AMD RDNA4 GPUs (RX 9070 / 9070 XT, gfx1201) under ROCm/HIP on Windows. All changes are no-ops on CUDA.
shared/kernels/quanto_int8_triton.pyHIP libdevice — Use
triton.language.extra.hip.libdeviceon ROCm instead of the CUDA libdevice. The CUDA libdevice triggers__nv_rintf has been droppederrors under Triton 3.6+ on AMD, silently killing the fused INT8 kernel.f32 MFMA for int8 dot on gfx1201 — All three Triton INT8 GEMM kernels now cast int8 operands to
float32beforetl.doton HIP. gfx1201 bf16 MFMA only provides K=8 and K=16 intrinsics; Wan 2.2 uses K=256 (works) but LTX Video uses K=64 — no matching bf16 intrinsic → Triton compile crash. Float32 MFMA (K=2/4) tiles freely at anyblock_k. Accumulators updated fromint32tofloat32accordingly.models/ltx_video/models/transformers/attention.pyskip_layer_maskdevice mismatch —skip_layer_maskis constructed on CPU buthidden_states_ais on GPU, causingRuntimeError: Expected all tensors to be on the same deviceat the first LTX Video denoising step. Fixed with.to(hidden_states_a.device)on both theAttentionSkipandAttentionValuesbranches.wgp.pyvalidate_settingsKeyError fix — Media input keys (image_start,image_end,image_refs,video_guide, etc.) are absent from the payload when inputs are cleared or a different model tab is active. Changed frominputs["key"]toinputs.get("key")throughout, plus a friendly error for I2V without an input image.asyncio socket warning — Suppress spurious
socket.send() raised exceptionlog noise on Windows.plugins/wan2gp-video-mask-creator/plugin.pyROCm CPU LAPACK fix — ROCm Windows PyTorch ships without CPU LAPACK, so
nn.init.orthogonal_fails during matanyone model construction. Monkey-patchesorthogonal_to retry the QR decomposition on the GPU when the CPU path raises a LAPACK error. Zero overhead on systems with working CPU LAPACK.Test platform