Skip to content

Conversation

@zhang-hui-yulo
Copy link
Contributor

@zhang-hui-yulo zhang-hui-yulo commented Dec 30, 2025

Add native fattn-mma-f16 for RDNA4, all tests passed, need perf tuning.

All tests have been executed more than 5 times to check random data error, no random data error now.

resolves #18243

  • Pass FLASH_ATTN_EXT on RDNA4.
  • Disable fattn-mma-f16 for RDNA3, add RDNA3 support in the future.
  • Perf tuning for RDNA4.

FLASH_ATTN_EXT .txt

@zhang-hui-yulo zhang-hui-yulo changed the title HIP: add fattn-mma for RDNA4 HIP: add fattn-mma-f16 for RDNA4 Dec 30, 2025
@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Dec 30, 2025
@JohannesGaessler
Copy link
Collaborator

Since this PR is currently a draft with open TODOs: please ping me at those times you would like a review, otherwise I'll be focusing on other matters.

@zhang-hui-yulo
Copy link
Contributor Author

Hello @JohannesGaessler ,

Do you have a good way to measure the perf of fattn-mma-f16? The current perf tests of FLASH_ATTN_EXT only use vec fattn and tile fattn, thank you.

Best Regards
Hui

@JohannesGaessler
Copy link
Collaborator

Use something like llama-bench -n 0 -d 32768 -p "512,1-256*2". My recommendation would be to always use a real model with llama-bench unless this is not viable for some reason.

@zhang-hui-yulo
Copy link
Contributor Author

OK, I got the bad news, fattn-mma is 25% slower than fattn-wmma, looks like that the most increased workload is in ldmatrix_trans, not sure how to deal it faster on rdna4.

@zhang-hui-yulo
Copy link
Contributor Author

zhang-hui-yulo commented Jan 7, 2026

Hello @JohannesGaessler

Shall be done now, just pass test-backend-ops and llama-bench, using identity matrix and mma to do register transpose is faster than native loading.

cmake -S . -B build -DGGML_HIP=ON -DGPU_TARGETS=gfx1201 -DCMAKE_BUILD_TYPE=Release -DLLAMA_CURL=OFF -DGGML_HIP_ROCWMMA_FATTN=ON
Details
CUDA_VISIBLE_DEVICES=0 ./build/bin/llama-bench --model ../../models/DeepSeek-R1-Distill-Qwen-1.5B/DeepSeek-R1-Distill-Qwen-1.5B_f16.gguf -r 1 -fa 1 -n 0 -d 32768 -p "512,1-512*2" --progress -o sql | sqlite3 ../../models/DeepSeek-R1-Distill-Qwen-1.5B/DeepSeek-R1-Distill-Qwen-1.5B_f16.sqlite
DeepSeek-R1-Distill-Qwen-1.5B_f16.gguf
Model Test t/s master t/s fattn_for_rdna4 Speedup
qwen2 1.5B F16 pp1@d32768 79.31 75.60 0.95
qwen2 1.5B F16 pp2@d32768 108.58 108.84 1.00
qwen2 1.5B F16 pp4@d32768 115.21 128.15 1.11
qwen2 1.5B F16 pp8@d32768 194.81 218.75 1.12
qwen2 1.5B F16 pp16@d32768 351.71 373.34 1.06
qwen2 1.5B F16 pp32@d32768 586.19 651.52 1.11
qwen2 1.5B F16 pp64@d32768 844.73 758.69 0.90
qwen2 1.5B F16 pp128@d32768 950.54 827.33 0.87
qwen2 1.5B F16 pp256@d32768 1017.40 1191.77 1.17
qwen2 1.5B F16 pp512@d32768 1047.45 1231.11 1.18

FLASH_ATTN_EXT_latest.txt

Best Regards
Hui

@zhang-hui-yulo zhang-hui-yulo marked this pull request as ready for review January 7, 2026 07:44
Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In terms of correctness and the way it's implemented I approve. Performance does not need to be optimal for a marge, this can be improved in follow-up PRs. Probably I would want to do some refactors down the line but this is a job for me as a maintainer.

Did you make a conscious decision when you copied the Turing configuration or did you just pick one of them? The context is that I first wrote the kernel for Ampere or newer with 99 kiB of SRAM/SM, an occupancy of 2, head sizes <= 256, and <= 64 Q columns. I later extended the kernel with support for Turing with 64 kiB of SRAM/SM and head sizes 576/512 for DeepSeek. Probably it would make sense to try more configurations that can potentially fit into the 128 kiB of SRAM/CU on RDNA4.

The kernel selection logic you added in fattn.cu probably needs to be improved prior to a merge, particularly when it comes to quantized KV cache.

@JohannesGaessler
Copy link
Collaborator

Quick performance test at the default settings:

