From a8d7fccd26e58eefe89f1b79454ea88886f89d21 Mon Sep 17 00:00:00 2001 From: timt51 Date: Fri, 7 Feb 2025 09:47:15 -0500 Subject: [PATCH] redo causal and noncausal alibi for diagonal --- csrc/flash_attn/src/mask.h | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/csrc/flash_attn/src/mask.h b/csrc/flash_attn/src/mask.h index 31df79f4e..709493331 100644 --- a/csrc/flash_attn/src/mask.h +++ b/csrc/flash_attn/src/mask.h @@ -177,7 +177,13 @@ struct Mask { for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; if constexpr (Has_alibi) { - tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope; + if constexpr (Is_causal) { + tensor(make_coord(i, mi), make_coord(j, nj)) += ((col_idx == (col_idx_limit_right - 1)) ? 0 : alibi_slope); + + } else { + tensor(make_coord(i, mi), make_coord(j, nj)) += ((col_idx == row_idx) ? 0 : alibi_slope); + + } } if constexpr (Causal_mask) { if (col_idx >= col_idx_limit_right) {