kv-cache : refactor + add llama_memory_state_i (#13746)

* kv-cache : simplify the "struct llama_kv_cache" interface

ggml-ci

* kv-cache : revert the (n_swa + n_ubatch) change (for next PR)

ggml-ci

* kv-cache : some comments

ggml-ci

* context : fix graph reserve for multiple sequences

ggml-ci

* kv-cache : fix typo [no ci]

* kv-cache : fix find_slot() logic for free slots

ggml-ci

* llama : add TODO for deprecating the defrag API in the future

* kv-cache : improve find_slot() using min/max seq pos info

ggml-ci

* llama : handle aborts and compute errors

ggml-ci

* memory : extract state into llama_memory_state

ggml-ci

* kv-cache : add comments

ggml-ci

* server : update batching logic to reset n_batch on successful decode

* server : upon full re-processing, remove the sequence from the cache

* kv-cache : add TODO for doing split_equal when split_simple fails

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-05-31 10:24:04 +03:00 committed by GitHub
parent eb3949938e
commit 12d0188c0d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 1304 additions and 655 deletions

View file

@ -8892,9 +8892,9 @@ struct llm_build_mamba : public llm_graph_context {
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const {
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
const auto kv_head = kv_self->head;
const auto kv_head = kv_state->get_head();
const int64_t d_conv = hparams.ssm_d_conv;
const int64_t d_inner = hparams.ssm_d_inner;
@ -8912,8 +8912,8 @@ struct llm_build_mamba : public llm_graph_context {
GGML_ASSERT(ubatch.equal_seqs);
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
ggml_tensor * conv_states_all = kv_self->k_l[il];
ggml_tensor * ssm_states_all = kv_self->v_l[il];
ggml_tensor * conv_states_all = kv_state->get_k_l(il);
ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
// (ab)using the KV cache to store the states
ggml_tensor * conv = build_copy_mask_state(
@ -11640,7 +11640,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const {
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
const auto n_tokens = ubatch.n_tokens;
const auto n_seqs = ubatch.n_seqs;
@ -11650,7 +11650,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
const auto n_head = n_embd / head_size;
const auto n_head_kv = hparams.n_head_kv(il);
const auto kv_head = kv_self->head;
const auto kv_head = kv_state->get_head();
const auto & layer = model.layers[il];
@ -11762,7 +11762,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
}
ggml_tensor * wkv_state = build_copy_mask_state(
gf, kv_self->v_l[il], state_copy, state_mask,
gf, kv_state->get_v_l(il), state_copy, state_mask,
hparams.n_embd_v_s(), n_seqs);
ggml_tensor * wkv_output;
@ -11781,9 +11781,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
wkv_state,
ggml_view_1d(
ctx0,
kv_self->v_l[il],
kv_state->get_v_l(il),
hparams.n_embd_v_s() * n_seqs,
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il])
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il))
)
)
);
@ -12036,7 +12036,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
ggml_tensor *& first_layer_value,
const llama_ubatch & ubatch,
int il) const {
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
const auto n_tokens = ubatch.n_tokens;
const auto n_seqs = ubatch.n_seqs;
@ -12045,7 +12045,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
const auto head_count = n_embd / head_size;
const auto n_seq_tokens = ubatch.n_seq_tokens;
const auto kv_head = kv_self->head;
const auto kv_head = kv_state->get_head();
const auto & layer = model.layers[il];
@ -12116,7 +12116,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
ggml_tensor * wkv_state = build_copy_mask_state(
gf, kv_self->v_l[il], state_copy, state_mask,
gf, kv_state->get_v_l(il), state_copy, state_mask,
hparams.n_embd_v_s(), n_seqs);
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
@ -12130,9 +12130,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
wkv_state,
ggml_view_1d(
ctx0,
kv_self->v_l[il],
kv_state->get_v_l(il),
hparams.n_embd_v_s() * n_seqs,
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il])
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il))
)
)
);