Skip to content

Commit

Permalink
redo causal and noncausal alibi for diagonal
Browse files Browse the repository at this point in the history
  • Loading branch information
timt51 authored Feb 7, 2025
1 parent c42bf6b commit a8d7fcc
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion csrc/flash_attn/src/mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,13 @@ struct Mask {
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
if constexpr (Has_alibi) {
tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope;
if constexpr (Is_causal) {
tensor(make_coord(i, mi), make_coord(j, nj)) += ((col_idx == (col_idx_limit_right - 1)) ? 0 : alibi_slope);

} else {
tensor(make_coord(i, mi), make_coord(j, nj)) += ((col_idx == row_idx) ? 0 : alibi_slope);

}
}
if constexpr (Causal_mask) {
if (col_idx >= col_idx_limit_right) {
Expand Down

0 comments on commit a8d7fcc

Please sign in to comment.