llama : Support llama 4 text-only (#12791)

* llama4 conversion

* initial support, no chat template

* clean up a bit

* fix tokenizer conversion

* correct hparams

* try this

* fix shexp

* ffn_inp_normed

* chat template

* clean up model conversion

* add_bos

* add scale_before_ffn

* fix order

* weight_before_ffn

* llm_graph_input_attn_temp

* add chunk attn mask

* build_inp_attn_scale()

* add comment about ggml_repeat

* clarify comments

* fix build
This commit is contained in:
Xuan-Son Nguyen 2025-04-07 23:06:44 +02:00 committed by GitHub
parent 82974011f3
commit 1466621e73
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 532 additions and 22 deletions

View file

@ -6,6 +6,7 @@
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_LLAMA, "llama" },
{ LLM_ARCH_LLAMA4, "llama4" },
{ LLM_ARCH_DECI, "deci" },
{ LLM_ARCH_FALCON, "falcon" },
{ LLM_ARCH_GROK, "grok" },
@ -114,6 +115,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
{ LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@ -233,6 +235,35 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_LLAMA4,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_DECI,
{

View file

@ -10,6 +10,7 @@
enum llm_arch {
LLM_ARCH_LLAMA,
LLM_ARCH_LLAMA4,
LLM_ARCH_DECI,
LLM_ARCH_FALCON,
LLM_ARCH_BAICHUAN,
@ -118,6 +119,7 @@ enum llm_kv {
LLM_KV_RESIDUAL_SCALE,
LLM_KV_EMBEDDING_SCALE,
LLM_KV_TOKEN_SHIFT_COUNT,
LLM_KV_INTERLEAVE_MOE_LAYER_STEP,
LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV,

View file

@ -61,6 +61,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX },
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
};
llm_chat_template llm_chat_template_from_str(const std::string & name) {
@ -174,6 +175,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_YANDEX;
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("'HUMAN'")) {
return LLM_CHAT_TEMPLATE_BAILING;
} else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
return LLM_CHAT_TEMPLATE_LLAMA4;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
@ -608,7 +611,16 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "<role>ASSISTANT</role>";
}
} else {
} else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA4) {
// Llama 4
for (auto message : chat) {
std::string role(message->role);
ss << "<|header_start|>" << role << "<|header_end|>\n\n" << trim(message->content) << "<|eot|>";
}
if (add_ass) {
ss << "<|header_start|>assistant<|header_end|>\n\n";
}
} else {
// template not supported
return -1;
}

View file

@ -40,6 +40,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_MEGREZ,
LLM_CHAT_TEMPLATE_YANDEX,
LLM_CHAT_TEMPLATE_BAILING,
LLM_CHAT_TEMPLATE_LLAMA4,
LLM_CHAT_TEMPLATE_UNKNOWN,
};

View file

