kv-cache : refactor + add llama_memory_state_i (#13746)
* kv-cache : simplify the "struct llama_kv_cache" interface ggml-ci * kv-cache : revert the (n_swa + n_ubatch) change (for next PR) ggml-ci * kv-cache : some comments ggml-ci * context : fix graph reserve for multiple sequences ggml-ci * kv-cache : fix typo [no ci] * kv-cache : fix find_slot() logic for free slots ggml-ci * llama : add TODO for deprecating the defrag API in the future * kv-cache : improve find_slot() using min/max seq pos info ggml-ci * llama : handle aborts and compute errors ggml-ci * memory : extract state into llama_memory_state ggml-ci * kv-cache : add comments ggml-ci * server : update batching logic to reset n_batch on successful decode * server : upon full re-processing, remove the sequence from the cache * kv-cache : add TODO for doing split_equal when split_simple fails ggml-ci
This commit is contained in:
parent
eb3949938e
commit
12d0188c0d
14 changed files with 1304 additions and 655 deletions
|
@ -259,9 +259,9 @@ extern "C" {
|
|||
llama_token * token;
|
||||
float * embd;
|
||||
llama_pos * pos;
|
||||
int32_t * n_seq_id;
|
||||
llama_seq_id ** seq_id;
|
||||
int8_t * logits; // TODO: rename this to "output"
|
||||
int32_t * n_seq_id; // TODO: remove, should belong to only 1 sequence
|
||||
llama_seq_id ** seq_id; // TODO: become llama_seq_id * seq_id;
|
||||
int8_t * logits; // TODO: rename this to "output"
|
||||
} llama_batch;
|
||||
|
||||
enum llama_model_kv_override_type {
|
||||
|
@ -677,12 +677,14 @@ extern "C" {
|
|||
|
||||
// Returns the smallest position present in the KV cache for the specified sequence
|
||||
// This is typically non-zero only for SWA caches
|
||||
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
|
||||
// Return -1 if the sequence is empty
|
||||
LLAMA_API llama_pos llama_kv_self_seq_pos_min(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Returns the largest position present in the KV cache for the specified sequence
|
||||
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
|
||||
// Return -1 if the sequence is empty
|
||||
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
|
||||
struct llama_context * ctx,
|
||||
|
@ -692,12 +694,14 @@ extern "C" {
|
|||
// This will be applied:
|
||||
// - lazily on next llama_decode()
|
||||
// - explicitly with llama_kv_self_update()
|
||||
// TODO: deprecate and always update the cache lazily [TAG: API_KV_NO_DEFRAG]
|
||||
LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx);
|
||||
|
||||
// Check if the context supports KV cache shifting
|
||||
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
|
||||
|
||||
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
||||
// TODO: deprecate and always update the cache lazily [TAG: API_KV_NO_DEFRAG]
|
||||
LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
|
||||
|
||||
//
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue