Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 17 additions & 27 deletions ggml/src/ggml-cuda/fattn-tile.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -540,10 +540,12 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
}

KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) && (!oob_check || i_KQ < k_VKQ_sup) ?
slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
if (!oob_check || i_KQ < k_VKQ_sup) {
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ?
slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;

KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
}
}

KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]);
Expand Down Expand Up @@ -581,10 +583,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
float KQ_sum_add = 0.0f;
#pragma unroll
for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
const float val = expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]);
if (!oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup) {
KQ_sum_add += val;
}
const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ?
expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;
KQ_sum_add += val;
tmp[i0/(np*warp_size)][jc1] = val;
}
KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add;
Expand Down Expand Up @@ -975,26 +976,6 @@ static __global__ void flash_attn_tile(
}
}

if (gridDim.y == 1) {
#pragma unroll
for (int jc0 = 0; jc0 < cpw; ++jc0) {
#ifdef FAST_FP16_AVAILABLE
const half2 KQ_sum_jc_inv = make_half2(1.0f/KQ_sum[jc0], 1.0f/KQ_sum[jc0]);
#pragma unroll
for (int i = 0; i < (DVp/2)/warp_size; ++i) {
VKQ[jc0*((DVp/2)/warp_size) + i] *= KQ_sum_jc_inv;
}
#else
const float KQ_sum_jc_inv = 1.0f/KQ_sum[jc0];
#pragma unroll
for (int i = 0; i < (DVp/2)/warp_size; ++i) {
VKQ[jc0*((DVp/2)/warp_size) + i].x *= KQ_sum_jc_inv;
VKQ[jc0*((DVp/2)/warp_size) + i].y *= KQ_sum_jc_inv;
}
#endif // FAST_FP16_AVAILABLE
}
}

// Write back results:
#pragma unroll
for (int jc0 = 0; jc0 < cpw; ++jc0) {
Expand All @@ -1007,6 +988,8 @@ static __global__ void flash_attn_tile(
return;
}

const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f;

const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;

#ifdef FAST_FP16_AVAILABLE
Expand All @@ -1017,6 +1000,8 @@ static __global__ void flash_attn_tile(
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);
tmp[i1].x *= scale;
tmp[i1].y *= scale;
}
if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) {
ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
Expand All @@ -1027,6 +1012,11 @@ static __global__ void flash_attn_tile(
#pragma unroll
for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) {
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x *= scale;
VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y *= scale;
}
ggml_cuda_memcpy_1<cpy_ne_D*4>(
&dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D],
&VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]);
Expand Down
Loading