diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index d3e95f678a5..640cfe1c6c6 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -1297,34 +1297,8 @@ void launch_fattn( const int cc = ggml_cuda_info().devices[id].cc; const int nsm = ggml_cuda_info().devices[id].nsm; -#ifdef GGML_USE_HIP - // HIP/ROCm: bypass the memory pool for f16 temp buffers. - // The legacy pool (ggml_cuda_pool_leg) retains peak-sized allocations permanently. - // For quantized KV dequant, this means the f16 temp buffer stays allocated, - // consuming more VRAM than the quantized KV compression saves — causing OOM. - // Using raw alloc+free ensures the memory is released after the kernel completes. - struct hip_f16_alloc { - half * ptr = nullptr; - cudaStream_t stream; - hip_f16_alloc(cudaStream_t s) : stream(s) {} - ~hip_f16_alloc() { - if (ptr) { - // Cast to void: hipStreamSynchronize / hipFree are [[nodiscard]] under - // HIP's -Werror policy; we're in a destructor and can't propagate errors. - (void) cudaStreamSynchronize(stream); - (void) cudaFree(ptr); - } - } - void alloc(size_t nelements) { - CUDA_CHECK(cudaMalloc(&ptr, nelements * sizeof(half))); - } - }; - hip_f16_alloc K_f16(main_stream); - hip_f16_alloc V_f16(main_stream); -#else ggml_cuda_pool_alloc K_f16(pool); ggml_cuda_pool_alloc V_f16(pool); -#endif ggml_cuda_pool_alloc KV_max(pool); ggml_cuda_pool_alloc dst_tmp(pool); ggml_cuda_pool_alloc dst_tmp_meta(pool); diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index caa42cee11d..5941b2a140a 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -43,7 +43,7 @@ static __global__ void flash_attn_ext_vec( #ifdef FLASH_ATTN_AVAILABLE // Skip unused kernel variants for faster compilation: - if (use_logit_softcap && !(D == 128 || D == 256)) { + if (use_logit_softcap && !(D == 128 || D == 256 || D == 512)) { GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, @@ -64,7 +64,8 @@ static __global__ void flash_attn_ext_vec( #ifdef GGML_USE_HIP #ifdef RDNA - constexpr int nthreads_KQ_q = 2; + // nthreads_KQ=2 at D=512 exceeds the 256-VGPR limit on RDNA4 (wave32). + constexpr int nthreads_KQ_q = (D >= 512) ? 4 : 2; #else constexpr int nthreads_KQ_q = 4; #endif // RDNA @@ -792,10 +793,27 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten } } +template +void ggml_cuda_flash_attn_ext_vec_case_d512(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + // decode-only (ncols=1): ncols=2 would exceed the 256-VGPR limit on RDNA4. + const ggml_tensor * KQV = dst; + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (logit_softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_case_impl<512, 1, type_K, type_V, false>(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_case_impl<512, 1, type_K, type_V, true>(ctx, dst); + } +} + #define DECL_FATTN_VEC_CASE(D, type_K, type_V) \ template void ggml_cuda_flash_attn_ext_vec_case \ (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ +#define DECL_FATTN_VEC_CASE_D512(type_K, type_V) \ + template void ggml_cuda_flash_attn_ext_vec_case_d512 \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + #define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \ @@ -924,3 +942,11 @@ extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO2_0); extern DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0); extern DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0); extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0); + +// D=512 VEC instances (decode-only, K=q8_0; turbo-K excluded — VGPR-unsafe on RDNA4) +extern DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_F16); +extern DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); +extern DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_BF16); +extern DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0); +extern DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0); +extern DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index ffa5a556453..62839ce6663 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -268,11 +268,29 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg FATTN_VEC_CASE(128, type_K, type_V) \ FATTN_VEC_CASE(256, type_K, type_V) \ +#define FATTN_VEC_CASE_D512(type_K, type_V) \ + { \ + const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \ + const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \ + if (Q->ne[0] == 512 && type_K_okay && type_V_okay) { \ + ggml_cuda_flash_attn_ext_vec_case_d512(ctx, dst); \ + return; \ + } \ + } \ + static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_tensor * Q = dst->src[0]; ggml_tensor * K = dst->src[1]; ggml_tensor * V = dst->src[2]; + // D=512 decode (ncols=1) VEC path: K=q8_0 + common V types. + FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_F16) + FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_BF16) + FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0) + FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0) + FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0) + #ifdef GGML_CUDA_FA_ALL_QUANTS FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16) @@ -535,7 +553,16 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: // 192 satisfies % 64 == 0 but has no vec instance (DKQ != DV); force it onto the MMA path. +#ifdef GGML_USE_HIP + // D=512 VEC is decode-only (ncols=1) with K=q8_0; turbo K types excluded (register-unsafe at D=512). + const bool d512_vec_safe = (Q->ne[0] == 512 && Q->ne[1] == 1 && K->type == GGML_TYPE_Q8_0); + const bool can_use_vector_kernel = (Q->ne[0] <= 256 || d512_vec_safe) + && Q->ne[0] % 64 == 0 + && Q->ne[0] != 192 + && K->ne[1] % FATTN_KQ_STRIDE == 0; +#else const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && Q->ne[0] != 192 && K->ne[1] % FATTN_KQ_STRIDE == 0; +#endif #ifdef GGML_USE_HIP // HIP/ROCm: the TILE/MMA/WMMA FA paths allocate unbounded f16 temp buffers @@ -544,7 +571,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const // This causes quantized KV to OOM before f16 on the same context length. // Force VEC path which does inline dequant with zero temp buffer overhead. // Trade-off: prefill is slower (sequential query processing). - // Limitation: head_dim > 256 cannot use VEC (falls through to TILE). + // D=512 with K=q8_0 decode (ne[1]==1) now uses VEC; other D>256 fall to TILE. if ((ggml_is_quantized(K->type) || ggml_is_quantized(V->type)) && can_use_vector_kernel) { return BEST_FATTN_KERNEL_VEC; } diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu index 45636e5e70c..0f613ee68c1 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_BF16); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_BF16); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu index a1bc3f5a6aa..03153363912 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu index a5b768b111b..19befd44fa5 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo2_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo2_0.cu index 3630d871af4..a93be5669fa 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo2_0.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo2_0.cu @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0); +DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo3_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo3_0.cu index c8a4d9f8993..3c806c2a469 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo3_0.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo3_0.cu @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0); +DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo4_0.cu index 1646ef05dd1..180902f8ca3 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo4_0.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo4_0.cu @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0); +DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0);