Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,35 @@ conda activate hydra
pip install -r requirements.txt
```

<details>
<summary><b>AMD ROCm GPU support (optional)</b></summary>

HyDRA works on AMD GPUs via ROCm. Replace Step 3 with:

```bash
# Install ROCm PyTorch (ROCm 6.4 example)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4

# Install remaining deps (skip torch/cupy)
grep -vEi "^(torch|torchvision|cupy)" requirements.txt > /tmp/req_clean.txt
pip install -r /tmp/req_clean.txt

# (Recommended) Install FlashAttention for ~16-19% faster inference
# Option A — source build with Triton backend (ROCm 6.x):
git clone https://github.com/Dao-AILab/flash-attention.git /tmp/flash-attention
cd /tmp/flash-attention && git submodule update --init third_party/aiter
FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE python setup.py install

# Option B — AMD AITER for best performance on ROCm 7.x:
# pip install aiter (or build from https://github.com/ROCm/aiter)
# AITER auto-dispatches to CK kernels on ROCm 7.x (~25% faster than Triton)
```

The attention backend is selected automatically at import time:
FA3 → AITER → FA2 → SageAttention → PyTorch SDPA (fallback).

</details>

### Step 4: Download the pretrained Wan2.1 (1.3B) T2V model

- Model link: https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B
Expand Down
36 changes: 31 additions & 5 deletions diffsynth/models/wan_video_dit.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,48 @@
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple, Optional
from einops import rearrange
from .utils import hash_state_dict_keys
# from .ucpe_attention import UcpeSelfAttention

logger = logging.getLogger(__name__)

try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
except (ImportError, ModuleNotFoundError):
FLASH_ATTN_3_AVAILABLE = False

try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
except (ImportError, ModuleNotFoundError):
FLASH_ATTN_2_AVAILABLE = False

try:
import importlib as _il
_aiter_mha = _il.import_module("aiter.ops.mha")
_aiter_flash_attn_func = _aiter_mha.flash_attn_func
AITER_AVAILABLE = True
except Exception:
AITER_AVAILABLE = False

try:
from sageattention import sageattn
SAGE_ATTN_AVAILABLE = True
except ModuleNotFoundError:
SAGE_ATTN_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
SAGE_ATTN_AVAILABLE = False

_ATTN_BACKEND = (
"flash_attn_3" if FLASH_ATTN_3_AVAILABLE else
"aiter" if AITER_AVAILABLE else
"flash_attn_2" if FLASH_ATTN_2_AVAILABLE else
"sage_attn" if SAGE_ATTN_AVAILABLE else
"sdpa"
)
logger.info("Attention backend: %s", _ATTN_BACKEND)

def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False):
if compatibility_mode:
Expand All @@ -37,6 +57,12 @@ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v)
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif AITER_AVAILABLE:
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = _aiter_flash_attn_func(q, k, v)
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif FLASH_ATTN_2_AVAILABLE:
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
Expand Down