Skip to content

Commit 16d7eee

Browse files
Gitty BursteinGittyBursteinyael-works
committed
SparseK: fix op_params indices + add runtime guards
Co-authored-by: Gitty <[email protected]> Co-authored-by: Yael Shuker <[email protected]>
1 parent 596508b commit 16d7eee

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8003,10 +8003,10 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
80038003
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
80048004

80058005
// -------- SparseK op_params (לא משנה שום דבר חוץ מקריאת הפרמטרים) --------
8006-
const bool use_sparsek = ggml_get_op_params_i32(dst, 30) != 0;
8007-
const int32_t k_top = ggml_get_op_params_i32(dst, 31);
8008-
const int32_t win_local = ggml_get_op_params_i32(dst, 32);
8009-
const int32_t stride_glb = ggml_get_op_params_i32(dst, 33);
8006+
const bool use_sparsek = ggml_get_op_params_i32(dst, 28) != 0;
8007+
const int32_t k_top = ggml_get_op_params_i32(dst, 29);
8008+
const int32_t win_local = ggml_get_op_params_i32(dst, 30);
8009+
const int32_t stride_glb = ggml_get_op_params_i32(dst, 31);
80108010
// ----------------------------------------------------------------------------
80118011

80128012
ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;

ggml/src/ggml.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5215,10 +5215,10 @@ struct ggml_tensor * ggml_flash_attn_back(
52155215
return result;
52165216
}
52175217

5218-
#define GGML_FA_EXT_PARAM_SPARSEK_FLAG 30
5219-
#define GGML_FA_EXT_PARAM_SPARSEK_KTOP 31
5220-
#define GGML_FA_EXT_PARAM_SPARSEK_WIN 32
5221-
#define GGML_FA_EXT_PARAM_SPARSEK_STRIDE 33
5218+
#define GGML_FA_EXT_PARAM_SPARSEK_FLAG 28
5219+
#define GGML_FA_EXT_PARAM_SPARSEK_KTOP 29
5220+
#define GGML_FA_EXT_PARAM_SPARSEK_WIN 30
5221+
#define GGML_FA_EXT_PARAM_SPARSEK_STRIDE 31
52225222

52235223
void ggml_flash_attn_ext_set_sparsek(struct ggml_tensor * a,
52245224
bool use_sparsek,

0 commit comments

Comments
 (0)