FlashInfer+ROCm is a port of the FlashInfer library to add support for AMD GPUs. The project is in its early active stage of development and does not yet support all functionalities implemented upstream. The feature support matrix lists currently supported features.
To determine which upstream version a specific FlashInfer+ROCm release is
based on, please refer to the release tag. The versioning convention,
<upstream_version>+rocm
, directly links each of the FlashInfer+ROCm releases
to a corresponding upstream tag. For example, the 0.2.5+rocm
release is
synchronized with the upstream v0.2.5
tag.
Feature Support Matrix
Kernel Type | FP16 / BF16 | FP8 (E4M3, E5M2) | Notes |
---|---|---|---|
Decode Attention | ✅ | WIP | Supports MHA, GQA, and MQA variants. |
Prefill Attention | WIP | WIP | MHA/GQA/MQA support is a work in progress. |
Cascade | WIP | WIP | not yet ported. |
MLA | TBD | TBD | not yet ported. |
POD | TBD | TBD | not yet ported. |
Positional Encoding | TBD | TBD | LLaMA RoPE is supported. |
Sampling | TBD | TBD | Top-K/Top-P sampling is not yet ported. |
Normalization | TBD | TBD | RMS-Norm/Layer-Norm is not yet ported. |
GPU Support
Model | Architecture |
---|---|
MI300x | CDNA3 |
ROCm Support
6.3.2, 6.4.1
Docker image compatibility
Docker Image | ROCm | Flashinfer | PyTorch |
---|---|---|---|
TBD | 6.4.1 | 0.2.5 | 2.7.1 |
The recommended docker image to setup a development environment for
FlashInfer+ROCm is rocm/pytorch:rocm6.4.1_ubuntu24.04_py3.12_pytorch_release_2.7.1
docker pull rocm/pytorch:rocm6.4.1_ubuntu24.04_py3.12_pytorch_release_2.7.1
docker run -it --privileged --network=host --device=/dev/kfd --device=/dev/dri \
--group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--ipc=host --shm-size 128G --name=<container name> rocm/pytorch:rocm6.4.1_ubuntu24.04_py3.12_pytorch_release_2.7.1
Run the following command to install micromamba and set it up inside bash
"${SHELL}" <(curl -L micro.mamba.pm/install.sh)
source ~/.bashrc
After installing micromamba a custom environment for FlashInfer+ROCm
development should be setup. The micromamba environment is used to only manage
the Python version, rest of the dependencies are installed using pip
.
# Create a micromamba env for Python 3.12
micromamba create -n <environment_name> python=3.12 -c conda-forge --override-channels
# Activate the environment
micromamba activate <environment_name>
# Install added dependencies using pip
pip install setuptools-scm scikit-build-core pytest numpy cmake ninja pybind11
pip install torch --index-url https://download.pytorch.org/whl/rocm6.4
Clone the latest trunk from https://github.com/ROCm/flashinfer.
git clone https://github.com/ROCm/flashinfer
cd flashinfer/
The Flashinfer+ROCm library can be built in two ways: with ahead-of-time (AOT) compiled kernels and without any AOT kernels.
Building the library with AOT kernels will take more time and local disk space as several common configurations of the core Flashinfer kernels are built during installation.
When building without AOT compilation, every kernel will be just-in-time (JIT) compiled at the time of first use.
- Instructions to build with AOT are as follows:
FLASHINFER_HIP_ARCHITECTURES=gfx942 FLASHINFER_AOT_TORCH_EXTS=ON python -m pip wheel . --wheel-dir=./dist/ --no-deps --no-build-isolation -v
cd dist
pip install flashinfer-*.whl
- Instructions to build using JIT requires setting the FLASHINFER_AOT_TORCH_EXTS build flag to OFF.
FLASHINFER_HIP_ARCHITECTURES=gfx942 python -m pip wheel . --wheel-dir=./dist/ --no-deps --no-build-isolation -v
cd dist
pip install flashinfer-*.whl
Note: The --no-deps
flags assumes that all require dependencies are already available in the build environment. Otherwise, refer the earlier steps to install required packages. If building without first installing all Python and build dependencies, the --no-deps
flag should be omitted. In that case, the build step will download all needed dependencies.
Development mode or editable installs (PEP 660) is supported and can be used with both AOT and JIT builds of the package. To setup an editable install, follow these steps:
FLASHINFER_HIP_ARCHITECTURES=gfx942 python -m pip install --no-build-isolation -ve.
A convenience script is provided inside the example directory that runs two HIPified kernels from Flashinfer: SingleDecodeWithKVCache
and BatchDecodeWithKVCache
and verifies the correctness of the generated results.
Following are the instructions to run the script:
cd examples/
python test_batch_decode_example.py
If Flashinfer+ROCm was installed without AOT kernels, the output should look as follows:
Failed to import __aot_prebuilt_uris__: No module named 'flashinfer.__aot_prebuilt_uris__'
2025-07-23 21:45:24,657 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend
PASS
Flashinfer+ROCm provides a C++ test suite to test all HIP kernels and C++ code. Installing Flashinfer does not automatically install the tests, instead these have to be configured separately.
- To configure the rest of the test suite
cd flashinfer/libflashinfer/tests/hip
mkdir build && cd build/
cmake -DCMAKE_CXX_COMPILER:PATH=/opt/rocm/bin/amdclang++ -DFLASHINFER_INCLUDE_DIRS=<path to flashinfer includ dirs> ..
ninja
- To run individual tests
./test_<target_test_name>
- To run all tests
ctest
The output should look something like this
Test project /root/flashinfer/libflashinfer/tests/hip/build
Start 1: MathTest
1/6 Test #1: MathTest ......................... Passed 3.40 sec
Start 2: PosEncTest
2/6 Test #2: PosEncTest ....................... Passed 3.40 sec
Start 3: CascadeTest
3/6 Test #3: CascadeTest ...................... Passed 985.27 sec
Start 4: PageTest
4/6 Test #4: PageTest ......................... Passed 112.40 sec
Start 5: SingleDecodeTest
5/6 Test #5: SingleDecodeTest ................. Passed 35.46 sec
Start 6: BatchDecodeTest
6/6 Test #6: BatchDecodeTest .................. Passed 556.81 sec
100% tests passed, 0 tests failed out of 6
To run pytests, run the following helper script from the project root directory:
cd scripts/
./run_hip_tests.sh
git clone https://github.com/ROCm/flashinfer
docker build -f docker/Dockerfile.rocm_ci --target flashinfer_base -t <docker-image-tag> .
docker run -it --privileged --network=host --device=/dev/kfd --device=/dev/dri \
--group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--ipc=host --shm-size 128G --name=<container name> <docker-image-tag>
The docker container will come with micromamba pre-installed. It also builds
and pre-installs Flashinfer AOT version. To use Flashiner, first activate the
environment and then use the flashinfer
package from Python.
micromamba activate flashinfer-py3.12-torch2.7.1-rocm6.4.1
import torch
import flashinfer
kv_len = 2048
num_kv_heads = 32
head_dim = 128
k = torch.randn(kv_len, num_kv_heads, head_dim).half().to(0)
v = torch.randn(kv_len, num_kv_heads, head_dim).half().to(0)
# decode attention
num_qo_heads = 32
q = torch.randn(num_qo_heads, head_dim).half().to(0)
o = flashinfer.single_decode_with_kv_cache(q, k, v) # decode attention without RoPE on-the-fly
o_rope_on_the_fly = flashinfer.single_decode_with_kv_cache(q, k, v, pos_encoding_mode="ROPE_LLAMA") # decode with LLaMA style RoPE on-the-fly