kv-cache : add SWA support (#13194)

* kv-cache : prepare for SWA

ggml-ci

* kv-cache : initial iSWA implementation

ggml-ci

* kv-cache : rework error recovery logic

ggml-ci

* models : fix Phi-3 SWA parameters

ggml-ci

* model : adjust Granite to rope factor changes

ggml-ci

* server : check if context can do shifts

ggml-ci

* iswa : for now, always enable shifts (experiment)

ggml-ci

* kv-cache : simplify SWA logic

ggml-ci

* kv-cache : apply defrag when we fail to find slots for the batch

ggml-ci

* llama : update docs about llama_decode

ggml-ci

* kv-cache : update warning logs when no space for the batch is available

ggml-ci

* llama : add llama_kv_self_seq_pos_min()

* kv-cache : keep track of partial SWA computes and print warnings

* server : disallow use cases involving partial SWA context

ggml-ci

* llama : add param to control SWA cache size

ggml-ci

* minor : clean-up

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-05-20 08:05:46 +03:00 committed by GitHub
parent f0adb80bf7
commit e298d2fbd0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 1426 additions and 650 deletions

View file

@ -93,6 +93,7 @@ llama_context::llama_context(
}
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
cparams.op_offload = params.op_offload;
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
@ -176,8 +177,9 @@ llama_context::llama_context(
// init the memory module
if (!hparams.vocab_only) {
llama_memory_params params_mem = {
/*.type_k =*/ params.type_k,
/*.type_v =*/ params.type_v,
/*.type_k =*/ params.type_k,
/*.type_v =*/ params.type_v,
/*.swa_full =*/ params.swa_full,
};
memory.reset(model.create_memory(params_mem, cparams));
@ -947,8 +949,6 @@ int llama_context::decode(llama_batch & inp_batch) {
// find KV slot
if (!kv_self->find_slot(ubatch)) {
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
return 1;
}
@ -2093,6 +2093,7 @@ llama_context_params llama_context_default_params() {
/*.flash_attn =*/ false,
/*.no_perf =*/ true,
/*.op_offload =*/ true,
/*.swa_full =*/ true,
};
return result;
@ -2467,6 +2468,15 @@ void llama_kv_self_seq_div(
kv->seq_div(seq_id, p0, p1, d);
}
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
const auto * kv = ctx->get_kv_self();
if (!kv) {
return -1;
}
return kv->seq_pos_min(seq_id);
}
// deprecated
llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_self_seq_pos_max(ctx, seq_id);
@ -2475,7 +2485,7 @@ llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
const auto * kv = ctx->get_kv_self();
if (!kv) {
return 0;
return -1;
}
return kv->seq_pos_max(seq_id);
@ -2637,7 +2647,21 @@ int32_t llama_encode(
int32_t llama_decode(
llama_context * ctx,
llama_batch batch) {
const int ret = ctx->decode(batch);
int ret = ctx->decode(batch);
// defrag and try again
// TODO: distinguish return code when we are sure that even after defrag there is no space available
if (ret == 1) {
llama_kv_self_defrag(ctx);
ret = ctx->decode(batch);
if (ret == 1) {
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
return ret;
}
}
if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
}