@ -59,6 +59,22 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
}
}
void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
if (ubatch->pos && attn_scale) {
const int64_t n_tokens = ubatch->n_tokens;
std::vector<float> attn_scale_data(n_tokens, 0.0f);
for (int i = 0; i < n_tokens; ++i) {
const float pos = ubatch->pos[i];
attn_scale_data[i] = std::log(
std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
) * f_attn_temp_scale + 1.0;
}
ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*n_pos_per_token*ggml_element_size(attn_scale));
}
}
void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
if (pos_bucket) {
const int64_t n_tokens = ubatch->n_tokens;
@ -458,9 +474,17 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
}
// may need to cut off old tokens for sliding window
// TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
if (data_swa) {
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
f = -INFINITY;
if (hparams.n_attn_chunk) {
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
f = -INFINITY;
}
} else {
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
f = -INFINITY;
}
}
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
}
@ -812,8 +836,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
float w_scale,
llama_expert_gating_func_type gating_op,
int il) const {
int64_t n_embd = cur->ne[0];
int64_t n_tokens = cur->ne[1];
const int64_t n_embd = cur->ne[0];
const int64_t n_tokens = cur->ne[1];
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
cb(logits, "ffn_moe_logits", il);
@ -841,6 +866,12 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
cb(selection_probs, "ffn_moe_probs_biased", il);
}
// llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
// see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
if (arch == LLM_ARCH_LLAMA4) {
selection_probs = logits;
}
// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
cb(selected_experts->src[0], "ffn_moe_argsort", il);
@ -867,6 +898,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
}
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
if (weight_before_ffn) {
// TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
cur = ggml_mul(ctx0, repeated, weights);
cb(cur, "ffn_moe_weighted", il);
}
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(up, "ffn_moe_up", il);
@ -894,7 +934,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
cb(experts, "ffn_moe_down", il);
experts = ggml_mul(ctx0, experts, weights);
if (!weight_before_ffn) {
experts = ggml_mul(ctx0, experts, weights);
cb(cur, "ffn_moe_weighted", il);
}
// aggregate experts
ggml_tensor * moe_out = nullptr;
@ -914,6 +957,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
moe_out = ggml_cont(ctx0, moe_out);
}
cb(moe_out, "ffn_moe_out", il);
return moe_out;
}
@ -981,6 +1026,19 @@ ggml_tensor * llm_graph_context::build_inp_pos() const {
return cur;
}
ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
auto inp = std::make_unique<llm_graph_input_attn_temp>(n_pos_per_token(), hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
auto & cur = inp->attn_scale;
cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token());
ggml_set_input(cur);
res->add_input(std::move(inp));
return cur;
}
ggml_tensor * llm_graph_context::build_inp_out_ids() const {
auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);

View file