GPU Model Microbatch size Test t/s b7653 t/s 103141f Speedup
RX 9060 XT llama 8B Q4_0 1 pp512@d32768 29.45 29.03 0.99
RX 9060 XT llama 8B Q4_0 2 pp512@d32768 57.01 48.12 0.84
RX 9060 XT llama 8B Q4_0 4 pp512@d32768 72.29 82.72 1.14
RX 9060 XT llama 8B Q4_0 8 pp512@d32768 94.93 113.99 1.20
RX 9060 XT llama 8B Q4_0 16 pp512@d32768 148.10 191.49 1.29
RX 9060 XT llama 8B Q4_0 32 pp512@d32768 165.09 207.68 1.26
RX 9060 XT llama 8B Q4_0 64 pp512@d32768 250.37 282.48 1.13
RX 9060 XT llama 8B Q4_0 128 pp512@d32768 282.13 330.75 1.17
RX 9060 XT llama 8B Q4_0 256 pp512@d32768 306.39 353.56 1.15
RX 9060 XT llama 8B Q4_0 512 pp512@d32768 312.65 357.47 1.14

LLaMA 3 has a head size of 128 which is the one that the code is generally most optimized for. With a GQA ratio of 4 you need a physical batch size of >= 4 to fully utilize the WMMA tiles with a width of 16, at that point the new implementation seems to already be faster than a combination of the tile and vector kernels.

@JohannesGaessler
Copy link
Collaborator

I forgot: I'm running llama-bench like this:

./build/bin/llama-bench --model models/opt/${mn}-${q}.gguf -r 1 -fa 1 -n 0 -d 32768 -ub "512,1-256*2" --progress -o sql|sqlite3 llama-bench.sqlite

@IMbackK
Copy link
Collaborator

IMbackK commented Jan 7, 2026

LLaMA 3 has a head size of 128 which is the one that the code is generally most optimized for. With a GQA ratio of 4 you need a physical batch size of >= 4 to fully utilize the WMMA tiles with a width of 16, at that point the new implementation seems to already be faster than a combination of the tile and vector kernels.

What about compared to the wmma kernel?

@JohannesGaessler
Copy link
Collaborator

I did the rocWMMA benchmark wrong so I had to re-do it, these are the results:

GPU Model Microbatch size Test t/s b7653 t/s 103141f Speedup
RX 9060 XT llama 8B Q4_0 1 pp512@d32768 29.11 29.04 1.00
RX 9060 XT llama 8B Q4_0 2 pp512@d32768 48.60 48.19 0.99
RX 9060 XT llama 8B Q4_0 4 pp512@d32768 38.85 82.86 2.13
RX 9060 XT llama 8B Q4_0 8 pp512@d32768 65.95 114.03 1.73
RX 9060 XT llama 8B Q4_0 16 pp512@d32768 138.99 192.13 1.38
RX 9060 XT llama 8B Q4_0 32 pp512@d32768 181.99 208.31 1.14
RX 9060 XT llama 8B Q4_0 64 pp512@d32768 210.27 282.55 1.34
RX 9060 XT llama 8B Q4_0 128 pp512@d32768 259.25 330.98 1.28
RX 9060 XT llama 8B Q4_0 256 pp512@d32768 266.39 354.36 1.33
RX 9060 XT llama 8B Q4_0 512 pp512@d32768 271.81 357.62 1.32

On RDNA4 it seems to be faster than the rocWMMA kernel as it exists on master.

@zhang-hui-yulo
Copy link
Contributor Author

In terms of correctness and the way it's implemented I approve. Performance does not need to be optimal for a marge, this can be improved in follow-up PRs. Probably I would want to do some refactors down the line but this is a job for me as a maintainer.

Did you make a conscious decision when you copied the Turing configuration or did you just pick one of them? The context is that I first wrote the kernel for Ampere or newer with 99 kiB of SRAM/SM, an occupancy of 2, head sizes <= 256, and <= 64 Q columns. I later extended the kernel with support for Turing with 64 kiB of SRAM/SM and head sizes 576/512 for DeepSeek. Probably it would make sense to try more configurations that can potentially fit into the 128 kiB of SRAM/CU on RDNA4.

The kernel selection logic you added in fattn.cu probably needs to be improved prior to a merge, particularly when it comes to quantized KV cache.

I just copy the config from Turing and haven't started serious tuning yet as transpose loading wastes too much of my time, hopefully the future AMD GPUs will have transpose loading from shared memory, RDNA4 global transpose loading doesn't help much.

I will have a try to make a better config in the next couple of days.

@JohannesGaessler
Copy link
Collaborator

I think the issue of transposition can be fixed upon loading the V data from VRAM to SRAM in combination with a permutation of VKQ.

@zhang-hui-yulo
Copy link
Contributor Author

I think the issue of transposition can be fixed upon loading the V data from VRAM to SRAM in combination with a permutation of VKQ.

This is also what I thought, transpose the data via gmem to smem, but I really cannot clean up the transpose loading by native CUDA, cute is much easier, so I just spend sometime to try if identity mat and mma can do the transpose as it can also help my another project, I will suggest to add a TODO first.

Besides, based on the spec of gfx950, I think the next gen of RDNA shall have transpose loading from smem, this will make the things much easier and no need change the loading logic.

