@@ -1918,19 +1918,19 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
19181918 const bool has_mask = op->src [3 ] != nullptr ;
19191919
19201920 if (ggml_metal_op_flash_attn_ext_use_vec (op)) {
1921- const bool has_kvpad = ne11 % 32 != 0 ;
1921+ const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0 ;
19221922
19231923 if (has_kvpad) {
1924- res += 32 *(
1924+ res += OP_FLASH_ATTN_EXT_VEC_NCPSG *(
19251925 nb11*ne12*ne13 +
19261926 nb21*ne22*ne23 +
19271927 (has_mask ? ggml_type_size (GGML_TYPE_F16)*ne31*ne32*ne33 : 0 ));
19281928 }
19291929 } else {
1930- const bool has_kvpad = ne11 % 64 != 0 ;
1930+ const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0 ;
19311931
19321932 if (has_kvpad) {
1933- res += 64 *(
1933+ res += OP_FLASH_ATTN_EXT_NCPSG *(
19341934 nb11*ne12*ne13 +
19351935 nb21*ne22*ne23 +
19361936 (has_mask ? ggml_type_size (GGML_TYPE_F16)*ne31*ne32*ne33 : 0 ));
@@ -1940,6 +1940,44 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
19401940 return res;
19411941}
19421942
1943+ size_t ggml_metal_op_flash_attn_ext_extra_blk (const ggml_tensor * op) {
1944+ assert (op->op == GGML_OP_FLASH_ATTN_EXT);
1945+
1946+ GGML_TENSOR_LOCALS ( int32_t , ne0, op->src [0 ], ne);
1947+ // GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1948+ // GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1949+ // GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1950+ // GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1951+ // GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1952+ GGML_TENSOR_LOCALS ( int32_t , ne3, op->src [3 ], ne);
1953+ GGML_TENSOR_LOCALS (uint64_t , nb3, op->src [3 ], nb);
1954+
1955+ size_t res = 0 ;
1956+
1957+ const bool has_mask = op->src [3 ] != nullptr ;
1958+
1959+ if (!has_mask) {
1960+ return res;
1961+ }
1962+
1963+ const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec (op);
1964+
1965+ // this optimization is not useful for the vector kernels
1966+ if (is_vec) {
1967+ return res;
1968+ }
1969+
1970+ const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
1971+ const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
1972+
1973+ const int64_t ne1 = (ne01 + nqptg - 1 )/nqptg;
1974+ const int64_t ne0 = (ne30 + ncpsg - 1 )/ncpsg;
1975+
1976+ res += GGML_PAD (ggml_type_size (GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32 );
1977+
1978+ return res;
1979+ }
1980+
19431981size_t ggml_metal_op_flash_attn_ext_extra_tmp (const ggml_tensor * op) {
19441982 assert (op->op == GGML_OP_FLASH_ATTN_EXT);
19451983
@@ -2034,18 +2072,23 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
20342072 ggml_metal_buffer_id bid_pad = bid_dst;
20352073 bid_pad.offs += ggml_nbytes (op);
20362074
2037- ggml_metal_buffer_id bid_tmp = bid_pad;
2038- bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_pad (op);
2075+ ggml_metal_buffer_id bid_blk = bid_pad;
2076+ bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad (op);
2077+
2078+ ggml_metal_buffer_id bid_tmp = bid_blk;
2079+ bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk (op);
20392080
20402081 if (!ggml_metal_op_flash_attn_ext_use_vec (op)) {
20412082 // half8x8 kernel
2042- const int64_t nqptg = 8 ; // queries per threadgroup !! sync with kernel template arguments !!
2043- const int64_t ncpsg = 64 ; // cache values per simdgroup !! sync with kernel template arguments !!
2083+ const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
2084+ const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG ; // cache values per simdgroup
20442085
20452086 GGML_ASSERT (nqptg <= 32 );
20462087 GGML_ASSERT (nqptg % 8 == 0 );
20472088 GGML_ASSERT (ncpsg % 32 == 0 );
20482089
2090+ bool need_sync = false ;
2091+
20492092 const bool has_kvpad = ne11 % ncpsg != 0 ;
20502093
20512094 if (has_kvpad) {
@@ -2083,11 +2126,46 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
20832126
20842127 ggml_metal_encoder_dispatch_threadgroups (enc, ncpsg, std::max (ne12, ne32), std::max (ne13, ne33), 32 , 1 , 1 );
20852128
2086- ggml_metal_op_concurrency_reset (ctx) ;
2129+ need_sync = true ;
20872130 } else {
20882131 assert (ggml_metal_op_flash_attn_ext_extra_pad (op) == 0 );
20892132 }
20902133
2134+ if (has_mask) {
2135+ assert (ggml_metal_op_flash_attn_ext_extra_blk (op) != 0 );
2136+
2137+ ggml_metal_kargs_flash_attn_ext_blk args0 = {
2138+ /* .ne01 =*/ ne01,
2139+ /* .ne30 =*/ ne30,
2140+ /* .ne31 =*/ ne31,
2141+ /* .ne32 =*/ ne32,
2142+ /* .ne33 =*/ ne33,
2143+ /* .nb31 =*/ nb31,
2144+ /* .nb32 =*/ nb32,
2145+ /* .nb33 =*/ nb33,
2146+ };
2147+
2148+ ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk (lib, op, nqptg, ncpsg);
2149+
2150+ ggml_metal_encoder_set_pipeline (enc, pipeline0);
2151+ ggml_metal_encoder_set_bytes (enc, &args0, sizeof (args0), 0 );
2152+ ggml_metal_encoder_set_buffer (enc, bid_src3, 1 );
2153+ ggml_metal_encoder_set_buffer (enc, bid_blk, 2 );
2154+
2155+ const int32_t nblk1 = ((ne01 + nqptg - 1 )/nqptg);
2156+ const int32_t nblk0 = ((ne30 + ncpsg - 1 )/ncpsg);
2157+
2158+ ggml_metal_encoder_dispatch_threadgroups (enc, nblk0, nblk1, ne32*ne33, 32 , 1 , 1 );
2159+
2160+ need_sync = true ;
2161+ } else {
2162+ assert (ggml_metal_op_flash_attn_ext_extra_blk (op) == 0 );
2163+ }
2164+
2165+ if (need_sync) {
2166+ ggml_metal_op_concurrency_reset (ctx);
2167+ }
2168+
20912169 const int is_q = ggml_is_quantized (op->src [1 ]->type ) ? 1 : 0 ;
20922170
20932171 // 2*(2*ncpsg)
@@ -2164,22 +2242,25 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
21642242 ggml_metal_encoder_set_buffer (enc, bid_src3, 4 );
21652243 ggml_metal_encoder_set_buffer (enc, bid_src4, 5 );
21662244 ggml_metal_encoder_set_buffer (enc, bid_pad, 6 );
2167- ggml_metal_encoder_set_buffer (enc, bid_dst, 7 );
2245+ ggml_metal_encoder_set_buffer (enc, bid_blk, 7 );
2246+ ggml_metal_encoder_set_buffer (enc, bid_dst, 8 );
21682247
21692248 ggml_metal_encoder_set_threadgroup_memory_size (enc, smem, 0 );
21702249
21712250 ggml_metal_encoder_dispatch_threadgroups (enc, (ne01 + nqptg - 1 )/nqptg, ne02, ne03, 32 , nsg, 1 );
21722251#undef FATTN_SMEM
21732252 } else {
21742253 // half4x4 kernel
2175- const int64_t nqptg = 1 ; // queries per threadgroup !! sync with kernel template arguments !!
2176- const int64_t ncpsg = 32 ; // cache values per simdgroup !! sync with kernel template arguments !!
2177- const int64_t nkpsg = 1 *ncpsg;
2254+ const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
2255+ const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG ; // cache values per simdgroup !! sync with kernel template arguments !!
2256+ const int nkpsg = 1 *ncpsg;
21782257
21792258 GGML_ASSERT (nqptg <= 32 );
21802259 GGML_ASSERT (nqptg % 1 == 0 );
21812260 GGML_ASSERT (ncpsg % 32 == 0 );
21822261
2262+ bool need_sync = false ;
2263+
21832264 const bool has_kvpad = ne11 % ncpsg != 0 ;
21842265
21852266 if (has_kvpad) {
@@ -2217,11 +2298,15 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
22172298
22182299 ggml_metal_encoder_dispatch_threadgroups (enc, ncpsg, std::max (ne12, ne32), std::max (ne13, ne33), 32 , 1 , 1 );
22192300
2220- ggml_metal_op_concurrency_reset (ctx) ;
2301+ need_sync = true ;
22212302 } else {
22222303 assert (ggml_metal_op_flash_attn_ext_extra_pad (op) == 0 );
22232304 }
22242305
2306+ if (need_sync) {
2307+ ggml_metal_op_concurrency_reset (ctx);
2308+ }
2309+
22252310 // ne00 + 2*ncpsg*(nsg)
22262311 // for each query, we load it as f16 in shared memory (ne00)
22272312 // and store the soft_max values and the mask
0 commit comments