Skip to content

MLA + FA now only uses K-cache - 47% saving on KV-cache size (only for use with #13435 for now) #13529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
52 changes: 33 additions & 19 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1432,21 +1432,24 @@ ggml_tensor * llm_graph_context::build_attn(

v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);

ggml_tensor * v_cache_view = nullptr;
// note: MLA with flash attention now uses the last 512 elements of K-cache in place of a V-cache
if (!v_mla || v_trans) {
ggml_tensor * v_cache_view = nullptr;

if (!v_trans) {
v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
} else {
// note: the V cache is transposed when not using flash attention
v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
( n_ctx)*ggml_element_size(kv_self->v_l[il]),
(kv_head)*ggml_element_size(kv_self->v_l[il]));
if (!v_trans) {
v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
} else {
// note: the V cache is transposed when not using flash attention
v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
( n_ctx)*ggml_element_size(kv_self->v_l[il]),
(kv_head)*ggml_element_size(kv_self->v_l[il]));

v_cur = ggml_transpose(ctx0, v_cur);
}
//cb(v_cache_view, "v_cache_view", il);
v_cur = ggml_transpose(ctx0, v_cur);
}
//cb(v_cache_view, "v_cache_view", il);

ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
}
}

const bool is_swa = hparams.is_swa(il);
Expand All @@ -1471,17 +1474,28 @@ ggml_tensor * llm_graph_context::build_attn(
0);
//cb(k, "k", il);

ggml_tensor * v = !v_trans ?
ggml_view_3d(ctx0, kv_self->v_l[il],
n_embd_head_v, n_kv, n_head_kv,
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
0) :
ggml_view_3d(ctx0, kv_self->v_l[il],
ggml_tensor * v = nullptr;

if (v_trans) {
v = ggml_view_3d(ctx0, kv_self->v_l[il],
n_kv, n_embd_head_v, n_head_kv,
ggml_element_size(kv_self->v_l[il])*n_ctx,
ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
0);
} else if (!v_mla) {
v = ggml_view_3d(ctx0, kv_self->v_l[il],
n_embd_head_v, n_kv, n_head_kv,
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
0);
} else {
// note: MLA with flash attention now uses the last 512 elements of K-cache in place of a V-cache
v = ggml_view_3d(ctx0, kv_self->k_l[il],
n_embd_head_v, n_kv, n_head_kv,
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
n_embd_head_k-n_embd_head_v); // offset by n_rot elements
}

ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
cb(cur, "kqv_out", il);
Expand Down
7 changes: 5 additions & 2 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ llama_kv_cache_unified::llama_kv_cache_unified(
uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
const int32_t n_layer = hparams.n_layer;

const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);

has_shift = false;
can_shift = true;
can_shift = !is_mla || v_trans; // TODO: allow context shifting for MLA with flash attention

LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d, padding = %d\n",
__func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift, padding);
Expand Down Expand Up @@ -100,8 +102,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
throw std::runtime_error("failed to create ggml context for kv cache");
}

// note: MLA with flash attention now uses the last 512 elements of K-cache in place of a V-cache
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, !is_mla || v_trans ? n_embd_v_gqa*kv_size : 0);
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
k_l.push_back(k);
Expand Down
Loading