@ -100,6 +100,23 @@ public:
const int64_t n_pos_per_token = 1;
};
// temperature tuning, used by llama4
class llm_graph_input_attn_temp : public llm_graph_input_i {
public:
llm_graph_input_attn_temp(int64_t n_pos_per_token, uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
: n_pos_per_token(n_pos_per_token), n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
virtual ~llm_graph_input_attn_temp() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
const int64_t n_pos_per_token = 1;
const uint32_t n_attn_temp_floor_scale;
const float f_attn_temp_scale;
};
class llm_graph_input_pos_bucket : public llm_graph_input_i {
public:
llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
@ -470,6 +487,7 @@ struct llm_graph_context {
ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
ggml_tensor * build_inp_pos() const;
ggml_tensor * build_inp_attn_scale() const;
ggml_tensor * build_inp_out_ids() const;
ggml_tensor * build_inp_mean() const;
ggml_tensor * build_inp_cls() const;

View file

@ -112,6 +112,14 @@ struct llama_hparams {
bool use_alibi = false;
bool attn_soft_cap = false;
uint32_t n_moe_layer_step = 0;
bool use_kq_norm = true;
uint32_t n_attn_chunk = 0;
// values below seems to be fixed on llama4
uint32_t n_no_rope_layer_step = 4;
uint32_t n_attn_temp_floor_scale = 8192;
float f_attn_temp_scale = 0.1;
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;

View file

@ -90,6 +90,8 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_57B_A14B: return "57B.A14B";
case LLM_TYPE_27B: return "27B";
case LLM_TYPE_290B: return "290B";
case LLM_TYPE_17B_16E: return "17Bx16E (Scout)";
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
default: return "?B";
}
}
@ -550,6 +552,25 @@ void llama_model::load_hparams(llama_model_loader & ml) {
}
}
} break;
case LLM_ARCH_LLAMA4:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full
hparams.n_attn_chunk = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
hparams.n_swa = 1; // TODO @ngxson : this is added to trigger the SWA branch (we store the chunked attn mask in the SWA tensor), will need to clean this up later
switch (hparams.n_expert) {
case 16: type = LLM_TYPE_17B_16E; break;
case 128: type = LLM_TYPE_17B_128E; break;
default: type = LLM_TYPE_UNKNOWN;
}
if (type == LLM_TYPE_17B_128E) {
hparams.use_kq_norm = false;
}
} break;
case LLM_ARCH_DECI:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@ -1690,6 +1711,56 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
}
} break;
case LLM_ARCH_LLAMA4:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Llama 4 requires n_moe_layer_step > 0");
for (int i = 0; i < n_layer; ++i) {
bool is_moe_layer = (i + 1) % hparams.n_moe_layer_step == 0;
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
if (is_moe_layer) {
int n_ff_exp = hparams.n_ff_exp;
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
// Shared expert
const int64_t n_ff_shexp = n_ff_exp;
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd }, 0);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
} else {
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
}
} break;
case LLM_ARCH_DECI:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -4203,12 +4274,22 @@ struct llm_build_llama : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
// temperature tuning
ggml_tensor * inp_attn_scale = nullptr;
if (arch == LLM_ARCH_LLAMA4) {
inp_attn_scale = build_inp_attn_scale();
}
auto * inp_attn = build_attn_inp_kv_unified();
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
bool use_rope = arch == LLM_ARCH_LLAMA4
? (il + 1) % hparams.n_no_rope_layer_step != 0
: true;
// norm
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
@ -4246,25 +4327,38 @@ struct llm_build_llama : public llm_graph_context {
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
if (use_rope) {
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
} else if (inp_attn_scale) {
Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale);
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
if (arch == LLM_ARCH_LLAMA4 && use_rope && hparams.use_kq_norm) {
// Llama4TextL2Norm
Qcur = ggml_rms_norm(ctx0, Qcur, 1e-6);
Kcur = ggml_rms_norm(ctx0, Kcur, 1e-6);
cb(Qcur, "Qcur_normed", il);
cb(Kcur, "Kcur_normed", il);
}
cur = build_attn(inp_attn, gf,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, kq_scale, il);
cb(cur, "attn_out", il);
}
if (il == n_layer - 1) {
@ -4282,7 +4376,7 @@ struct llm_build_llama : public llm_graph_context {
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
// feed-forward network (non-MoE)
if (model.layers[il].ffn_gate_inp == nullptr) {
cur = build_norm(ffn_inp,
@ -4297,6 +4391,38 @@ struct llm_build_llama : public llm_graph_context {
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
} else if (arch == LLM_ARCH_LLAMA4) {
// llama4 MoE
ggml_tensor * ffn_inp_normed = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
ggml_tensor * moe_out = build_moe_ffn(ffn_inp_normed,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU, false,
false, 0.0,
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID,
il);
// Shared experts
ggml_tensor * shexp_out = build_ffn(ffn_inp_normed,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(shexp_out, "ffn_moe_shexp", il);
cur = ggml_add(ctx0, moe_out, shexp_out);
cb(cur, "ffn_moe_out_merged", il);
} else {
// MoE branch
cur = build_norm(ffn_inp,
@ -12091,6 +12217,7 @@ llm_graph_result_ptr llama_model::build_graph(
switch (arch) {
case LLM_ARCH_LLAMA:
case LLM_ARCH_LLAMA4:
case LLM_ARCH_MINICPM:
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
@ -12440,6 +12567,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
// use what we call a normal RoPE, operating on pairs of consecutive head values
case LLM_ARCH_LLAMA:
case LLM_ARCH_LLAMA4:
case LLM_ARCH_DECI:
case LLM_ARCH_BAICHUAN:
case LLM_ARCH_STARCODER:

View file

@ -86,6 +86,8 @@ enum llm_type {
LLM_TYPE_57B_A14B,
LLM_TYPE_27B,
LLM_TYPE_290B,
LLM_TYPE_17B_16E, // llama4 Scout
LLM_TYPE_17B_128E, // llama4 Maverick
};
struct llama_layer_posnet {

View file

@ -1616,7 +1616,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "megrez") {
pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
} else if (
tokenizer_pre == "gpt-4o") {
tokenizer_pre == "gpt-4o" ||
tokenizer_pre == "llama4") {
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O;
clean_spaces = false;
} else if (