memory : migrate from llama_kv_cache to more generic llama_memory (#14006)

* memory : merge llama_kv_cache into llama_memory + new `llama_memory` API

ggml-ci

* context : fix casts

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-06-05 15:29:22 +03:00 committed by GitHub
parent 3a077146a4
commit 7f37b6cf1e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 324 additions and 220 deletions

View file

@ -2,7 +2,7 @@
#include "llama-batch.h"
#include "llama-graph.h"
#include "llama-kv-cache.h"
#include "llama-memory.h"
#include <set>
#include <vector>
@ -13,7 +13,7 @@
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
class llama_kv_cache_recurrent : public llama_kv_cache {
class llama_kv_cache_recurrent : public llama_memory_i {
public:
llama_kv_cache_recurrent(
const llama_model & model,
@ -29,6 +29,16 @@ public:
// llama_memory_i
//
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) override;
llama_memory_state_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
void clear() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
@ -40,20 +50,6 @@ public:
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
//
// llama_kv_cache
//
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) override;
llama_memory_state_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
bool prepare(const std::vector<llama_ubatch> & ubatches);
// find a contiguous slot of kv cells and emplace the ubatch there