Skip to content

Commit 5798c33

Browse files
committed
Add SparseK KQ mask unit test
1 parent b7315fc commit 5798c33

File tree

2 files changed

+245
-0
lines changed

2 files changed

+245
-0
lines changed

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ endif()
182182
llama_build_and_test(test-chat-parser.cpp)
183183
llama_build_and_test(test-chat-template.cpp)
184184
llama_build_and_test(test-json-partial.cpp)
185+
llama_build_and_test(test-sparsek_kq_mask.cpp)
185186
llama_build_and_test(test-log.cpp)
186187
llama_build_and_test(test-regex-partial.cpp)
187188

tests/test-sparsek_kq_mask.cpp

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
#include <cassert>
2+
#include <cmath>
3+
#include <cstdint>
4+
#include <cstdlib>
5+
#include <iostream>
6+
#include <vector>
7+
8+
// Small helper: assert that a value is -INF
9+
static void assert_is_neginf(float x) {
10+
assert(std::isinf(x) && x < 0.0f && "expected -INF");
11+
}
12+
13+
// Small helper: assert that a value is exactly 0.0f
14+
static void assert_is_zero(float x) {
15+
const float eps = 1e-8f;
16+
assert(std::fabs(x - 0.0f) < eps && "expected 0.0f");
17+
}
18+
19+
// This helper mirrors the SparseK row logic used at the end of
20+
// llama_kv_cache::set_input_kq_mask in src/llama-kv-cache.cpp.
21+
//
22+
// It operates on a single mask row of length n_kv for a specific token index i.
23+
static void apply_sparsek_row(float * row, int64_t n_kv, int token_index, bool causal_attn) {
24+
// Read SparseK configuration from environment, similar to the production code.
25+
const char * s = nullptr;
26+
27+
bool SPARSEK_ENABLE = false;
28+
int SPARSEK_WIN_LOCAL = 64;
29+
int SPARSEK_STRIDE = 128;
30+
bool SPARSEK_EN_LOCAL = true;
31+
bool SPARSEK_EN_STRIDE = true;
32+
33+
if ((s = std::getenv("LLAMA_SPARSEK_ENABLE"))) {
34+
SPARSEK_ENABLE = std::atoi(s) != 0;
35+
}
36+
if ((s = std::getenv("LLAMA_SPARSEK_WIN"))) {
37+
SPARSEK_WIN_LOCAL = std::max(0, std::atoi(s));
38+
}
39+
if ((s = std::getenv("LLAMA_SPARSEK_STRIDE"))) {
40+
SPARSEK_STRIDE = std::max(0, std::atoi(s));
41+
}
42+
if ((s = std::getenv("LLAMA_SPARSEK_ENABLE_LOCAL"))) {
43+
SPARSEK_EN_LOCAL = std::atoi(s) != 0;
44+
}
45+
if ((s = std::getenv("LLAMA_SPARSEK_ENABLE_STRIDE"))) {
46+
SPARSEK_EN_STRIDE = std::atoi(s) != 0;
47+
}
48+
49+
// Same intended gating as in the SparseK block:
50+
// if SparseK is disabled, or all patterns are disabled, leave the row unchanged.
51+
if (!SPARSEK_ENABLE || (!SPARSEK_EN_LOCAL && !SPARSEK_EN_STRIDE)) {
52+
return;
53+
}
54+
55+
std::vector<uint8_t> allow(n_kv, 0);
56+
57+
// Local window pattern (symmetric around the current token index)
58+
if (SPARSEK_EN_LOCAL && SPARSEK_WIN_LOCAL > 0) {
59+
const int j0 = std::max<int>(0, token_index - SPARSEK_WIN_LOCAL);
60+
const int j1 = std::min<int>(static_cast<int>(n_kv) - 1, token_index + SPARSEK_WIN_LOCAL);
61+
for (int j = j0; j <= j1; ++j) {
62+
allow[j] = 1;
63+
}
64+
}
65+
66+
// Stride pattern (backward only for causal, both directions for non-causal)
67+
if (SPARSEK_EN_STRIDE && SPARSEK_STRIDE > 0) {
68+
for (int j = token_index; j >= 0; j -= SPARSEK_STRIDE) {
69+
allow[j] = 1;
70+
}
71+
if (!causal_attn) {
72+
for (int j = token_index; j < static_cast<int>(n_kv); j += SPARSEK_STRIDE) {
73+
allow[j] = 1;
74+
}
75+
}
76+
}
77+
78+
// Final mask update: disallowed positions get -INF,
79+
// allowed positions reset any negative infinity back to 0.0f.
80+
for (int64_t j = 0; j < n_kv; ++j) {
81+
if (!allow[j]) {
82+
row[j] = -INFINITY;
83+
} else if (std::isinf(row[j]) && row[j] < 0.0f) {
84+
row[j] = 0.0f;
85+
}
86+
}
87+
}
88+
89+
// Pretty-print helper for debugging, not strictly required but useful.
90+
static void dump_row(const char * name, const std::vector<float> & row) {
91+
std::cout << name << ":";
92+
for (float v : row) {
93+
if (std::isinf(v) && v < 0.0f) {
94+
std::cout << " -INF";
95+
} else {
96+
std::cout << " " << v;
97+
}
98+
}
99+
std::cout << "\n";
100+
}
101+
102+
// Scenario 1: SparseK disabled -> row must remain unchanged.
103+
static void test_sparsek_disabled_keeps_row() {
104+
const int64_t n_kv = 8;
105+
std::vector<float> row(n_kv, 0.0f);
106+
107+
// Configure environment: disabled SparseK.
108+
setenv("LLAMA_SPARSEK_ENABLE", "0", 1);
109+
setenv("LLAMA_SPARSEK_WIN", "2", 1);
110+
setenv("LLAMA_SPARSEK_STRIDE", "2", 1);
111+
setenv("LLAMA_SPARSEK_ENABLE_LOCAL", "1", 1);
112+
setenv("LLAMA_SPARSEK_ENABLE_STRIDE", "1", 1);
113+
114+
apply_sparsek_row(row.data(), n_kv, /*token_index=*/3, /*causal_attn=*/true);
115+
116+
for (int64_t j = 0; j < n_kv; ++j) {
117+
assert_is_zero(row[j]);
118+
}
119+
}
120+
121+
// Scenario 2: Local window only, causal attention.
122+
// With n_kv = 8, token_index = 3 and window = 1, we expect positions {2,3,4} to be allowed.
123+
static void test_sparsek_local_window_only() {
124+
const int64_t n_kv = 8;
125+
std::vector<float> row(n_kv, -INFINITY);
126+
127+
setenv("LLAMA_SPARSEK_ENABLE", "1", 1);
128+
setenv("LLAMA_SPARSEK_WIN", "1", 1);
129+
setenv("LLAMA_SPARSEK_STRIDE", "0", 1);
130+
setenv("LLAMA_SPARSEK_ENABLE_LOCAL", "1", 1);
131+
setenv("LLAMA_SPARSEK_ENABLE_STRIDE", "0", 1);
132+
133+
const int token_index = 3;
134+
apply_sparsek_row(row.data(), n_kv, token_index, /*causal_attn=*/true);
135+
136+
// Optional debug print:
137+
// dump_row("local_window_only", row);
138+
139+
for (int64_t j = 0; j < n_kv; ++j) {
140+
bool should_allow = (j == 2 || j == 3 || j == 4);
141+
if (should_allow) {
142+
assert_is_zero(row[j]);
143+
} else {
144+
assert_is_neginf(row[j]);
145+
}
146+
}
147+
}
148+
149+
// Scenario 3: Stride only, causal attention.
150+
// With n_kv = 8, token_index = 5, stride = 2, causal:
151+
// allowed positions should be {5, 3, 1}.
152+
static void test_sparsek_stride_causal() {
153+
const int64_t n_kv = 8;
154+
std::vector<float> row(n_kv, -INFINITY);
155+
156+
setenv("LLAMA_SPARSEK_ENABLE", "1", 1);
157+
setenv("LLAMA_SPARSEK_WIN", "0", 1);
158+
setenv("LLAMA_SPARSEK_STRIDE", "2", 1);
159+
setenv("LLAMA_SPARSEK_ENABLE_LOCAL", "0", 1);
160+
setenv("LLAMA_SPARSEK_ENABLE_STRIDE", "1", 1);
161+
162+
const int token_index = 5;
163+
apply_sparsek_row(row.data(), n_kv, token_index, /*causal_attn=*/true);
164+
165+
// dump_row("stride_causal", row);
166+
167+
for (int64_t j = 0; j < n_kv; ++j) {
168+
bool should_allow = (j == 1 || j == 3 || j == 5);
169+
if (should_allow) {
170+
assert_is_zero(row[j]);
171+
} else {
172+
assert_is_neginf(row[j]);
173+
}
174+
}
175+
}
176+
177+
// Scenario 4: Stride only, non-causal.
178+
// With n_kv = 8, token_index = 5, stride = 2, non-causal:
179+
// allowed positions should be {1, 3, 5, 7}.
180+
static void test_sparsek_stride_noncausal() {
181+
const int64_t n_kv = 8;
182+
std::vector<float> row(n_kv, -INFINITY);
183+
184+
setenv("LLAMA_SPARSEK_ENABLE", "1", 1);
185+
setenv("LLAMA_SPARSEK_WIN", "0", 1);
186+
setenv("LLAMA_SPARSEK_STRIDE", "2", 1);
187+
setenv("LLAMA_SPARSEK_ENABLE_LOCAL", "0", 1);
188+
setenv("LLAMA_SPARSEK_ENABLE_STRIDE", "1", 1);
189+
190+
const int token_index = 5;
191+
apply_sparsek_row(row.data(), n_kv, token_index, /*causal_attn=*/false);
192+
193+
// dump_row("stride_noncausal", row);
194+
195+
for (int64_t j = 0; j < n_kv; ++j) {
196+
bool should_allow = (j == 1 || j == 3 || j == 5 || j == 7);
197+
if (should_allow) {
198+
assert_is_zero(row[j]);
199+
} else {
200+
assert_is_neginf(row[j]);
201+
}
202+
}
203+
}
204+
205+
// Scenario 5: Combined local window + stride.
206+
// This checks that both patterns are OR'ed together.
207+
static void test_sparsek_combined_patterns() {
208+
const int64_t n_kv = 16;
209+
std::vector<float> row(n_kv, -INFINITY);
210+
211+
setenv("LLAMA_SPARSEK_ENABLE", "1", 1);
212+
setenv("LLAMA_SPARSEK_WIN", "1", 1);
213+
setenv("LLAMA_SPARSEK_STRIDE", "4", 1);
214+
setenv("LLAMA_SPARSEK_ENABLE_LOCAL", "1", 1);
215+
setenv("LLAMA_SPARSEK_ENABLE_STRIDE", "1", 1);
216+
217+
const int token_index = 8;
218+
apply_sparsek_row(row.data(), n_kv, token_index, /*causal_attn=*/true);
219+
220+
// Local window (radius 1) -> {7,8,9}
221+
// Stride (4, causal, backward) from 8 -> {8,4,0}
222+
// Union -> {0,4,7,8,9}
223+
for (int64_t j = 0; j < n_kv; ++j) {
224+
bool should_allow = (j == 0 || j == 4 || j == 7 || j == 8 || j == 9);
225+
if (should_allow) {
226+
assert_is_zero(row[j]);
227+
} else {
228+
assert_is_neginf(row[j]);
229+
}
230+
}
231+
}
232+
233+
int main() {
234+
std::cout << "Running SparseK KQ mask row tests...\n";
235+
236+
test_sparsek_disabled_keeps_row();
237+
test_sparsek_local_window_only();
238+
test_sparsek_stride_causal();
239+
test_sparsek_stride_noncausal();
240+
test_sparsek_combined_patterns();
241+
242+
std::cout << "All SparseK KQ mask tests passed.\n";
243+
return 0;
244+
}

0 commit comments

Comments
 (0)