Skip to content

Commit 4dff236

Browse files
authored
ggml : remove GGML_KQ_MASK_PAD constant (#17910)
* ggml : remove GGML_KQ_MASK_PAD constant * cont : remove comment
1 parent 4df6e85 commit 4dff236

File tree

7 files changed

+19
-36
lines changed

7 files changed

+19
-36
lines changed

ggml/include/ggml.h

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2305,13 +2305,11 @@ extern "C" {
23052305
float stop,
23062306
float step);
23072307

2308-
#define GGML_KQ_MASK_PAD 1
2309-
2310-
// q: [n_embd_k, n_batch, n_head, ne3 ]
2311-
// k: [n_embd_k, n_kv, n_head_kv, ne3 ]
2312-
// v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !!
2313-
// mask: [n_kv, n_batch_pad, ne32, ne33] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
2314-
// res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !!
2308+
// q: [n_embd_k, n_batch, n_head, ne3 ]
2309+
// k: [n_embd_k, n_kv, n_head_kv, ne3 ]
2310+
// v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !!
2311+
// mask: [n_kv, n_batch, ne32, ne33]
2312+
// res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !!
23152313
//
23162314
// broadcast:
23172315
// n_head % n_head_kv == 0

ggml/src/ggml.c

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5260,8 +5260,6 @@ struct ggml_tensor * ggml_flash_attn_ext(
52605260

52615261
if (mask) {
52625262
GGML_ASSERT(ggml_is_contiguous(mask));
5263-
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
5264-
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
52655263
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
52665264

52675265
GGML_ASSERT(q->ne[2] % mask->ne[2] == 0);

src/llama-context.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,6 @@ llama_context::llama_context(
9393
// with causal attention, the batch size is limited by the context size
9494
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
9595

96-
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
97-
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
98-
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
99-
// TODO: this padding is not needed for the cache-less context so we should probably move it to llama_memory
100-
if (cparams.n_batch < GGML_KQ_MASK_PAD) {
101-
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
102-
cparams.n_batch = GGML_KQ_MASK_PAD;
103-
}
10496
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
10597

10698
cparams.op_offload = params.op_offload;

src/llama-graph.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
385385
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
386386

387387
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
388-
res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
388+
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
389389

390390
return res;
391391
}
@@ -416,10 +416,10 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
416416
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
417417

418418
res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
419-
res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
419+
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
420420

421421
res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
422-
res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
422+
res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
423423

424424
return res;
425425
}
@@ -452,7 +452,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
452452
}
453453
}
454454

455-
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
455+
for (int i = n_tokens; i < n_tokens; ++i) {
456456
for (int j = 0; j < n_enc; ++j) {
457457
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
458458
}
@@ -1470,13 +1470,13 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
14701470
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
14711471

14721472
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1473-
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1473+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
14741474
ggml_set_input(inp->self_kq_mask);
14751475

14761476
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
14771477

14781478
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
1479-
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1479+
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
14801480
ggml_set_input(inp->self_kq_mask_swa);
14811481

14821482
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
@@ -1558,7 +1558,7 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
15581558
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
15591559
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
15601560

1561-
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1561+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
15621562
ggml_set_input(inp->self_kq_mask);
15631563

15641564
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1701,7 +1701,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
17011701

17021702
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
17031703

1704-
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1704+
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1);
17051705
ggml_set_input(inp->cross_kq_mask);
17061706

17071707
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
@@ -1767,7 +1767,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
17671767
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
17681768
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
17691769

1770-
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1770+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
17711771
ggml_set_input(inp->self_kq_mask);
17721772

17731773
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1781,7 +1781,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
17811781
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
17821782
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
17831783

1784-
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1784+
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
17851785
ggml_set_input(inp->self_kq_mask_swa);
17861786

17871787
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;

src/llama-kv-cache.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,8 +1232,7 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
12321232
GGML_ASSERT(n_tokens%n_stream == 0);
12331233

12341234
// n_tps == n_tokens_per_stream
1235-
const int64_t n_tps = n_tokens/n_stream;
1236-
const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
1235+
const int64_t n_tps = n_tokens/n_stream;
12371236

12381237
std::fill(data, data + ggml_nelements(dst), -INFINITY);
12391238

@@ -1266,7 +1265,7 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
12661265
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
12671266
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
12681267

1269-
const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
1268+
const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii);
12701269

12711270
for (uint32_t j = 0; j < n_kv; ++j) {
12721271
if (cells.is_empty(j)) {

tests/test-backend-ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5875,7 +5875,7 @@ struct test_flash_attn_ext : public test_case {
58755875

58765876
ggml_tensor * m = nullptr;
58775877
if (mask) {
5878-
m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, nr23[1]);
5878+
m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, nb, 1, nr23[1]);
58795879
ggml_set_name(m, "m");
58805880
}
58815881

tools/mtmd/clip.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -775,10 +775,6 @@ struct clip_graph {
775775

776776
// if flash attn is used, we need to pad the mask and cast to f16
777777
if (ctx->flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
778-
int n_pad = GGML_PAD(window_mask->ne[1], GGML_KQ_MASK_PAD) - window_mask->ne[1];
779-
if (n_pad > 0) {
780-
window_mask = ggml_pad(ctx0, window_mask, 0, n_pad, 0, 0);
781-
}
782778
window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16);
783779
}
784780

@@ -791,7 +787,7 @@ struct clip_graph {
791787

792788
// loop over layers
793789
for (int il = 0; il < n_layer; il++) {
794-
auto & layer = model.layers[il];
790+
const auto & layer = model.layers[il];
795791
const bool full_attn = use_window_attn ? (il + 1) % n_wa_pattern == 0 : true;
796792

797793
ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states

0 commit comments

Comments
 (0)