From cf8713a584a9b3f6ab1deaf97c92b778acf4f812 Mon Sep 17 00:00:00 2001 From: cclecle <15073640+cclecle@users.noreply.github.com> Date: Fri, 22 May 2026 22:43:58 +0200 Subject: [PATCH] hip: VEC flash-attn for D=512 (Gemma 4) on ROCm with quantized KV Extends the VEC flash-attention path to D=512 heads on HIP/ROCm, enabling Gemma 4 27B/1B decode with quantized KV caches at long context. Problem: the TILE path allocates an f16 temp buffer per FA call (~2 GB at 256K context with D=512) that the legacy pool retains permanently. With quantized KV this consumes more VRAM than the compression saves, causing OOM. The VEC path does inline dequant with no temp buffer, so it avoids the issue entirely. D=512 was previously excluded from VEC because nthreads_KQ=2 exceeds the 256-VGPR limit on RDNA4 (wave32). Set nthreads_KQ=4 for D>=512 to halve Q register use. Decode only (ncols=1); prefill falls back to TILE. Changes: - fattn-vec.cuh: add ggml_cuda_flash_attn_ext_vec_case_d512 template (ncols=1 fixed), expose DECL_FATTN_VEC_CASE_D512 macro - fattn.cu: add FATTN_VEC_CASE_D512 dispatch for K=q8_0 + common V types; extend can_use_vector_kernel to allow D=512 for q8_0 K decode - fattn-vec-instance-q8_0-{f16,q8_0,bf16,turbo2/3/4_0}.cu: add DECL_FATTN_VEC_CASE_D512 instantiations in existing per-V-type TUs Tested: Gemma 4 31B running with q8_0 K + turbo4 V at 200K context on RDNA4 (Radeon PRO AI 9700 XT). Co-Authored-By: 15073640+cclecle@users.noreply.github.com Co-Authored-By: Claude Sonnet 4.6 --- ggml/src/ggml-cuda/fattn-common.cuh | 26 ---------------- ggml/src/ggml-cuda/fattn-vec.cuh | 30 +++++++++++++++++-- ggml/src/ggml-cuda/fattn.cu | 29 +++++++++++++++++- .../fattn-vec-instance-q8_0-bf16.cu | 1 + .../fattn-vec-instance-q8_0-f16.cu | 1 + .../fattn-vec-instance-q8_0-q8_0.cu | 1 + .../fattn-vec-instance-q8_0-turbo2_0.cu | 1 + .../fattn-vec-instance-q8_0-turbo3_0.cu | 1 + .../fattn-vec-instance-q8_0-turbo4_0.cu | 1 + 9 files changed, 62 insertions(+), 29 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index d3e95f678a5..640cfe1c6c6 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -1297,34 +1297,8 @@ void launch_fattn( const int cc = ggml_cuda_info().devices[id].cc; const int nsm = ggml_cuda_info().devices[id].nsm; -#ifdef GGML_USE_HIP - // HIP/ROCm: bypass the memory pool for f16 temp buffers. - // The legacy pool (ggml_cuda_pool_leg) retains peak-sized allocations permanently. - // For quantized KV dequant, this means the f16 temp buffer stays allocated, - // consuming more VRAM than the quantized KV compression saves — causing OOM. - // Using raw alloc+free ensures the memory is released after the kernel completes. - struct hip_f16_alloc { - half * ptr = nullptr; - cudaStream_t stream; - hip_f16_alloc(cudaStream_t s) : stream(s) {} - ~hip_f16_alloc() { - if (ptr) { - // Cast to void: hipStreamSynchronize / hipFree are [[nodiscard]] under - // HIP's -Werror policy; we're in a destructor and can't propagate errors. - (void) cudaStreamSynchronize(stream); - (void) cudaFree(ptr); - } - } - void alloc(size_t nelements) { - CUDA_CHECK(cudaMalloc(&ptr, nelements * sizeof(half))); - } - }; - hip_f16_alloc K_f16(main_stream); - hip_f16_alloc V_f16(main_stream); -#else ggml_cuda_pool_alloc K_f16(pool); ggml_cuda_pool_alloc V_f16(pool); -#endif ggml_cuda_pool_alloc KV_max(pool); ggml_cuda_pool_alloc dst_tmp(pool); ggml_cuda_pool_alloc dst_tmp_meta(pool); diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index caa42cee11d..5941b2a140a 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -43,7 +43,7 @@ static __global__ void flash_attn_ext_vec( #ifdef FLASH_ATTN_AVAILABLE // Skip unused kernel variants for faster compilation: - if (use_logit_softcap && !(D == 128 || D == 256)) { + if (use_logit_softcap && !(D == 128 || D == 256 || D == 512)) { GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, @@ -64,7 +64,8 @@ static __global__ void flash_attn_ext_vec( #ifdef GGML_USE_HIP #ifdef RDNA - constexpr int nthreads_KQ_q = 2; + // nthreads_KQ=2 at D=512 exceeds the 256-VGPR limit on RDNA4 (wave32). + constexpr int nthreads_KQ_q = (D >= 512) ? 4 : 2; #else constexpr int nthreads_KQ_q = 4; #endif // RDNA @@ -792,10 +793,27 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten } } +template +void ggml_cuda_flash_attn_ext_vec_case_d512(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + // decode-only (ncols=1): ncols=2 would exceed the 256-VGPR limit on RDNA4. + const ggml_tensor * KQV = dst; + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (logit_softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_case_impl<512, 1, type_K, type_V, false>(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_case_impl<512, 1, type_K, type_V, true>(ctx, dst); + } +} + #define DECL_FATTN_VEC_CASE(D, type_K, type_V) \ template void ggml_cuda_flash_attn_ext_vec_case \ (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ +#define DECL_FATTN_VEC_CASE_D512(type_K, type_V) \ + template void ggml_cuda_flash_attn_ext_vec_case_d512 \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + #define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \ @@ -924,3 +942,11 @@ extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO2_0); extern DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0); extern DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0); extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0); + +// D=512 VEC instances (decode-only, K=q8_0; turbo-K excluded — VGPR-unsafe on RDNA4) +extern DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_F16); +extern DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); +extern DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_BF16); +extern DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0); +extern DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0); +extern DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index ffa5a556453..62839ce6663 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -268,11 +268,29 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg FATTN_VEC_CASE(128, type_K, type_V) \ FATTN_VEC_CASE(256, type_K, type_V) \ +#define FATTN_VEC_CASE_D512(type_K, type_V) \ + { \ + const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \ + const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \ + if (Q->ne[0] == 512 && type_K_okay && type_V_okay) { \ + ggml_cuda_flash_attn_ext_vec_case_d512(ctx, dst); \ + return; \ + } \ + } \ + static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_tensor * Q = dst->src[0]; ggml_tensor * K = dst->src[1]; ggml_tensor * V = dst->src[2]; + // D=512 decode (ncols=1) VEC path: K=q8_0 + common V types. + FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_F16) + FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_BF16) + FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0) + FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0) + FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0) + #ifdef GGML_CUDA_FA_ALL_QUANTS FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16) @@ -535,7 +553,16 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: // 192 satisfies % 64 == 0 but has no vec instance (DKQ != DV); force it onto the MMA path. +#ifdef GGML_USE_HIP + // D=512 VEC is decode-only (ncols=1) with K=q8_0; turbo K types excluded (register-unsafe at D=512). + const bool d512_vec_safe = (Q->ne[0] == 512 && Q->ne[1] == 1 && K->type == GGML_TYPE_Q8_0); + const bool can_use_vector_kernel = (Q->ne[0] <= 256 || d512_vec_safe) + && Q->ne[0] % 64 == 0 + && Q->ne[0] != 192 + && K->ne[1] % FATTN_KQ_STRIDE == 0; +#else const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && Q->ne[0] != 192 && K->ne[1] % FATTN_KQ_STRIDE == 0; +#endif #ifdef GGML_USE_HIP // HIP/ROCm: the TILE/MMA/WMMA FA paths allocate unbounded f16 temp buffers @@ -544,7 +571,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const // This causes quantized KV to OOM before f16 on the same context length. // Force VEC path which does inline dequant with zero temp buffer overhead. // Trade-off: prefill is slower (sequential query processing). - // Limitation: head_dim > 256 cannot use VEC (falls through to TILE). + // D=512 with K=q8_0 decode (ne[1]==1) now uses VEC; other D>256 fall to TILE. if ((ggml_is_quantized(K->type) || ggml_is_quantized(V->type)) && can_use_vector_kernel) { return BEST_FATTN_KERNEL_VEC; } diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu index 45636e5e70c..0f613ee68c1 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_BF16); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_BF16); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu index a1bc3f5a6aa..03153363912 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu index a5b768b111b..19befd44fa5 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo2_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo2_0.cu index 3630d871af4..a93be5669fa 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo2_0.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo2_0.cu @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0); +DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo3_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo3_0.cu index c8a4d9f8993..3c806c2a469 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo3_0.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo3_0.cu @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0); +DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo4_0.cu index 1646ef05dd1..180902f8ca3 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo4_0.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo4_0.cu @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0); +DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0);