Skip to content

fix(install): support aarch64 hosts with any compute capability#2587

Open
andre-fu wants to merge 3 commits into
PrimeIntellect-ai:mainfrom
andre-fu:fix/aarch64-flash-attn-install
Open

fix(install): support aarch64 hosts with any compute capability#2587
andre-fu wants to merge 3 commits into
PrimeIntellect-ai:mainfrom
andre-fu:fix/aarch64-flash-attn-install

Conversation

@andre-fu
Copy link
Copy Markdown

@andre-fu andre-fu commented May 21, 2026

Summary

  • scripts/install.sh runs uv sync --all-extras and exits — leaving aarch64 hosts with a broken venv (import flash_attn fails because pyproject.toml sets FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE and the
    x86_64-only prebuilt wheel doesn't match). scripts/docker-arm64-post-install.sh solved this for Docker GB200 builds but hardcoded TORCH_CUDA_ARCH_LIST="10.0" and /app/.venv, so it didn't apply to Hopper
    hosts (H100/H200/GH200).
  • scripts/docker-arm64-post-install.sh: auto-detect TORCH_CUDA_ARCH_LIST from nvidia-smi --query-gpu=compute_cap when available; parameterize VENV_PATH (default $(pwd)/.venv). Preserves the sm_100
    default when no GPU is visible (Docker buildx with no --gpus), so existing GB200 image builds are unchanged.
  • scripts/install.sh: after uv sync --all-extras, invoke the post-install on aarch64 hosts. Includes a log_warn reminder that future uv sync --all-extras will wipe the build (same caveat as the existing
    FA3 flow).
  • README.md: new step 4.1 documenting the aarch64 post-install, mirrored on the existing 3.1 Flash Attention 3 section.

Verification

Tested on a Lambda Labs GH200 480GB (aarch64, sm_90, torch 2.11+cu128, cp312).

flash-attn correctness vs torch.nn.functional.scaled_dot_product_attention

Forward + backward across 6 shape/dtype/causal combos:

B S H D dtype causal max out diff max grad diff
1 128 4 64 float16 False 0.0001 0.0002
1 128 4 64 float16 True 0.0001 0.0005
2 512 8 64 bfloat16 False 0.0020 0.0020
2 512 8 64 bfloat16 True 0.0039 0.0078
1 2048 16 128 bfloat16 True 0.0039 0.0156
4 1024 12 64 bfloat16 True 0.0039 0.0156

Worst output diff 0.0039 (well under 0.05); worst gradient diff 0.0156 (well under 0.25) — typical bf16 vs fp32-accumulation noise.

pytest tests/unit

383 passed, 1 xfailed, 1 failed in 156s. The failure is tests/unit/train/models/test_qwen3_5_moe.py::test_qwen3_5_moe — a pre-existing TileLang/TVM codegen issue in flash-linear-attention's MoE backward
kernel, unrelated to flash-attn 2 / this change.

SFT trainer smoke test

uv run --no-sync sft @ configs/debug/sft/train.toml

All 5 steps complete with attn='flash_attention_2' (per startup config) at ~4500 tok/s, peak memory 13.4 GiB, no errors.

Known follow-ups (not in this PR)

A prebuilt aarch64 wheel hosted on the prime-rl release page (mirroring the x86_64 wheel from mjun0812/flash-attention-prebuild-wheels) would let uv sync just work without the post-install step. Happy to file
separately and attach my built wheel (flash_attn-2.8.3-cp312-cp312-linux_aarch64.whl, 253 MB compressed) if useful.


Note

Medium Risk
Changes the installation flow and build scripts to compile CUDA extensions on aarch64, which can affect setup reliability and build times across different GPU/driver environments, but is limited to install-time behavior.

Overview
Fixes broken flash-attn installs on aarch64 by adding an install-time post-step that rebuilds flash-attn from source for the local GPU and warns users about avoiding future uv syncs that would uninstall it.

Updates scripts/docker-arm64-post-install.sh to be host-friendly (auto-detect TORCH_CUDA_ARCH_LIST via nvidia-smi, configurable VENV_PATH/MAX_JOBS, safer defaults) and documents the required aarch64 post-install step in the README.

Reviewed by Cursor Bugbot for commit 187c876. Bugbot is set up for automated code reviews on this repo. Configure here.

Comment thread scripts/install.sh Outdated
@andre-fu andre-fu force-pushed the fix/aarch64-flash-attn-install branch from 23c7b29 to b7045c5 Compare May 21, 2026 21:56
Comment thread scripts/install.sh
@andre-fu andre-fu force-pushed the fix/aarch64-flash-attn-install branch from b7045c5 to 5684161 Compare May 21, 2026 22:16
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit 5684161. Configure here.

Comment thread scripts/docker-arm64-post-install.sh
The aarch64 host install path was broken: `uv sync` installs flash-attn
from PyPI source but pyproject sets FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE,
so the compiled extension never builds. `scripts/docker-arm64-post-install.sh`
fixed it for Docker GB200 builds but hardcoded sm_100 and /app/.venv,
leaving Hopper hosts (H100/H200/GH200) without a recipe.

Changes:
- `scripts/docker-arm64-post-install.sh`: auto-detect compute capability
  via nvidia-smi when available; parameterize venv path. Preserves the
  sm_100 default when no GPU is visible (Docker buildx).
- `scripts/install.sh`: call the post-install for aarch64 hosts after
  `uv sync --all-extras`. Previously the script ran uv sync and exited,
  leaving aarch64 users with a broken venv.
- `README.md`: document the aarch64 post-install step (mirrors the
  existing 3.1 Flash Attention 3 pattern).

Validated on GH200 (sm_90, aarch64):
- forward + backward parity vs torch SDPA (max diff < 0.05 / 0.25)
- 383/384 unit tests pass (the 1 failure is unrelated TileLang/MoE)
- SFT trainer smoke test (5 steps, Qwen3-0.6B) runs with flash_attention_2

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@andre-fu andre-fu force-pushed the fix/aarch64-flash-attn-install branch from 5684161 to 7c5a596 Compare May 21, 2026 22:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant