From 8551c44d840a7db50adb958ccaf464dc3ded82e7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 18 Mar 2025 13:05:49 +0200 Subject: [PATCH] context : always use non-causal attention for encoder graphs (#12447) * context : always use non-causal attention for encoder graphs ggml-ci * context : move the change to llama_context::encode() ggml-ci --- src/llama-context.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index abb7e526..42332acf 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1057,6 +1057,13 @@ int llama_context::encode(llama_batch & inp_batch) { ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + const auto causal_attn_org = cparams.causal_attn; + + // always use non-causal attention for encoder graphs + // TODO: this is a tmp solution until we have a proper way to support enc-dec models + // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223 + cparams.causal_attn = false; + auto * gf = graph_init(); auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER); @@ -1064,6 +1071,8 @@ int llama_context::encode(llama_batch & inp_batch) { res->set_inputs(&ubatch); + cparams.causal_attn = causal_attn_org; + const auto compute_status = graph_compute(gf, n_tokens > 1); switch (compute_status) { case GGML_STATUS_SUCCESS: