From 91a8ee6a6f1f4c8547ff7b745ef95c6edc1d2af6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C4=90inh=20Tr=E1=BB=8Dng=20Huy?= <77562200+huydt84@users.noreply.github.com> Date: Mon, 9 Jun 2025 13:15:31 +0900 Subject: [PATCH] add geglu activation function (#14074) Co-authored-by: dinhhuy --- src/llama-graph.cpp | 22 ++++++++++++++++++++++ src/llama-graph.h | 1 + 2 files changed, 23 insertions(+) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index c4bdd660..55390d42 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -659,6 +659,28 @@ ggml_tensor * llm_graph_context::build_ffn( cur = ggml_mul(ctx0, x0, x1); cb(cur, "ffn_mul", il); } break; + case LLM_FFN_GEGLU: + { + // Split into two equal parts + int64_t split_point = cur->ne[0] / 2; + ggml_tensor * output_ffn_up = ggml_cont(ctx0, ggml_view_2d( + ctx0, cur, split_point, + cur->ne[1], cur->nb[1], 0 + )); + ggml_tensor * output_ffn_gate = ggml_cont(ctx0, ggml_view_2d( + ctx0, cur, split_point, + cur->ne[1], cur->nb[1], + split_point * ggml_element_size(cur) + )); + + // Apply GELU activation function to the first part + output_ffn_up = ggml_gelu(ctx0, output_ffn_up); + cb(output_ffn_up, "ffn_gelu", il); + + // Element-wise multiplication between the activated part and the gate part + cur = ggml_mul(ctx0, output_ffn_up, output_ffn_gate); + cb(cur, "ffn_geglu", il); + } break; } if (gate && type_gate == LLM_FFN_PAR) { diff --git a/src/llama-graph.h b/src/llama-graph.h index 2b1cfa5b..28da6a52 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -36,6 +36,7 @@ enum llm_ffn_op_type { LLM_FFN_RELU, LLM_FFN_RELU_SQR, LLM_FFN_SWIGLU, + LLM_FFN_GEGLU, }; enum llm_ffn_gate_type {