kv-cache : add SWA support (#13194)
* kv-cache : prepare for SWA ggml-ci * kv-cache : initial iSWA implementation ggml-ci * kv-cache : rework error recovery logic ggml-ci * models : fix Phi-3 SWA parameters ggml-ci * model : adjust Granite to rope factor changes ggml-ci * server : check if context can do shifts ggml-ci * iswa : for now, always enable shifts (experiment) ggml-ci * kv-cache : simplify SWA logic ggml-ci * kv-cache : apply defrag when we fail to find slots for the batch ggml-ci * llama : update docs about llama_decode ggml-ci * kv-cache : update warning logs when no space for the batch is available ggml-ci * llama : add llama_kv_self_seq_pos_min() * kv-cache : keep track of partial SWA computes and print warnings * server : disallow use cases involving partial SWA context ggml-ci * llama : add param to control SWA cache size ggml-ci * minor : clean-up ggml-ci
This commit is contained in:
parent
f0adb80bf7
commit
e298d2fbd0
15 changed files with 1426 additions and 650 deletions
|
@ -14,6 +14,12 @@ enum llama_expert_gating_func_type {
|
|||
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
|
||||
};
|
||||
|
||||
enum llama_swa_type {
|
||||
LLAMA_SWA_TYPE_NONE = 0,
|
||||
LLAMA_SWA_TYPE_STANDARD = 1,
|
||||
LLAMA_SWA_TYPE_CHUNKED = 2,
|
||||
};
|
||||
|
||||
struct llama_hparams_posnet {
|
||||
uint32_t n_embd;
|
||||
uint32_t n_layer;
|
||||
|
@ -35,8 +41,6 @@ struct llama_hparams {
|
|||
uint32_t n_embd_features = 0;
|
||||
uint32_t n_layer;
|
||||
uint32_t n_rot;
|
||||
uint32_t n_swa = 0; // sliding window attention (SWA)
|
||||
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
|
||||
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
||||
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
|
||||
uint32_t n_expert = 0;
|
||||
|
@ -96,6 +100,12 @@ struct llama_hparams {
|
|||
|
||||
std::array<int, 4> rope_sections;
|
||||
|
||||
// Sliding Window Attention (SWA)
|
||||
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
|
||||
uint32_t n_swa = 0; // the size of the sliding window (0 - no SWA)
|
||||
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
|
||||
|
||||
// for State Space Models
|
||||
uint32_t ssm_d_conv = 0;
|
||||
uint32_t ssm_d_inner = 0;
|
||||
|
@ -116,11 +126,10 @@ struct llama_hparams {
|
|||
bool causal_attn = true;
|
||||
bool use_alibi = false;
|
||||
bool attn_soft_cap = false;
|
||||
bool use_kq_norm = true;
|
||||
|
||||
// llama4
|
||||
uint32_t n_moe_layer_step = 0;
|
||||
bool use_kq_norm = true;
|
||||
uint32_t n_attn_chunk = 0;
|
||||
// values below seems to be fixed on llama4
|
||||
uint32_t n_no_rope_layer_step = 4;
|
||||
uint32_t n_attn_temp_floor_scale = 8192;
|
||||
float f_attn_temp_scale = 0.1;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue