diff --git a/README.md b/README.md index 5b7654868d..e4a8346ce1 100644 --- a/README.md +++ b/README.md @@ -110,6 +110,15 @@ source $HOME/.local/bin/env uv sync --all-extras ``` +4.1. On aarch64 hosts: build flash-attn from source for your GPU + +> *NOTE*: aarch64 has no prebuilt flash-attn wheel. This step compiles the CUDA extension for your local GPU (~20-30 minutes). Compute capability is auto-detected from `nvidia-smi`; override with `TORCH_CUDA_ARCH_LIST=9.0` (Hopper) / `10.0` (Blackwell) if needed. +> *NOTE*: After this step, you can't run `uv sync --all-extras` or `uv run` as it will uninstall the package, you can avoid it by running `uv sync --inexact` or `uv run --no-sync`. + +```bash +bash scripts/docker-arm64-post-install.sh +``` + 3.1. Optional: Install Flash Attention 3 (on Hopper GPUs only, for flash_attention_3 attention backend) > *NOTE*: This step will take a while, as it builds the Flash Attention 3 extension from source, as it has no wheels prebuilt. diff --git a/scripts/docker-arm64-post-install.sh b/scripts/docker-arm64-post-install.sh index f02b3070b3..55f85a3a03 100755 --- a/scripts/docker-arm64-post-install.sh +++ b/scripts/docker-arm64-post-install.sh @@ -1,17 +1,44 @@ #!/bin/bash -# arm64 post-install fixups for Docker builds. -set -e +# arm64 post-install fixups: rebuild flash-attn from source for the target GPU. +# +# Why this exists: pyproject.toml sets FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE to keep +# `uv sync` fast; on x86_64 it pins a prebuilt wheel to fill in the binary, but no +# such wheel exists for aarch64. Without this script, `import flash_attn` fails on +# aarch64 with `ModuleNotFoundError: No module named 'flash_attn_2_cuda'`. +# +# Defaults preserve the existing Docker behavior (sm_100 / GB200). On a host with +# `nvidia-smi` available, the compute capability is auto-detected from the local +# GPU. Override via env vars if needed: +# TORCH_CUDA_ARCH_LIST e.g. 9.0 (Hopper), 10.0 (Blackwell) +# VENV_PATH path to the venv (default: $(pwd)/.venv) +# MAX_JOBS parallel nvcc jobs (default: 4) +set -euo pipefail -echo "=== building flash-attn from source (sm_100 / GB200) ===" -# Run from /tmp so uv doesn't read pyproject.toml's [tool.uv.extra-build-variables] -# which sets FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE and prevents CUDA kernel compilation. -export TORCH_CUDA_ARCH_LIST="10.0" -export MAX_JOBS=4 +if [ -z "${TORCH_CUDA_ARCH_LIST:-}" ]; then + # Try to detect from the local GPU. Tolerate any failure mode (binary missing, + # driver not loaded, Docker buildx without --gpus) and fall back to GB200. + TORCH_CUDA_ARCH_LIST="$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null | head -1 | tr -d ' ' || true)" + : "${TORCH_CUDA_ARCH_LIST:=10.0}" +fi +export TORCH_CUDA_ARCH_LIST + +VENV_PATH="${VENV_PATH:-$(pwd)/.venv}" +if [ ! -x "$VENV_PATH/bin/python" ]; then + echo "ERROR: no python at $VENV_PATH/bin/python. Run from the project root or set VENV_PATH." >&2 + exit 1 +fi + +export MAX_JOBS="${MAX_JOBS:-4}" export FLASH_ATTENTION_FORCE_BUILD=TRUE export FLASH_ATTENTION_SKIP_CUDA_BUILD=FALSE -(cd /tmp && uv pip install --python /app/.venv/bin/python \ - "flash-attn==2.8.3" --no-build-isolation --no-binary flash-attn --no-cache) + +echo "=== building flash-attn from source (TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST, MAX_JOBS=$MAX_JOBS) ===" +echo " target venv: $VENV_PATH" +# Run from /tmp so uv ignores the project's [tool.uv.extra-build-variables], +# which sets FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE and would prevent kernel compilation. +(cd /tmp && uv pip install --python "$VENV_PATH/bin/python" \ + "flash-attn==2.8.3" --no-build-isolation --no-binary flash-attn --no-cache --reinstall-package flash-attn) echo "=== reinstalling flash-attn-cute (flash-attn overwrites it with a stub) ===" -uv pip install --reinstall --no-deps \ +uv pip install --python "$VENV_PATH/bin/python" --reinstall --no-deps \ "flash-attn-4 @ git+https://github.com/Dao-AILab/flash-attention.git@96bd151#subdirectory=flash_attn/cute" diff --git a/scripts/install.sh b/scripts/install.sh index 3b5cfb918b..dd7fa8340f 100755 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -145,6 +145,16 @@ main() { log_info "Installing pre-commit hooks..." uv run pre-commit install + # aarch64 has no prebuilt flash-attn wheel; build it from source for the local GPU. + # Without this, `import flash_attn` fails with `ModuleNotFoundError: flash_attn_2_cuda`. + # Run last so no subsequent uv operation (which implicitly syncs against the lockfile) + # rebuilds flash-attn from PyPI with FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE and undoes this. + if [ "$(uname -m)" = "aarch64" ]; then + log_info "aarch64 detected: building flash-attn from source (this takes 20-30 minutes)..." + log_warn "Future 'uv sync --all-extras' or 'uv run' will remove this build. Use 'uv sync --inexact' or 'uv run --no-sync' to keep it." + bash scripts/docker-arm64-post-install.sh + fi + log_info "Installation completed!" }