From 0418e3415bb3d0f44cb91425f1dd900fbc07fa35 Mon Sep 17 00:00:00 2001 From: Siyu Wu Date: Sun, 2 Nov 2025 15:52:32 +0000 Subject: [PATCH] bugfix: Fix precision issue with Triton operator token_att_fwd --- .../models/llama/triton_kernel/token_attention_nopad_att1.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py index 45de83e98..eb5af6fec 100644 --- a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py +++ b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py @@ -60,8 +60,7 @@ def _fwd_kernel_token_att1( ).to(tl.int64) off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1) - att_value = att_value.to(tl.float32) + att_value = tl.sum(q[None, :] * k, 1, dtype=tl.float32) att_value *= sm_scale off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)