Skip to content

Commit a8d7fcc

Browse files
authored
redo causal and noncausal alibi for diagonal
1 parent c42bf6b commit a8d7fcc

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

csrc/flash_attn/src/mask.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,13 @@ struct Mask {
177177
for (int j = 0; j < size<1, 0>(tensor); ++j) {
178178
const int col_idx = col_idx_base + j;
179179
if constexpr (Has_alibi) {
180-
tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope;
180+
if constexpr (Is_causal) {
181+
tensor(make_coord(i, mi), make_coord(j, nj)) += ((col_idx == (col_idx_limit_right - 1)) ? 0 : alibi_slope);
182+
183+
} else {
184+
tensor(make_coord(i, mi), make_coord(j, nj)) += ((col_idx == row_idx) ? 0 : alibi_slope);
185+
186+
}
181187
}
182188
if constexpr (Causal_mask) {
183189
if (col_idx >= col_idx_limit_right) {

0 commit comments

Comments
 (0)