Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,15 @@ source $HOME/.local/bin/env
uv sync --all-extras
```

4.1. On aarch64 hosts: build flash-attn from source for your GPU

> *NOTE*: aarch64 has no prebuilt flash-attn wheel. This step compiles the CUDA extension for your local GPU (~20-30 minutes). Compute capability is auto-detected from `nvidia-smi`; override with `TORCH_CUDA_ARCH_LIST=9.0` (Hopper) / `10.0` (Blackwell) if needed.
> *NOTE*: After this step, you can't run `uv sync --all-extras` or `uv run` as it will uninstall the package, you can avoid it by running `uv sync --inexact` or `uv run --no-sync`.

```bash
bash scripts/docker-arm64-post-install.sh
```

3.1. Optional: Install Flash Attention 3 (on Hopper GPUs only, for flash_attention_3 attention backend)

> *NOTE*: This step will take a while, as it builds the Flash Attention 3 extension from source, as it has no wheels prebuilt.
Expand Down
47 changes: 37 additions & 10 deletions scripts/docker-arm64-post-install.sh
Original file line number Diff line number Diff line change
@@ -1,17 +1,44 @@
#!/bin/bash
# arm64 post-install fixups for Docker builds.
set -e
# arm64 post-install fixups: rebuild flash-attn from source for the target GPU.
#
# Why this exists: pyproject.toml sets FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE to keep
# `uv sync` fast; on x86_64 it pins a prebuilt wheel to fill in the binary, but no
# such wheel exists for aarch64. Without this script, `import flash_attn` fails on
# aarch64 with `ModuleNotFoundError: No module named 'flash_attn_2_cuda'`.
#
# Defaults preserve the existing Docker behavior (sm_100 / GB200). On a host with
# `nvidia-smi` available, the compute capability is auto-detected from the local
# GPU. Override via env vars if needed:
# TORCH_CUDA_ARCH_LIST e.g. 9.0 (Hopper), 10.0 (Blackwell)
# VENV_PATH path to the venv (default: $(pwd)/.venv)
# MAX_JOBS parallel nvcc jobs (default: 4)
set -euo pipefail

echo "=== building flash-attn from source (sm_100 / GB200) ==="
# Run from /tmp so uv doesn't read pyproject.toml's [tool.uv.extra-build-variables]
# which sets FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE and prevents CUDA kernel compilation.
export TORCH_CUDA_ARCH_LIST="10.0"
export MAX_JOBS=4
if [ -z "${TORCH_CUDA_ARCH_LIST:-}" ]; then
# Try to detect from the local GPU. Tolerate any failure mode (binary missing,
# driver not loaded, Docker buildx without --gpus) and fall back to GB200.
TORCH_CUDA_ARCH_LIST="$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null | head -1 | tr -d ' ' || true)"
: "${TORCH_CUDA_ARCH_LIST:=10.0}"
fi
Comment thread
cursor[bot] marked this conversation as resolved.
export TORCH_CUDA_ARCH_LIST

VENV_PATH="${VENV_PATH:-$(pwd)/.venv}"
if [ ! -x "$VENV_PATH/bin/python" ]; then
echo "ERROR: no python at $VENV_PATH/bin/python. Run from the project root or set VENV_PATH." >&2
exit 1
fi

export MAX_JOBS="${MAX_JOBS:-4}"
export FLASH_ATTENTION_FORCE_BUILD=TRUE
export FLASH_ATTENTION_SKIP_CUDA_BUILD=FALSE
(cd /tmp && uv pip install --python /app/.venv/bin/python \
"flash-attn==2.8.3" --no-build-isolation --no-binary flash-attn --no-cache)

echo "=== building flash-attn from source (TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST, MAX_JOBS=$MAX_JOBS) ==="
echo " target venv: $VENV_PATH"
# Run from /tmp so uv ignores the project's [tool.uv.extra-build-variables],
# which sets FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE and would prevent kernel compilation.
(cd /tmp && uv pip install --python "$VENV_PATH/bin/python" \
"flash-attn==2.8.3" --no-build-isolation --no-binary flash-attn --no-cache --reinstall-package flash-attn)

echo "=== reinstalling flash-attn-cute (flash-attn overwrites it with a stub) ==="
uv pip install --reinstall --no-deps \
uv pip install --python "$VENV_PATH/bin/python" --reinstall --no-deps \
"flash-attn-4 @ git+https://github.com/Dao-AILab/flash-attention.git@96bd151#subdirectory=flash_attn/cute"
10 changes: 10 additions & 0 deletions scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,16 @@ main() {
log_info "Installing pre-commit hooks..."
uv run pre-commit install

# aarch64 has no prebuilt flash-attn wheel; build it from source for the local GPU.
# Without this, `import flash_attn` fails with `ModuleNotFoundError: flash_attn_2_cuda`.
# Run last so no subsequent uv operation (which implicitly syncs against the lockfile)
# rebuilds flash-attn from PyPI with FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE and undoes this.
if [ "$(uname -m)" = "aarch64" ]; then
log_info "aarch64 detected: building flash-attn from source (this takes 20-30 minutes)..."
log_warn "Future 'uv sync --all-extras' or 'uv run' will remove this build. Use 'uv sync --inexact' or 'uv run --no-sync' to keep it."
bash scripts/docker-arm64-post-install.sh
fi

Comment thread
cursor[bot] marked this conversation as resolved.
log_info "Installation completed!"
}

Expand Down