Skip to content

Commit 4f41ee1

Browse files
authored
vulkan: use scalar FA rather than coopmat2 when N==1 (#13554)
1 parent 3e0be1c commit 4f41ee1

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5872,10 +5872,17 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
58725872
vk_pipeline *pipelines;
58735873
bool small_rows = N <= get_fa_num_small_rows(path);
58745874

5875+
// coopmat1 does not actually support "small rows" (it needs 16 rows).
5876+
// So use scalar instead.
58755877
if (small_rows && path == FA_COOPMAT1) {
58765878
path = FA_SCALAR;
58775879
}
58785880

5881+
// scalar is faster than coopmat2 when N==1
5882+
if (N == 1 && path == FA_COOPMAT2) {
5883+
path = FA_SCALAR;
5884+
}
5885+
58795886
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
58805887

58815888
switch (path) {

0 commit comments

Comments
 (0)