From dab1f0284e7edb719cd7d1b36483ad4a855880ce Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 5 Apr 2025 11:27:47 -0500 Subject: [PATCH] vulkan: Use fp16 for the flash attention P*V multiplication This is consistent with the ggml-cuda behavior and the mul_mat fallback. --- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 8ddadb8a15de5..a8f4bc41726c2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -330,9 +330,11 @@ void main() { // resize eM by using smear/reduce coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); - O = eMdiag * O; + // multiply with fp16 accumulation, then add to O. + coopmat PV = coopmat(0); + PV = coopMatMulAdd(P_A, V, PV); - O = coopMatMulAdd(P_A, V, O); + O = eMdiag * O + coopmat(PV); } // If there is split_k, then the split_k resolve shader does the final