@JohannesGaessler
Copy link
Collaborator

Can you give me a brief outline of what work you still want to do for this PR and what you intend to do in follow-up PRs in the near future? I think I know how to handle the transposition for RDNA GPUs as they already exist but since I'm working on multiple things in parallel I would prefer to avoid concurrent work on the same code.

@JohannesGaessler
Copy link
Collaborator

Comparison with #16827 where the WMMA kernel was tuned for RDNA3:

GPU Model Microbatch size Test t/s b7653 t/s 103141f t/s #16827
RX 9060 XT llama 8B Q4_0 1 pp512@d32768 29.11 29.04 25.54
RX 9060 XT llama 8B Q4_0 2 pp512@d32768 48.60 48.19 37.81
RX 9060 XT llama 8B Q4_0 4 pp512@d32768 38.85 82.86 71.59
RX 9060 XT llama 8B Q4_0 8 pp512@d32768 65.95 114.03 112.20
RX 9060 XT llama 8B Q4_0 16 pp512@d32768 138.99 192.13 246.41
RX 9060 XT llama 8B Q4_0 32 pp512@d32768 181.99 208.31 293.08
RX 9060 XT llama 8B Q4_0 64 pp512@d32768 210.27 282.55 125.35
RX 9060 XT llama 8B Q4_0 128 pp512@d32768 259.25 330.98 184.33
RX 9060 XT llama 8B Q4_0 256 pp512@d32768 266.39 354.36 189.31
RX 9060 XT llama 8B Q4_0 512 pp512@d32768 271.81 357.62 188.91

The RDNA3 tunings seem to have been detrimental for large batch FA andr RDNA4 and this PR seems to be the fastest to date. There are some intermediate batch sizes where this PR seems to still be suboptimal but I think this is an issue with tuning.

@JohannesGaessler
Copy link
Collaborator

I forgot: it's probably worthwhile to check the logic in fattn-common.cuh w.r.t. whether or not stream-k should be used. As of right now the logic should be treating AMD GPUs as "Ada Lovelace or newer".

@zhang-hui-yulo
Copy link
Contributor Author

zhang-hui-yulo commented Jan 8, 2026

I think these are the following things I want to do in this PR:

  1. Do some basic tuning for RDNA4 config in fattn and fattn-mma based on llama 8B.
  2. Try to enable streamk in fattn-coomon, hopefully there is no coding bug as I didn't test streamk.

TODO in the following PRs:

  1. gmem to smem transportation, I will be appreciated if you can help to finish this as I haven't spent much effort on loading, gmem to smem transportation might be faster than mma transportation as it could by hidden by gemm main loop.
  2. More perf tuning.
  3. RDNA3 support as I might find a good way to handle RDNA3 fused gemm without smem, of course this needs gmem to smem transportation.

In the meaning time, please give me your comments about this PR, then I can update the code as same time.

@zhang-hui-yulo
Copy link
Contributor Author

zhang-hui-yulo commented Jan 8, 2026

Hello @JohannesGaessler

I just do some basic tuning for DeepSeek-R1-Distill-Qwen-1.5B, I cannot make any perf change for Meta-Llama-3-8B-Instruct, sorry I'm still not familiar with llama.cpp model parameters.

I would suggest to keep this PR simple and make more changes in the future.

CUDA_VISIBLE_DEVICES=0 ./build/bin/llama-bench --model ../../models/DeepSeek-R1-Distill-Qwen-1.5B/DeepSeek-R1-Distill-Qwen-1.5B_f16.gguf -r 4 -fa 1 -n 0 -d 32768 -p "512,1-512*2" --progress -o sql | sqlite3 ../../models/DeepSeek-R1-Distill-Qwen-1.5B/DeepSeek-R1-Distill-Qwen-1.5B_f16.sqlite
Model Test t/s master t/s fattn_for_rdna4 Speedup
qwen2 1.5B F16 pp1@d32768 93.47 43.18 0.46
qwen2 1.5B F16 pp2@d32768 133.78 76.90 0.57
qwen2 1.5B F16 pp4@d32768 132.69 140.32 1.06
qwen2 1.5B F16 pp8@d32768 197.53 220.01 1.11
qwen2 1.5B F16 pp16@d32768 352.13 374.62 1.06
qwen2 1.5B F16 pp32@d32768 588.38 657.14 1.12
qwen2 1.5B F16 pp64@d32768 865.42 766.73 0.89
qwen2 1.5B F16 pp128@d32768 978.27 864.89 0.88
qwen2 1.5B F16 pp256@d32768 1195.13 1352.40 1.13
qwen2 1.5B F16 pp512@d32768 1297.93 1394.73 1.07

Best Regards
Hui

const int nblocks_stream_k = max_blocks;

const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75;
const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding amd_wmma_available here is better from an intent standpoint, but this change dosent do anything in practice, just fyi

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just follow Johannes' suggestion, looks like that tiles_efficiency_percent < 75 is enought for RDNA4.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants