llama : fix Gemma3 SWA KV cache shift (#12373)
* llama : fix Gemma3 SWA KV cache shift ggml-ci * hparams : add comment [no ci]
This commit is contained in:
parent
be7c303410
commit
84d5475541
6 changed files with 37 additions and 43 deletions
|
@ -1403,34 +1403,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
|
||||
}
|
||||
|
||||
// TODO: improve
|
||||
bool is_sliding = false;
|
||||
|
||||
switch (arch) {
|
||||
case LLM_ARCH_COHERE2:
|
||||
{
|
||||
const int32_t sliding_window_pattern = 4;
|
||||
is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA2:
|
||||
{
|
||||
const int32_t sliding_window_pattern = 2;
|
||||
is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA3:
|
||||
{
|
||||
const int32_t sliding_window_pattern = 6;
|
||||
is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
|
||||
} break;
|
||||
case LLM_ARCH_PHI3:
|
||||
{
|
||||
is_sliding = hparams.n_swa > 0;
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
is_sliding = false;
|
||||
}
|
||||
};
|
||||
const bool is_sliding = hparams.is_sliding(il);
|
||||
|
||||
const auto & kq_mask = is_sliding ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue