llama : fix Gemma3 SWA KV cache shift (#12373)

* llama : fix Gemma3 SWA KV cache shift

ggml-ci

* hparams : add comment [no ci]
This commit is contained in:
Georgi Gerganov 2025-03-13 19:08:07 +02:00 committed by GitHub
parent be7c303410
commit 84d5475541
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 37 additions and 43 deletions

View file

@ -442,10 +442,10 @@ ggml_tensor * llama_context::build_rope_shift(
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * factors,
float freq_base,
float freq_scale,
ggml_backend_buffer * bbuf) const {
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
const auto & freq_base = cparams.rope_freq_base;
const auto & freq_scale = cparams.rope_freq_scale;
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
const auto & yarn_attn_factor = cparams.yarn_attn_factor;
@ -537,6 +537,17 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
const int64_t n_head_kv = hparams.n_head_kv(il);
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
float freq_base_l = cparams.rope_freq_base;
float freq_scale_l = cparams.rope_freq_scale;
// TODO: improve
if (model.arch == LLM_ARCH_GEMMA3) {
const bool is_sliding = hparams.is_sliding(il);
freq_base_l = is_sliding ? 10000.0f : cparams.rope_freq_base;
freq_scale_l = is_sliding ? 1.0f : cparams.rope_freq_scale;
}
ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
ggml_tensor * k =
@ -546,7 +557,7 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
0);
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, kv_self->k_l[il]->buffer);
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer);
ggml_build_forward_expand(gf, cur);
}