fix(install): support aarch64 hosts with any compute capability#2587
Open
andre-fu wants to merge 3 commits into
Open
fix(install): support aarch64 hosts with any compute capability#2587andre-fu wants to merge 3 commits into
andre-fu wants to merge 3 commits into
Conversation
23c7b29 to
b7045c5
Compare
b7045c5 to
5684161
Compare
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit 5684161. Configure here.
The aarch64 host install path was broken: `uv sync` installs flash-attn from PyPI source but pyproject sets FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE, so the compiled extension never builds. `scripts/docker-arm64-post-install.sh` fixed it for Docker GB200 builds but hardcoded sm_100 and /app/.venv, leaving Hopper hosts (H100/H200/GH200) without a recipe. Changes: - `scripts/docker-arm64-post-install.sh`: auto-detect compute capability via nvidia-smi when available; parameterize venv path. Preserves the sm_100 default when no GPU is visible (Docker buildx). - `scripts/install.sh`: call the post-install for aarch64 hosts after `uv sync --all-extras`. Previously the script ran uv sync and exited, leaving aarch64 users with a broken venv. - `README.md`: document the aarch64 post-install step (mirrors the existing 3.1 Flash Attention 3 pattern). Validated on GH200 (sm_90, aarch64): - forward + backward parity vs torch SDPA (max diff < 0.05 / 0.25) - 383/384 unit tests pass (the 1 failure is unrelated TileLang/MoE) - SFT trainer smoke test (5 steps, Qwen3-0.6B) runs with flash_attention_2 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
5684161 to
7c5a596
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.

Summary
scripts/install.shrunsuv sync --all-extrasand exits — leaving aarch64 hosts with a broken venv (import flash_attnfails becausepyproject.tomlsetsFLASH_ATTENTION_SKIP_CUDA_BUILD=TRUEand thex86_64-only prebuilt wheel doesn't match).
scripts/docker-arm64-post-install.shsolved this for Docker GB200 builds but hardcodedTORCH_CUDA_ARCH_LIST="10.0"and/app/.venv, so it didn't apply to Hopperhosts (H100/H200/GH200).
scripts/docker-arm64-post-install.sh: auto-detectTORCH_CUDA_ARCH_LISTfromnvidia-smi --query-gpu=compute_capwhen available; parameterizeVENV_PATH(default$(pwd)/.venv). Preserves the sm_100default when no GPU is visible (Docker buildx with no
--gpus), so existing GB200 image builds are unchanged.scripts/install.sh: afteruv sync --all-extras, invoke the post-install on aarch64 hosts. Includes alog_warnreminder that futureuv sync --all-extraswill wipe the build (same caveat as the existingFA3 flow).
README.md: new step 4.1 documenting the aarch64 post-install, mirrored on the existing 3.1 Flash Attention 3 section.Verification
Tested on a Lambda Labs GH200 480GB (aarch64, sm_90, torch 2.11+cu128, cp312).
flash-attn correctness vs
torch.nn.functional.scaled_dot_product_attentionForward + backward across 6 shape/dtype/causal combos:
Worst output diff 0.0039 (well under 0.05); worst gradient diff 0.0156 (well under 0.25) — typical bf16 vs fp32-accumulation noise.
pytest tests/unit383 passed, 1 xfailed, 1 failed in 156s. The failure is
tests/unit/train/models/test_qwen3_5_moe.py::test_qwen3_5_moe— a pre-existing TileLang/TVM codegen issue inflash-linear-attention's MoE backwardkernel, unrelated to flash-attn 2 / this change.
SFT trainer smoke test
All 5 steps complete with attn='flash_attention_2' (per startup config) at ~4500 tok/s, peak memory 13.4 GiB, no errors.
Known follow-ups (not in this PR)
A prebuilt aarch64 wheel hosted on the prime-rl release page (mirroring the x86_64 wheel from mjun0812/flash-attention-prebuild-wheels) would let uv sync just work without the post-install step. Happy to file
separately and attach my built wheel (flash_attn-2.8.3-cp312-cp312-linux_aarch64.whl, 253 MB compressed) if useful.
Note
Medium Risk
Changes the installation flow and build scripts to compile CUDA extensions on aarch64, which can affect setup reliability and build times across different GPU/driver environments, but is limited to install-time behavior.
Overview
Fixes broken
flash-attninstalls on aarch64 by adding an install-time post-step that rebuildsflash-attnfrom source for the local GPU and warns users about avoiding futureuvsyncs that would uninstall it.Updates
scripts/docker-arm64-post-install.shto be host-friendly (auto-detectTORCH_CUDA_ARCH_LISTvianvidia-smi, configurableVENV_PATH/MAX_JOBS, safer defaults) and documents the required aarch64 post-install step in theREADME.Reviewed by Cursor Bugbot for commit 187c876. Bugbot is set up for automated code reviews on this repo. Configure here.