Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1297,34 +1297,8 @@ void launch_fattn(
const int cc = ggml_cuda_info().devices[id].cc;
const int nsm = ggml_cuda_info().devices[id].nsm;

#ifdef GGML_USE_HIP
// HIP/ROCm: bypass the memory pool for f16 temp buffers.
// The legacy pool (ggml_cuda_pool_leg) retains peak-sized allocations permanently.
// For quantized KV dequant, this means the f16 temp buffer stays allocated,
// consuming more VRAM than the quantized KV compression saves — causing OOM.
// Using raw alloc+free ensures the memory is released after the kernel completes.
struct hip_f16_alloc {
half * ptr = nullptr;
cudaStream_t stream;
hip_f16_alloc(cudaStream_t s) : stream(s) {}
~hip_f16_alloc() {
if (ptr) {
// Cast to void: hipStreamSynchronize / hipFree are [[nodiscard]] under
// HIP's -Werror policy; we're in a destructor and can't propagate errors.
(void) cudaStreamSynchronize(stream);
(void) cudaFree(ptr);
}
}
void alloc(size_t nelements) {
CUDA_CHECK(cudaMalloc(&ptr, nelements * sizeof(half)));
}
};
hip_f16_alloc K_f16(main_stream);
hip_f16_alloc V_f16(main_stream);
#else
ggml_cuda_pool_alloc<half> K_f16(pool);
ggml_cuda_pool_alloc<half> V_f16(pool);
#endif
ggml_cuda_pool_alloc<int> KV_max(pool);
ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
Expand Down
30 changes: 28 additions & 2 deletions ggml/src/ggml-cuda/fattn-vec.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ static __global__ void flash_attn_ext_vec(
#ifdef FLASH_ATTN_AVAILABLE

// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
if (use_logit_softcap && !(D == 128 || D == 256 || D == 512)) {
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
ne00, ne01, ne02, ne03,
Expand All @@ -64,7 +64,8 @@ static __global__ void flash_attn_ext_vec(

#ifdef GGML_USE_HIP
#ifdef RDNA
constexpr int nthreads_KQ_q = 2;
// nthreads_KQ=2 at D=512 exceeds the 256-VGPR limit on RDNA4 (wave32).
constexpr int nthreads_KQ_q = (D >= 512) ? 4 : 2;
#else
constexpr int nthreads_KQ_q = 4;
#endif // RDNA
Expand Down Expand Up @@ -792,10 +793,27 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten
}
}

template <ggml_type type_K, ggml_type type_V>
void ggml_cuda_flash_attn_ext_vec_case_d512(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
// decode-only (ncols=1): ncols=2 would exceed the 256-VGPR limit on RDNA4.
const ggml_tensor * KQV = dst;
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (logit_softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_case_impl<512, 1, type_K, type_V, false>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_case_impl<512, 1, type_K, type_V, true>(ctx, dst);
}
}

#define DECL_FATTN_VEC_CASE(D, type_K, type_V) \
template void ggml_cuda_flash_attn_ext_vec_case \
<D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \

#define DECL_FATTN_VEC_CASE_D512(type_K, type_V) \
template void ggml_cuda_flash_attn_ext_vec_case_d512 \
<type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \

#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16); \
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \
Expand Down Expand Up @@ -924,3 +942,11 @@ extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO2_0);
extern DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0);
extern DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0);
extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0);

// D=512 VEC instances (decode-only, K=q8_0; turbo-K excluded — VGPR-unsafe on RDNA4)
extern DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_F16);
extern DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
extern DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_BF16);
extern DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0);
extern DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0);
extern DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0);
29 changes: 28 additions & 1 deletion ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,29 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
FATTN_VEC_CASE(128, type_K, type_V) \
FATTN_VEC_CASE(256, type_K, type_V) \

#define FATTN_VEC_CASE_D512(type_K, type_V) \
{ \
const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
if (Q->ne[0] == 512 && type_K_okay && type_V_okay) { \
ggml_cuda_flash_attn_ext_vec_case_d512<type_K, type_V>(ctx, dst); \
return; \
} \
} \

static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_tensor * Q = dst->src[0];
ggml_tensor * K = dst->src[1];
ggml_tensor * V = dst->src[2];

// D=512 decode (ncols=1) VEC path: K=q8_0 + common V types.
FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_F16)
FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_BF16)
FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0)
FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0)
FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0)

#ifdef GGML_CUDA_FA_ALL_QUANTS
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16)
Expand Down Expand Up @@ -535,7 +553,16 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const

// For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
// 192 satisfies % 64 == 0 but has no vec instance (DKQ != DV); force it onto the MMA path.
#ifdef GGML_USE_HIP
// D=512 VEC is decode-only (ncols=1) with K=q8_0; turbo K types excluded (register-unsafe at D=512).
const bool d512_vec_safe = (Q->ne[0] == 512 && Q->ne[1] == 1 && K->type == GGML_TYPE_Q8_0);
const bool can_use_vector_kernel = (Q->ne[0] <= 256 || d512_vec_safe)
&& Q->ne[0] % 64 == 0
&& Q->ne[0] != 192
&& K->ne[1] % FATTN_KQ_STRIDE == 0;
#else
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && Q->ne[0] != 192 && K->ne[1] % FATTN_KQ_STRIDE == 0;
#endif

#ifdef GGML_USE_HIP
// HIP/ROCm: the TILE/MMA/WMMA FA paths allocate unbounded f16 temp buffers
Expand All @@ -544,7 +571,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
// This causes quantized KV to OOM before f16 on the same context length.
// Force VEC path which does inline dequant with zero temp buffer overhead.
// Trade-off: prefill is slower (sequential query processing).
// Limitation: head_dim > 256 cannot use VEC (falls through to TILE).
// D=512 with K=q8_0 decode (ne[1]==1) now uses VEC; other D>256 fall to TILE.
if ((ggml_is_quantized(K->type) || ggml_is_quantized(V->type)) && can_use_vector_kernel) {
return BEST_FATTN_KERNEL_VEC;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_BF16);
DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_BF16);
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_F16);
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0);
DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0);
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0);
DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0);
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0);
DECL_FATTN_VEC_CASE_D512(GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0);