Skip to content

fix: AMD gfx1201 (RDNA4/ROCm) — INT8 Triton f32 MFMA, LTX Video device fix, validate_settings KeyError#1822

Open
athornbro wants to merge 1 commit into
deepbeepmeep:mainfrom
athornbro:fix/amd-rocm-gfx1201-ltxvideo
Open

fix: AMD gfx1201 (RDNA4/ROCm) — INT8 Triton f32 MFMA, LTX Video device fix, validate_settings KeyError#1822
athornbro wants to merge 1 commit into
deepbeepmeep:mainfrom
athornbro:fix/amd-rocm-gfx1201-ltxvideo

Conversation

@athornbro

Copy link
Copy Markdown

Summary

Five fixes to make Wan2GP and LTX Video run correctly on AMD RDNA4 GPUs (RX 9070 / 9070 XT, gfx1201) under ROCm/HIP on Windows. All changes are no-ops on CUDA.

shared/kernels/quanto_int8_triton.py

HIP libdevice — Use triton.language.extra.hip.libdevice on ROCm instead of the CUDA libdevice. The CUDA libdevice triggers __nv_rintf has been dropped errors under Triton 3.6+ on AMD, silently killing the fused INT8 kernel.

f32 MFMA for int8 dot on gfx1201 — All three Triton INT8 GEMM kernels now cast int8 operands to float32 before tl.dot on HIP. gfx1201 bf16 MFMA only provides K=8 and K=16 intrinsics; Wan 2.2 uses K=256 (works) but LTX Video uses K=64 — no matching bf16 intrinsic → Triton compile crash. Float32 MFMA (K=2/4) tiles freely at any block_k. Accumulators updated from int32 to float32 accordingly.

models/ltx_video/models/transformers/attention.py

skip_layer_mask device mismatchskip_layer_mask is constructed on CPU but hidden_states_a is on GPU, causing RuntimeError: Expected all tensors to be on the same device at the first LTX Video denoising step. Fixed with .to(hidden_states_a.device) on both the AttentionSkip and AttentionValues branches.

wgp.py

validate_settings KeyError fix — Media input keys (image_start, image_end, image_refs, video_guide, etc.) are absent from the payload when inputs are cleared or a different model tab is active. Changed from inputs["key"] to inputs.get("key") throughout, plus a friendly error for I2V without an input image.

asyncio socket warning — Suppress spurious socket.send() raised exception log noise on Windows.

plugins/wan2gp-video-mask-creator/plugin.py

ROCm CPU LAPACK fix — ROCm Windows PyTorch ships without CPU LAPACK, so nn.init.orthogonal_ fails during matanyone model construction. Monkey-patches orthogonal_ to retry the QR decomposition on the GPU when the CPU path raises a LAPACK error. Zero overhead on systems with working CPU LAPACK.

Test platform

  • AMD Radeon RX 9070 XT (gfx1201 / RDNA4 / Wave32), ROCm 7.2, Windows 11
  • Triton 3.7.0 (triton-windows)
  • Tested: Wan 2.2 14B I2V int8 (4-step lightning) and LTX Video 0.9.8 13B int8

Four fixes to make Wan2GP and LTX Video generation work correctly on
AMD RDNA4 GPUs (RX 9070/9070XT, gfx1201) running ROCm/HIP on Windows.

## shared/kernels/quanto_int8_triton.py

1. **HIP libdevice**: Use `triton.language.extra.hip.libdevice` instead of
   the CUDA libdevice on ROCm. The CUDA libdevice triggers
   "__nv_rintf has been dropped" errors under Triton 3.6+ on AMD,
   silently killing the fused INT8 kernel.

2. **f32 MFMA for int8 dot on gfx1201**: All three Triton INT8 GEMM kernels
   now cast operands to float32 (instead of bfloat16) before `tl.dot` when
   running under HIP. gfx1201 bf16 MFMA only provides K=8 and K=16
   intrinsics; Wan 2.2 uses K=256 (works) but LTX Video uses K=64 (no
   matching intrinsic → crash). Float32 MFMA (K=2/4) tiles freely at any
   block_k and covers both models. All three kernels' accumulators were
   changed from int32 to float32 accordingly.

## models/ltx_video/models/transformers/attention.py

3. **skip_layer_mask device mismatch**: `skip_layer_mask` is constructed on
   CPU but `hidden_states_a` is on the GPU, causing a
   "Expected all tensors to be on the same device" RuntimeError at the
   first denoising step. Fix: `.to(hidden_states_a.device)` on both the
   AttentionSkip and AttentionValues branches.

## wgp.py

4. **validate_settings KeyError fix**: Several media input keys
   (`image_start`, `image_end`, `image_refs`, `video_guide`, etc.) are
   absent from the form payload when inputs are cleared or a different
   model tab is active, producing an unhandled KeyError. Changed to
   `.get()` with a friendly error for I2V-without-image.

## plugins/wan2gp-video-mask-creator/plugin.py

5. **ROCm CPU LAPACK fix**: ROCm Windows PyTorch ships without CPU LAPACK,
   so `nn.init.orthogonal_` (used during matanyone model construction)
   fails with a geqrf RuntimeError. Monkey-patches orthogonal_ to
   transparently retry the QR decomposition on the GPU when the CPU path
   raises a LAPACK error. Zero overhead on systems with working CPU LAPACK.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@deepbeepmeep

Copy link
Copy Markdown
Owner

thx, but I cant merge this PR as is :

  • not all changes are no ops for nvidia: for instance there are int32 conversions transformed into float32 conversions. please provide either monkey patch centralized in one module dedicated to AMD or explicit AMD kernels
  • wgp change is unrelated to AMD and should be in a separate PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants