Skip to content

Commit e2e9fab

Browse files
authored
bugfix: Fix precision issue with Triton operator token_att_fwd (#1092)
1 parent de8dc64 commit e2e9fab

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ def _fwd_kernel_token_att1(
6060
).to(tl.int64)
6161
off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd
6262
k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)
63-
att_value = tl.sum(q[None, :] * k, 1)
64-
att_value = att_value.to(tl.float32)
63+
att_value = tl.sum(q[None, :] * k, 1, dtype=tl.float32)
6564
att_value *= sm_scale
6665
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs
6766
tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)

0 commit comments

Comments
 (0)