@@ -1918,19 +1918,19 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
19181918    const  bool  has_mask = op->src [3 ] != nullptr ;
19191919
19201920    if  (ggml_metal_op_flash_attn_ext_use_vec (op)) {
1921-         const  bool  has_kvpad = ne11 % 32  != 0 ;
1921+         const  bool  has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG  != 0 ;
19221922
19231923        if  (has_kvpad) {
1924-             res += 32 *(
1924+             res += OP_FLASH_ATTN_EXT_VEC_NCPSG *(
19251925                nb11*ne12*ne13 +
19261926                nb21*ne22*ne23 +
19271927                (has_mask ? ggml_type_size (GGML_TYPE_F16)*ne31*ne32*ne33 : 0 ));
19281928        }
19291929    } else  {
1930-         const  bool  has_kvpad = ne11 % 64  != 0 ;
1930+         const  bool  has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG  != 0 ;
19311931
19321932        if  (has_kvpad) {
1933-             res += 64 *(
1933+             res += OP_FLASH_ATTN_EXT_NCPSG *(
19341934                nb11*ne12*ne13 +
19351935                nb21*ne22*ne23 +
19361936                (has_mask ? ggml_type_size (GGML_TYPE_F16)*ne31*ne32*ne33 : 0 ));
@@ -1940,6 +1940,44 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
19401940    return  res;
19411941}
19421942
1943+ size_t  ggml_metal_op_flash_attn_ext_extra_blk (const  ggml_tensor * op) {
1944+     assert (op->op  == GGML_OP_FLASH_ATTN_EXT);
1945+ 
1946+     GGML_TENSOR_LOCALS ( int32_t , ne0, op->src [0 ], ne);
1947+   // GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1948+   // GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1949+   // GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1950+   // GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1951+   // GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1952+     GGML_TENSOR_LOCALS ( int32_t , ne3, op->src [3 ], ne);
1953+     GGML_TENSOR_LOCALS (uint64_t , nb3, op->src [3 ], nb);
1954+ 
1955+     size_t  res = 0 ;
1956+ 
1957+     const  bool  has_mask = op->src [3 ] != nullptr ;
1958+ 
1959+     if  (!has_mask) {
1960+         return  res;
1961+     }
1962+ 
1963+     const  bool  is_vec = ggml_metal_op_flash_attn_ext_use_vec (op);
1964+ 
1965+     //  this optimization is not useful for the vector kernels
1966+     if  (is_vec) {
1967+         return  res;
1968+     }
1969+ 
1970+     const  int  nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
1971+     const  int  ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
1972+ 
1973+     const  int64_t  ne1 = (ne01 + nqptg - 1 )/nqptg;
1974+     const  int64_t  ne0 = (ne30 + ncpsg - 1 )/ncpsg;
1975+ 
1976+     res += GGML_PAD (ggml_type_size (GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32 );
1977+ 
1978+     return  res;
1979+ }
1980+ 
19431981size_t  ggml_metal_op_flash_attn_ext_extra_tmp (const  ggml_tensor * op) {
19441982    assert (op->op  == GGML_OP_FLASH_ATTN_EXT);
19451983
@@ -2034,18 +2072,23 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
20342072    ggml_metal_buffer_id bid_pad = bid_dst;
20352073    bid_pad.offs  += ggml_nbytes (op);
20362074
2037-     ggml_metal_buffer_id bid_tmp = bid_pad;
2038-     bid_tmp.offs  += ggml_metal_op_flash_attn_ext_extra_pad (op);
2075+     ggml_metal_buffer_id bid_blk = bid_pad;
2076+     bid_blk.offs  += ggml_metal_op_flash_attn_ext_extra_pad (op);
2077+ 
2078+     ggml_metal_buffer_id bid_tmp = bid_blk;
2079+     bid_tmp.offs  += ggml_metal_op_flash_attn_ext_extra_blk (op);
20392080
20402081    if  (!ggml_metal_op_flash_attn_ext_use_vec (op)) {
20412082        //  half8x8 kernel
2042-         const  int64_t  nqptg = 8 ;   //  queries per threadgroup    !! sync with kernel template arguments !! 
2043-         const  int64_t  ncpsg = 64 ; //  cache values per simdgroup !! sync with kernel template arguments !! 
2083+         const  int  nqptg = OP_FLASH_ATTN_EXT_NQPTG;  //  queries per threadgroup
2084+         const  int  ncpsg = OP_FLASH_ATTN_EXT_NCPSG ; //  cache values per simdgroup
20442085
20452086        GGML_ASSERT (nqptg <= 32 );
20462087        GGML_ASSERT (nqptg  % 8   == 0 );
20472088        GGML_ASSERT (ncpsg  % 32  == 0 );
20482089
2090+         bool  need_sync = false ;
2091+ 
20492092        const  bool  has_kvpad = ne11 % ncpsg != 0 ;
20502093
20512094        if  (has_kvpad) {
@@ -2083,11 +2126,46 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
20832126
20842127            ggml_metal_encoder_dispatch_threadgroups (enc, ncpsg, std::max (ne12, ne32), std::max (ne13, ne33), 32 , 1 , 1 );
20852128
2086-             ggml_metal_op_concurrency_reset (ctx) ;
2129+             need_sync =  true ;
20872130        } else  {
20882131            assert (ggml_metal_op_flash_attn_ext_extra_pad (op) == 0 );
20892132        }
20902133
2134+         if  (has_mask) {
2135+             assert (ggml_metal_op_flash_attn_ext_extra_blk (op) != 0 );
2136+ 
2137+             ggml_metal_kargs_flash_attn_ext_blk args0 = {
2138+                 /* .ne01 =*/ 
2139+                 /* .ne30 =*/ 
2140+                 /* .ne31 =*/ 
2141+                 /* .ne32 =*/ 
2142+                 /* .ne33 =*/ 
2143+                 /* .nb31 =*/ 
2144+                 /* .nb32 =*/ 
2145+                 /* .nb33 =*/ 
2146+             };
2147+ 
2148+             ggml_metal_pipeline_t  pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk (lib, op, nqptg, ncpsg);
2149+ 
2150+             ggml_metal_encoder_set_pipeline (enc, pipeline0);
2151+             ggml_metal_encoder_set_bytes    (enc, &args0, sizeof (args0), 0 );
2152+             ggml_metal_encoder_set_buffer   (enc, bid_src3, 1 );
2153+             ggml_metal_encoder_set_buffer   (enc, bid_blk,  2 );
2154+ 
2155+             const  int32_t  nblk1 = ((ne01 + nqptg - 1 )/nqptg);
2156+             const  int32_t  nblk0 = ((ne30 + ncpsg - 1 )/ncpsg);
2157+ 
2158+             ggml_metal_encoder_dispatch_threadgroups (enc, nblk0, nblk1, ne32*ne33, 32 , 1 , 1 );
2159+ 
2160+             need_sync = true ;
2161+         } else  {
2162+             assert (ggml_metal_op_flash_attn_ext_extra_blk (op) == 0 );
2163+         }
2164+ 
2165+         if  (need_sync) {
2166+             ggml_metal_op_concurrency_reset (ctx);
2167+         }
2168+ 
20912169        const  int  is_q = ggml_is_quantized (op->src [1 ]->type ) ? 1  : 0 ;
20922170
20932171        //  2*(2*ncpsg)
@@ -2164,22 +2242,25 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
21642242        ggml_metal_encoder_set_buffer   (enc, bid_src3, 4 );
21652243        ggml_metal_encoder_set_buffer   (enc, bid_src4, 5 );
21662244        ggml_metal_encoder_set_buffer   (enc, bid_pad,  6 );
2167-         ggml_metal_encoder_set_buffer   (enc, bid_dst,  7 );
2245+         ggml_metal_encoder_set_buffer   (enc, bid_blk,  7 );
2246+         ggml_metal_encoder_set_buffer   (enc, bid_dst,  8 );
21682247
21692248        ggml_metal_encoder_set_threadgroup_memory_size (enc, smem, 0 );
21702249
21712250        ggml_metal_encoder_dispatch_threadgroups (enc, (ne01 + nqptg - 1 )/nqptg, ne02, ne03, 32 , nsg, 1 );
21722251#undef  FATTN_SMEM
21732252    } else  {
21742253        //  half4x4 kernel
2175-         const  int64_t  nqptg = 1 ;   //  queries per threadgroup    !! sync with kernel template arguments !! 
2176-         const  int64_t  ncpsg = 32 ; //  cache values per simdgroup !! sync with kernel template arguments !!
2177-         const  int64_t  nkpsg = 1 *ncpsg;
2254+         const  int  nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG;  //  queries per threadgroup
2255+         const  int  ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG ; //  cache values per simdgroup !! sync with kernel template arguments !!
2256+         const  int  nkpsg = 1 *ncpsg;
21782257
21792258        GGML_ASSERT (nqptg <= 32 );
21802259        GGML_ASSERT (nqptg  % 1   == 0 );
21812260        GGML_ASSERT (ncpsg  % 32  == 0 );
21822261
2262+         bool  need_sync = false ;
2263+ 
21832264        const  bool  has_kvpad = ne11 % ncpsg != 0 ;
21842265
21852266        if  (has_kvpad) {
@@ -2217,11 +2298,15 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
22172298
22182299            ggml_metal_encoder_dispatch_threadgroups (enc, ncpsg, std::max (ne12, ne32), std::max (ne13, ne33), 32 , 1 , 1 );
22192300
2220-             ggml_metal_op_concurrency_reset (ctx) ;
2301+             need_sync =  true ;
22212302        } else  {
22222303            assert (ggml_metal_op_flash_attn_ext_extra_pad (op) == 0 );
22232304        }
22242305
2306+         if  (need_sync) {
2307+             ggml_metal_op_concurrency_reset (ctx);
2308+         }
2309+ 
22252310        //  ne00 + 2*ncpsg*(nsg)
22262311        //  for each query, we load it as f16 in shared memory (ne00)
22272312        //  and store the soft_max values and the mask
0 commit comments