Skip to content

Commit 2f5a4e1

Browse files
authored
vulkan: move common FA code to flash_attn_base.comp (#13556)
* vulkan: move common FA code to flash_attn_base.comp * vulkan: move common FA index/stride setup code to flash_attn_base.comp * build fix
1 parent 4f41ee1 commit 2f5a4e1

File tree

4 files changed

+170
-417
lines changed

4 files changed

+170
-417
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 3 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -9,60 +9,13 @@
99
#extension GL_KHR_shader_subgroup_shuffle : enable
1010

1111
#include "types.comp"
12+
#include "flash_attn_base.comp"
1213

13-
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
14-
15-
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
16-
layout (constant_id = 1) const uint32_t Br = 1;
17-
layout (constant_id = 2) const uint32_t Bc = 32;
18-
layout (constant_id = 3) const uint32_t D = 32;
19-
20-
layout (constant_id = 5) const uint32_t D_split = 16;
2114
const uint32_t D_per_thread = D / D_split;
2215

2316
const uint32_t cols_per_iter = WorkGroupSize / D_split;
2417
const uint32_t cols_per_thread = Bc / cols_per_iter;
2518

26-
layout (push_constant) uniform parameter {
27-
uint32_t N;
28-
uint32_t KV;
29-
30-
uint32_t ne1;
31-
uint32_t ne2;
32-
uint32_t ne3;
33-
34-
uint32_t neq2;
35-
uint32_t neq3;
36-
uint32_t nek2;
37-
uint32_t nek3;
38-
uint32_t nev2;
39-
uint32_t nev3;
40-
uint32_t nem1;
41-
42-
uint32_t nb01;
43-
uint32_t nb02;
44-
uint32_t nb03;
45-
uint32_t nb11;
46-
uint32_t nb12;
47-
uint32_t nb13;
48-
uint32_t nb21;
49-
uint32_t nb22;
50-
uint32_t nb23;
51-
uint32_t nb31;
52-
53-
float scale;
54-
float max_bias;
55-
float logit_softcap;
56-
57-
uint32_t mask;
58-
uint32_t n_head_log2;
59-
float m0;
60-
float m1;
61-
62-
uint32_t gqa_ratio;
63-
uint32_t split_kv;
64-
uint32_t k_num;
65-
} p;
6619

6720
layout (binding = 0) readonly buffer Q {float data_q[];};
6821
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
@@ -71,39 +24,6 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
7124
layout (binding = 2) readonly buffer V {float16_t data_v[];};
7225
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
7326
layout (binding = 3) readonly buffer M {float16_t data_m[];};
74-
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
75-
76-
#if defined(A_TYPE_PACKED16)
77-
#define BINDING_IDX_K 0
78-
#define BINDING_IDX_V 1
79-
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
80-
#endif
81-
82-
#if defined(DATA_A_Q4_0)
83-
#define BLOCK_BYTE_SIZE 18
84-
85-
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
86-
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
87-
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
88-
uint shift = (iqs & 0x10) >> 2;
89-
vui_lo >>= shift;
90-
vui_hi >>= shift;
91-
92-
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
93-
}
94-
#endif
95-
96-
#if defined(DATA_A_Q8_0)
97-
#define BLOCK_BYTE_SIZE 34
98-
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
99-
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
100-
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
101-
102-
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
103-
}
104-
#endif
105-
106-
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
10727

10828
// Store the output when doing grouped query attention.
10929
// Rows index by Q's dimension 2, and the first N rows are valid.
@@ -114,27 +34,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
11434
return elem;
11535
}
11636

117-
// Store column zero. This is used to save per-row m and L values for split_k.
118-
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
119-
{
120-
if (r < N && c == 0) {
121-
uint32_t offset = iq2 + r;
122-
data_o[o_offset + offset] = D_TYPE(elem);
123-
}
124-
return elem;
125-
}
126-
127-
// Load the slope matrix, indexed by Q's dimension 2.
128-
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
129-
{
130-
const uint32_t h = iq2 + (r % p.gqa_ratio);
131-
132-
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
133-
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
134-
135-
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
136-
}
137-
13837
shared FLOAT_TYPE tmpsh[WorkGroupSize];
13938
shared vec4 tmpshv4[WorkGroupSize];
14039

