llama : fix Gemma3 SWA KV cache shift (#12373)
* llama : fix Gemma3 SWA KV cache shift ggml-ci * hparams : add comment [no ci]
This commit is contained in:
parent
be7c303410
commit
84d5475541
6 changed files with 37 additions and 43 deletions
|
@ -442,10 +442,10 @@ ggml_tensor * llama_context::build_rope_shift(
|
|||
ggml_tensor * cur,
|
||||
ggml_tensor * shift,
|
||||
ggml_tensor * factors,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
ggml_backend_buffer * bbuf) const {
|
||||
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
|
||||
const auto & freq_base = cparams.rope_freq_base;
|
||||
const auto & freq_scale = cparams.rope_freq_scale;
|
||||
|
||||
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
|
||||
const auto & yarn_attn_factor = cparams.yarn_attn_factor;
|
||||
|
@ -537,6 +537,17 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
|
|||
const int64_t n_head_kv = hparams.n_head_kv(il);
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
|
||||
float freq_base_l = cparams.rope_freq_base;
|
||||
float freq_scale_l = cparams.rope_freq_scale;
|
||||
|
||||
// TODO: improve
|
||||
if (model.arch == LLM_ARCH_GEMMA3) {
|
||||
const bool is_sliding = hparams.is_sliding(il);
|
||||
|
||||
freq_base_l = is_sliding ? 10000.0f : cparams.rope_freq_base;
|
||||
freq_scale_l = is_sliding ? 1.0f : cparams.rope_freq_scale;
|
||||
}
|
||||
|
||||
ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
|
||||
|
||||
ggml_tensor * k =
|
||||
|
@ -546,7 +557,7 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
|
|||
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
||||
0);
|
||||
|
||||
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, kv_self->k_l[il]->buffer);
|
||||
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue