Skip to content

Commit ac97d9b

Browse files
jeffbolznvpwilkin
authored andcommitted
vulkan: Support FA with K/V in F32 (ggml-org#16543)
1 parent 9f9ee43 commit ac97d9b

File tree

4 files changed

+49
-8
lines changed

4 files changed

+49
-8
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2649,18 +2649,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
26492649
} \
26502650
}
26512651

2652+
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
26522653
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
26532654
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
26542655
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
26552656
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
26562657
if (device->coopmat1_fa_support) {
2658+
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
26572659
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
26582660
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
26592661
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
26602662
}
26612663
#endif
26622664
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
26632665
if (device->coopmat2) {
2666+
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2)
26642667
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
26652668
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
26662669
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
@@ -7457,8 +7460,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
74577460
}
74587461

74597462
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
7460-
const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
7461-
const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
7463+
uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
7464+
uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
7465+
7466+
// For F32, the shader treats it as a block of size 4 (for vec4 loads)
7467+
if (k->type == GGML_TYPE_F32) {
7468+
k_stride /= 4;
7469+
}
7470+
if (v->type == GGML_TYPE_F32) {
7471+
v_stride /= 4;
7472+
}
74627473

74637474
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
74647475
bool aligned = (KV % alignment) == 0 &&
@@ -12660,6 +12671,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1266012671
}
1266112672
switch (op->src[1]->type) {
1266212673
case GGML_TYPE_F16:
12674+
case GGML_TYPE_F32:
1266312675
case GGML_TYPE_Q4_0:
1266412676
case GGML_TYPE_Q8_0:
1266512677
// supported in scalar and coopmat2 paths

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11

22
#include "types.glsl"
33

4+
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufF32 {
5+
vec4 block;
6+
};
7+
8+
float16_t dequantFuncF32(const in decodeBufF32 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
9+
{
10+
const vec4 v = bl.block;
11+
const uint idx = coordInBlock[1];
12+
const f16vec4 vf16 = f16vec4(v);
13+
return vf16[idx];
14+
}
15+
416
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
517
block_q4_0_packed16 block;
618
};
@@ -717,4 +729,6 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
717729
#define dequantFuncA dequantFuncIQ4_NL
718730
#elif defined(DATA_A_MXFP4)
719731
#define dequantFuncA dequantFuncMXFP4
732+
#elif defined(DATA_A_F32)
733+
#define dequantFuncA dequantFuncF32
720734
#endif

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,31 @@ layout (binding = 4) readonly buffer S {float data_s[];};
6464

6565
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
6666

67-
#if defined(A_TYPE_PACKED16)
6867
#define BINDING_IDX_K 0
6968
#define BINDING_IDX_V 1
69+
#if defined(DATA_A_F32)
70+
layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed;
71+
layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed;
72+
#elif defined(A_TYPE_PACKED16)
7073
layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
7174
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
7275
#endif
7376

77+
#if defined(DATA_A_F32)
78+
#undef BLOCK_SIZE
79+
#define BLOCK_SIZE 4
80+
#define BLOCK_BYTE_SIZE 16
81+
82+
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
83+
// iqs is currently always zero in the flash attention shaders
84+
if (binding_idx == BINDING_IDX_K) {
85+
return k_packed.k_data_packed[a_offset + ib];
86+
} else {
87+
return v_packed.v_data_packed[a_offset + ib];
88+
}
89+
}
90+
#endif
91+
7492
#if defined(DATA_A_Q4_0)
7593
#define BLOCK_BYTE_SIZE 18
7694

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -611,9 +611,6 @@ void process_shaders() {
611611
}
612612

613613
for (const auto& tname : type_names) {
614-
if (tname == "f32") {
615-
continue;
616-
}
617614
if (tname == "bf16") continue;
618615

619616
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
@@ -630,7 +627,7 @@ void process_shaders() {
630627
if (tname == "f16") {
631628
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
632629
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
633-
} else if (tname == "q4_0" || tname == "q8_0") {
630+
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
634631
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
635632
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
636633
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
@@ -639,7 +636,7 @@ void process_shaders() {
639636
if (tname == "f16") {
640637
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
641638
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
642-
} else if (tname == "q4_0" || tname == "q8_0") {
639+
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
643640
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
644641
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
645642
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);

0 commit comments

Comments
 (0)