@@ -146,58 +45,12 @@ void main() {
14645
init_iq_shmem(gl_WorkGroupSize);
14746
#endif
14847

149-
const uint32_t tid = gl_LocalInvocationIndex;
150-
const uint32_t N = p.N;
151-
const uint32_t KV = p.KV;
48+
init_indices();
15249

50+
const uint32_t tid = gl_LocalInvocationIndex;
15351
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
15452
const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
15553

156-
uint32_t i = gl_WorkGroupID.x;
157-
uint32_t split_k_index = 0;
158-
159-
if (p.k_num > 1) {
160-
i = 0;
161-
split_k_index = gl_WorkGroupID.x;
162-
}
163-
164-
const uint32_t Tr = CEIL_DIV(N, Br);
165-
166-
const uint32_t start_j = split_k_index * p.split_kv / Bc;
167-
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
168-
169-
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
170-
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
171-
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
172-
const uint32_t iq3 = gl_WorkGroupID.z;
173-
174-
// broadcast factors
175-
const uint32_t rk2 = p.neq2/p.nek2;
176-
const uint32_t rk3 = p.neq3/p.nek3;
177-
178-
const uint32_t rv2 = p.neq2/p.nev2;
179-
const uint32_t rv3 = p.neq3/p.nev3;
180-
181-
// k indices
182-
const uint32_t ik3 = iq3 / rk3;
183-
const uint32_t ik2 = iq2 / rk2;
184-
185-
// v indices
186-
const uint32_t iv3 = iq3 / rv3;
187-
const uint32_t iv2 = iq2 / rv2;
188-
189-
// nb?1 are already divided by the type size and are in units of elements.
190-
// When using grouped query attention, Q is indexed by iq2, so the stride
191-
// should be nb02 (which is in bytes).
192-
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
193-
uint32_t k_stride = p.nb11;
194-
uint32_t v_stride = p.nb21;
195-
// When using grouped query attention, all rows use the same mask (stride 0).
196-
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
197-
// that prevents the compiler from folding the "&" through the select
198-
// and breaking the alignment detection.
199-
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
200-
20154
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
20255

20356
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
2+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
3+
4+
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
5+
layout (constant_id = 1) const uint32_t Br = 1;
6+
layout (constant_id = 2) const uint32_t Bc = 32;
7+
layout (constant_id = 3) const uint32_t D = 32;
8+
layout (constant_id = 4) const uint32_t Clamp = 0;
9+
layout (constant_id = 5) const uint32_t D_split = 16;
10+
11+
12+
layout (push_constant) uniform parameter {
13+
uint32_t N;
14+
uint32_t KV;
15+
16+
uint32_t ne1;
17+
uint32_t ne2;
18+
uint32_t ne3;
19+
20+
uint32_t neq2;
21+
uint32_t neq3;
22+
uint32_t nek2;
23+
uint32_t nek3;
24+
uint32_t nev2;
25+
uint32_t nev3;
26+
uint32_t nem1;
27+
28+
uint32_t nb01;
29+
uint32_t nb02;
30+
uint32_t nb03;
31+
uint32_t nb11;
32+
uint32_t nb12;
33+
uint32_t nb13;
34+
uint32_t nb21;
35+
uint32_t nb22;
36+
uint32_t nb23;
37+
uint32_t nb31;
38+
39+
float scale;
40+
float max_bias;
41+
float logit_softcap;
42+
43+
uint32_t mask;
44+
uint32_t n_head_log2;
45+
float m0;
46+
float m1;
47+
48+
uint32_t gqa_ratio;
49+
uint32_t split_kv;
50+
uint32_t k_num;
51+
} p;
52+
53+
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
54+
55+
#if defined(A_TYPE_PACKED16)
56+
#define BINDING_IDX_K 0
57+
#define BINDING_IDX_V 1
58+
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
59+
#endif
60+
61+
#if defined(DATA_A_Q4_0)
62+
#define BLOCK_BYTE_SIZE 18
63+
64+
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
65+
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
66+
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
67+
uint shift = (iqs & 0x10) >> 2;
68+
vui_lo >>= shift;
69+
vui_hi >>= shift;
70+
71+
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
72+
}
73+
#endif
74+
75+
#if defined(DATA_A_Q8_0)
76+
#define BLOCK_BYTE_SIZE 34
77+
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
78+
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
79+
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
80+
81+
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
82+
}
83+
#endif
84+
85+
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
86+
87+
88+
// Store column zero. This is used to save per-row m and L values for split_k.
89+
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
90+
{
91+
if (r < N && c == 0) {
92+
uint32_t offset = iq2 + r;
93+
data_o[o_offset + offset] = D_TYPE(elem);
94+
}
95+
return elem;
96+
}
97+
98+
// Load the slope matrix, indexed by Q's dimension 2.
99+
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
100+
{
101+
const uint32_t h = iq2 + (r % p.gqa_ratio);
102+
103+
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
104+
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
105+
106+
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
107+
}
108+
109+
uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
110+
iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
111+
q_stride, k_stride, v_stride, m_stride;
112+
113+
void init_indices()
114+
{
115+
N = p.N;
116+
KV = p.KV;
117+
118+
i = gl_WorkGroupID.x;
119+
split_k_index = 0;
120+
121+
if (p.k_num > 1) {
122+
i = 0;
123+
split_k_index = gl_WorkGroupID.x;
124+
}
125+
126+
Tr = CEIL_DIV(N, Br);
127+
128+
start_j = split_k_index * p.split_kv / Bc;
129+
end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
130+
131+
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
132+
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
133+
iq2 = gl_WorkGroupID.y * p.gqa_ratio;
134+
iq3 = gl_WorkGroupID.z;
135+
136+
// broadcast factors
137+
rk2 = p.neq2/p.nek2;
138+
rk3 = p.neq3/p.nek3;
139+
140+
rv2 = p.neq2/p.nev2;
141+
rv3 = p.neq3/p.nev3;
142+
143+
// k indices
144+
ik3 = iq3 / rk3;
145+
ik2 = iq2 / rk2;
146+
147+
// v indices
148+
iv3 = iq3 / rv3;
149+
iv2 = iq2 / rv2;
150+
151+
// nb?1 are already divided by the type size and are in units of elements.
152+
// When using grouped query attention, Q is indexed by iq2, so the stride
153+
// should be nb02 (which is in bytes).
154+
q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
155+
k_stride = p.nb11;
156+
v_stride = p.nb21;
157+
// When using grouped query attention, all rows use the same mask (stride 0).
158+
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
159+
// that prevents the compiler from folding the "&" through the select
160+
// and breaking the alignment detection.
161+
m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
162+
}

0 commit comments

Comments
 (0)