memory : migrate from llama_kv_cache to more generic llama_memory (#14006)
* memory : merge llama_kv_cache into llama_memory + new `llama_memory` API ggml-ci * context : fix casts ggml-ci
This commit is contained in:
parent
3a077146a4
commit
7f37b6cf1e
11 changed files with 324 additions and 220 deletions
100
include/llama.h
100
include/llama.h
|
@ -61,7 +61,10 @@ extern "C" {
|
|||
struct llama_model;
|
||||
struct llama_context;
|
||||
struct llama_sampler;
|
||||
struct llama_kv_cache;
|
||||
|
||||
typedef struct llama_memory_i * llama_memory_t;
|
||||
|
||||
struct llama_kv_cache; // DEPRECATED (use llama_memory instead)
|
||||
|
||||
typedef int32_t llama_pos;
|
||||
typedef int32_t llama_token;
|
||||
|
@ -493,9 +496,11 @@ extern "C" {
|
|||
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
|
||||
|
||||
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
|
||||
LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx);
|
||||
LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx);
|
||||
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
|
||||
|
||||
DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
|
||||
|
||||
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
|
||||
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
|
||||
|
||||
|
@ -609,7 +614,78 @@ extern "C" {
|
|||
int32_t il_end);
|
||||
|
||||
//
|
||||
// KV cache
|
||||
// Memory
|
||||
//
|
||||
|
||||
// Clear the memory contents
|
||||
LLAMA_API void llama_memory_clear(llama_memory_t mem);
|
||||
|
||||
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
||||
// seq_id < 0 : match any sequence
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API bool llama_memory_seq_rm(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1);
|
||||
|
||||
// Copy all tokens that belong to the specified sequence to another sequence
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_memory_seq_cp(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1);
|
||||
|
||||
// Removes all tokens that do not belong to the specified sequence
|
||||
LLAMA_API void llama_memory_seq_keep(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_memory_seq_add(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta);
|
||||
|
||||
// Integer division of the positions by factor of `d > 1`
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_memory_seq_div(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d);
|
||||
|
||||
// Returns the smallest position present in the memory for the specified sequence
|
||||
// This is typically non-zero only for SWA caches
|
||||
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
|
||||
// Return -1 if the sequence is empty
|
||||
LLAMA_API llama_pos llama_memory_seq_pos_min(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Returns the largest position present in the memory for the specified sequence
|
||||
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
|
||||
// Return -1 if the sequence is empty
|
||||
LLAMA_API llama_pos llama_memory_seq_pos_max(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Check if the memory supports shifting
|
||||
LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);
|
||||
|
||||
//
|
||||
// KV cache for self-attention (TODO: deprecate in favor of llama_memory)
|
||||
//
|
||||
|
||||
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
||||
|
@ -623,7 +699,7 @@ extern "C" {
|
|||
|
||||
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
||||
LLAMA_API void llama_kv_self_clear(
|
||||
struct llama_context * ctx);
|
||||
struct llama_context * ctx);
|
||||
|
||||
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
||||
|
@ -694,14 +770,14 @@ extern "C" {
|
|||
// Defragment the KV cache
|
||||
// This will be applied:
|
||||
// - lazily on next llama_decode()
|
||||
LLAMA_API DEPRECATED(void llama_kv_self_defrag(struct llama_context * ctx),
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx),
|
||||
"simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
|
||||
|
||||
// Check if the context supports KV cache shifting
|
||||
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
|
||||
|
||||
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
||||
LLAMA_API DEPRECATED(void llama_kv_self_update(struct llama_context * ctx),
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx),
|
||||
"simply remove this call, updates are applied lazily on the next llama_decode()");
|
||||
|
||||
//
|
||||
|
@ -709,7 +785,7 @@ extern "C" {
|
|||
//
|
||||
|
||||
// Returns the *actual* size in bytes of the state
|
||||
// (logits, embedding and kv_cache)
|
||||
// (logits, embedding and memory)
|
||||
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
|
||||
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
|
||||
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
|
||||
|
@ -765,12 +841,12 @@ extern "C" {
|
|||
size_t n_token_count),
|
||||
"use llama_state_save_file instead");
|
||||
|
||||
// Get the exact size needed to copy the KV cache of a single sequence
|
||||
// Get the exact size needed to copy the state of a single sequence
|
||||
LLAMA_API size_t llama_state_seq_get_size(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Copy the KV cache of a single sequence into the specified buffer
|
||||
// Copy the state of a single sequence into the specified buffer
|
||||
LLAMA_API size_t llama_state_seq_get_data(
|
||||
struct llama_context * ctx,
|
||||
uint8_t * dst,
|
||||
|
@ -836,16 +912,16 @@ extern "C" {
|
|||
// For encode-decoder contexts, processes the batch using the encoder.
|
||||
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
|
||||
// 0 - success
|
||||
// < 0 - error. the KV cache state is restored to the state before this call
|
||||
// < 0 - error. the memory state is restored to the state before this call
|
||||
LLAMA_API int32_t llama_encode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch);
|
||||
|
||||
// Process a batch of tokens.
|
||||
// Requires KV cache.
|
||||
// Requires the context to have a memory.
|
||||
// For encode-decoder contexts, processes the batch using the decoder.
|
||||
// Positive return values does not mean a fatal error, but rather a warning.
|
||||
// Upon non-zero return values, the KV cache state is restored to the state before this call
|
||||
// Upon non-zero return values, the memory state is restored to the state before this call
|
||||
// 0 - success
|
||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||
// 2 - aborted
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue