Skip to content

Commit

Permalink
adjust alibi for non causal too
Browse files Browse the repository at this point in the history
  • Loading branch information
timt51 authored Feb 7, 2025
1 parent f124d98 commit 4f8b153
Showing 1 changed file with 1 addition and 7 deletions.
8 changes: 1 addition & 7 deletions csrc/flash_attn/src/mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,7 @@ struct Mask {
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
if constexpr (Has_alibi) {
if constexpr (Is_causal) {
tensor(make_coord(i, mi), make_coord(j, nj)) += ((col_idx == (col_idx_limit_right - 1)) ? 0 : alibi_slope);

} else {
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);

}
tensor(make_coord(i, mi), make_coord(j, nj)) += ((col_idx == (col_idx_limit_right - 1)) ? 0 : alibi_slope);
}
if constexpr (Causal_mask) {
if (col_idx >= col_idx_limit_right) {
Expand Down

0 comments on commit 4f8b153

Please sign in to comment.