diff --git a/README.md b/README.md index aac655f..f6bb128 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,35 @@ conda activate hydra pip install -r requirements.txt ``` +
+AMD ROCm GPU support (optional) + +HyDRA works on AMD GPUs via ROCm. Replace Step 3 with: + +```bash +# Install ROCm PyTorch (ROCm 6.4 example) +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4 + +# Install remaining deps (skip torch/cupy) +grep -vEi "^(torch|torchvision|cupy)" requirements.txt > /tmp/req_clean.txt +pip install -r /tmp/req_clean.txt + +# (Recommended) Install FlashAttention for ~16-19% faster inference +# Option A — source build with Triton backend (ROCm 6.x): +git clone https://github.com/Dao-AILab/flash-attention.git /tmp/flash-attention +cd /tmp/flash-attention && git submodule update --init third_party/aiter +FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE python setup.py install + +# Option B — AMD AITER for best performance on ROCm 7.x: +# pip install aiter (or build from https://github.com/ROCm/aiter) +# AITER auto-dispatches to CK kernels on ROCm 7.x (~25% faster than Triton) +``` + +The attention backend is selected automatically at import time: +FA3 → AITER → FA2 → SageAttention → PyTorch SDPA (fallback). + +
+ ### Step 4: Download the pretrained Wan2.1 (1.3B) T2V model - Model link: https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 8a062bf..953117b 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -1,3 +1,4 @@ +import logging import torch import torch.nn as nn import torch.nn.functional as F @@ -5,24 +6,43 @@ from typing import Tuple, Optional from einops import rearrange from .utils import hash_state_dict_keys -# from .ucpe_attention import UcpeSelfAttention + +logger = logging.getLogger(__name__) + try: import flash_attn_interface FLASH_ATTN_3_AVAILABLE = True -except ModuleNotFoundError: +except (ImportError, ModuleNotFoundError): FLASH_ATTN_3_AVAILABLE = False try: import flash_attn FLASH_ATTN_2_AVAILABLE = True -except ModuleNotFoundError: +except (ImportError, ModuleNotFoundError): FLASH_ATTN_2_AVAILABLE = False +try: + import importlib as _il + _aiter_mha = _il.import_module("aiter.ops.mha") + _aiter_flash_attn_func = _aiter_mha.flash_attn_func + AITER_AVAILABLE = True +except Exception: + AITER_AVAILABLE = False + try: from sageattention import sageattn - SAGE_ATTN_AVAILABLE = True -except ModuleNotFoundError: + SAGE_ATTN_AVAILABLE = True +except (ImportError, ModuleNotFoundError): SAGE_ATTN_AVAILABLE = False + +_ATTN_BACKEND = ( + "flash_attn_3" if FLASH_ATTN_3_AVAILABLE else + "aiter" if AITER_AVAILABLE else + "flash_attn_2" if FLASH_ATTN_2_AVAILABLE else + "sage_attn" if SAGE_ATTN_AVAILABLE else + "sdpa" +) +logger.info("Attention backend: %s", _ATTN_BACKEND) def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False): if compatibility_mode: @@ -37,6 +57,12 @@ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) x = flash_attn_interface.flash_attn_func(q, k, v) x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + elif AITER_AVAILABLE: + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + x = _aiter_flash_attn_func(q, k, v) + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) elif FLASH_ATTN_2_AVAILABLE: q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)