From a0554c3cdc295dd749925d5b39b76bac2e9f7cff Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 18 Mar 2025 11:14:48 +0200 Subject: [PATCH 1/2] context : always use non-causal attention for encoder graphs ggml-ci --- src/llama-context.cpp | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index abb7e526f6171..a0b3b7d0db2ac 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1627,7 +1627,16 @@ llm_graph_result_ptr llama_context::graph_build( ggml_cgraph * gf, const llama_ubatch & ubatch, llm_graph_type gtype) { - return model.build_graph( + const auto causal_attn_org = cparams.causal_attn; + + // always use non-causal attention for encoder graphs + // TODO: this is a tmp solution until we have a proper way to support enc-dec models + // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223 + if (gtype == LLM_GRAPH_TYPE_ENCODER) { + cparams.causal_attn = false; + } + + auto res = model.build_graph( { /*.ctx =*/ ctx, /*.arch =*/ model.arch, @@ -1643,6 +1652,12 @@ llm_graph_result_ptr llama_context::graph_build( /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), }, gf, gtype); + + if (gtype == LLM_GRAPH_TYPE_ENCODER) { + cparams.causal_attn = causal_attn_org; + } + + return res; } ggml_status llama_context::graph_compute( From 29acf2cf05d5ddb83b881ec1f5343939098a6760 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 18 Mar 2025 11:55:19 +0200 Subject: [PATCH 2/2] context : move the change to llama_context::encode() ggml-ci --- src/llama-context.cpp | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a0b3b7d0db2ac..42332acf1e39d 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1057,6 +1057,13 @@ int llama_context::encode(llama_batch & inp_batch) { ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + const auto causal_attn_org = cparams.causal_attn; + + // always use non-causal attention for encoder graphs + // TODO: this is a tmp solution until we have a proper way to support enc-dec models + // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223 + cparams.causal_attn = false; + auto * gf = graph_init(); auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER); @@ -1064,6 +1071,8 @@ int llama_context::encode(llama_batch & inp_batch) { res->set_inputs(&ubatch); + cparams.causal_attn = causal_attn_org; + const auto compute_status = graph_compute(gf, n_tokens > 1); switch (compute_status) { case GGML_STATUS_SUCCESS: @@ -1627,16 +1636,7 @@ llm_graph_result_ptr llama_context::graph_build( ggml_cgraph * gf, const llama_ubatch & ubatch, llm_graph_type gtype) { - const auto causal_attn_org = cparams.causal_attn; - - // always use non-causal attention for encoder graphs - // TODO: this is a tmp solution until we have a proper way to support enc-dec models - // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223 - if (gtype == LLM_GRAPH_TYPE_ENCODER) { - cparams.causal_attn = false; - } - - auto res = model.build_graph( + return model.build_graph( { /*.ctx =*/ ctx, /*.arch =*/ model.arch, @@ -1652,12 +1652,6 @@ llm_graph_result_ptr llama_context::graph_build( /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), }, gf, gtype); - - if (gtype == LLM_GRAPH_TYPE_ENCODER) { - cparams.causal_attn = causal_attn_org; - } - - return res; } ggml_status llama_context::graph_compute(