kv-cache : simplify + fix warning for recurrent models (#12756)
ggml-ci
This commit is contained in:
parent
1be76e4620
commit
3e1d29348b
4 changed files with 80 additions and 173 deletions
|
@ -2474,7 +2474,12 @@ int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
||||||
return llama_kv_cache_n_tokens(ctx->get_kv_self());
|
const auto * kv = ctx->get_kv_self();
|
||||||
|
if (!kv) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv->get_n_tokens();
|
||||||
}
|
}
|
||||||
|
|
||||||
// deprecated
|
// deprecated
|
||||||
|
@ -2483,7 +2488,12 @@ int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
||||||
return llama_kv_cache_used_cells(ctx->get_kv_self());
|
const auto * kv = ctx->get_kv_self();
|
||||||
|
if (!kv) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv->get_used_cells();
|
||||||
}
|
}
|
||||||
|
|
||||||
// deprecated
|
// deprecated
|
||||||
|
@ -2492,7 +2502,12 @@ void llama_kv_cache_clear(llama_context * ctx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_self_clear(llama_context * ctx) {
|
void llama_kv_self_clear(llama_context * ctx) {
|
||||||
llama_kv_cache_clear(ctx->get_kv_self());
|
auto * kv = ctx->get_kv_self();
|
||||||
|
if (!kv) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
kv->clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
// deprecated
|
// deprecated
|
||||||
|
@ -2509,7 +2524,12 @@ bool llama_kv_self_seq_rm(
|
||||||
llama_seq_id seq_id,
|
llama_seq_id seq_id,
|
||||||
llama_pos p0,
|
llama_pos p0,
|
||||||
llama_pos p1) {
|
llama_pos p1) {
|
||||||
return llama_kv_cache_seq_rm(ctx->get_kv_self(), seq_id, p0, p1);
|
auto * kv = ctx->get_kv_self();
|
||||||
|
if (!kv) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv->seq_rm(seq_id, p0, p1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// deprecated
|
// deprecated
|
||||||
|
@ -2528,7 +2548,12 @@ void llama_kv_self_seq_cp(
|
||||||
llama_seq_id seq_id_dst,
|
llama_seq_id seq_id_dst,
|
||||||
llama_pos p0,
|
llama_pos p0,
|
||||||
llama_pos p1) {
|
llama_pos p1) {
|
||||||
return llama_kv_cache_seq_cp(ctx->get_kv_self(), seq_id_src, seq_id_dst, p0, p1);
|
auto * kv = ctx->get_kv_self();
|
||||||
|
if (!kv) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// deprecated
|
// deprecated
|
||||||
|
@ -2539,7 +2564,12 @@ void llama_kv_cache_seq_keep(
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
||||||
return llama_kv_cache_seq_keep(ctx->get_kv_self(), seq_id);
|
auto * kv = ctx->get_kv_self();
|
||||||
|
if (!kv) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv->seq_keep(seq_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
// deprecated
|
// deprecated
|
||||||
|
@ -2558,7 +2588,12 @@ void llama_kv_self_seq_add(
|
||||||
llama_pos p0,
|
llama_pos p0,
|
||||||
llama_pos p1,
|
llama_pos p1,
|
||||||
llama_pos delta) {
|
llama_pos delta) {
|
||||||
return llama_kv_cache_seq_add(ctx->get_kv_self(), seq_id, p0, p1, delta);
|
auto * kv = ctx->get_kv_self();
|
||||||
|
if (!kv) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv->seq_add(seq_id, p0, p1, delta);
|
||||||
}
|
}
|
||||||
|
|
||||||
// deprecated
|
// deprecated
|
||||||
|
@ -2577,7 +2612,12 @@ void llama_kv_self_seq_div(
|
||||||
llama_pos p0,
|
llama_pos p0,
|
||||||
llama_pos p1,
|
llama_pos p1,
|
||||||
int d) {
|
int d) {
|
||||||
return llama_kv_cache_seq_div(ctx->get_kv_self(), seq_id, p0, p1, d);
|
auto * kv = ctx->get_kv_self();
|
||||||
|
if (!kv) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv->seq_div(seq_id, p0, p1, d);
|
||||||
}
|
}
|
||||||
|
|
||||||
// deprecated
|
// deprecated
|
||||||
|
@ -2586,7 +2626,12 @@ llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
||||||
return llama_kv_cache_seq_pos_max(ctx->get_kv_self(), seq_id);
|
const auto * kv = ctx->get_kv_self();
|
||||||
|
if (!kv) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv->seq_pos_max(seq_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
// deprecated
|
// deprecated
|
||||||
|
@ -2595,7 +2640,12 @@ void llama_kv_cache_defrag(llama_context * ctx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_self_defrag(llama_context * ctx) {
|
void llama_kv_self_defrag(llama_context * ctx) {
|
||||||
llama_kv_cache_defrag(ctx->get_kv_self());
|
auto * kv = ctx->get_kv_self();
|
||||||
|
if (!kv) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv->defrag();
|
||||||
}
|
}
|
||||||
|
|
||||||
// deprecated
|
// deprecated
|
||||||
|
@ -2604,7 +2654,12 @@ bool llama_kv_cache_can_shift(const llama_context * ctx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
||||||
return llama_kv_cache_can_shift(ctx->get_kv_self());
|
const auto * kv = ctx->get_kv_self();
|
||||||
|
if (!kv) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv->get_can_shift();
|
||||||
}
|
}
|
||||||
|
|
||||||
// deprecated
|
// deprecated
|
||||||
|
|
|
@ -131,7 +131,7 @@ int32_t llama_kv_cache_unified::get_n_tokens() const {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_kv_cache_unified::get_used_cells() const {
|
int32_t llama_kv_cache_unified::get_used_cells() const {
|
||||||
return used;
|
return used;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -428,7 +428,7 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) {
|
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
||||||
llama_pos result = 0;
|
llama_pos result = 0;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
for (uint32_t i = 0; i < size; ++i) {
|
||||||
|
@ -481,6 +481,11 @@ void llama_kv_cache_unified::restore() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::commit() {
|
void llama_kv_cache_unified::commit() {
|
||||||
|
// TODO: tmp - move to llama_kv_cache_recurrent
|
||||||
|
if (recurrent) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (pending.ranges.empty()) {
|
if (pending.ranges.empty()) {
|
||||||
LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
|
LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
|
||||||
__func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
|
__func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
|
||||||
|
@ -1273,117 +1278,6 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
// interface implementation
|
|
||||||
//
|
|
||||||
|
|
||||||
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv) {
|
|
||||||
if (!kv) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
return kv->get_n_tokens();
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
|
|
||||||
if (!kv) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
return kv->get_used_cells();
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_kv_cache_clear(llama_kv_cache * kv) {
|
|
||||||
if (!kv) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
kv->clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool llama_kv_cache_seq_rm(
|
|
||||||
llama_kv_cache * kv,
|
|
||||||
llama_seq_id seq_id,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1) {
|
|
||||||
if (!kv) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return kv->seq_rm(seq_id, p0, p1);
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_kv_cache_seq_cp(
|
|
||||||
llama_kv_cache * kv,
|
|
||||||
llama_seq_id seq_id_src,
|
|
||||||
llama_seq_id seq_id_dst,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1) {
|
|
||||||
if (!kv) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id) {
|
|
||||||
if (!kv) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
kv->seq_keep(seq_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_kv_cache_seq_add(
|
|
||||||
llama_kv_cache * kv,
|
|
||||||
llama_seq_id seq_id,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1,
|
|
||||||
llama_pos delta) {
|
|
||||||
if (!kv) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
kv->seq_add(seq_id, p0, p1, delta);
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_kv_cache_seq_div(
|
|
||||||
llama_kv_cache * kv,
|
|
||||||
llama_seq_id seq_id,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1,
|
|
||||||
int d) {
|
|
||||||
if (!kv) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
kv->seq_div(seq_id, p0, p1, d);
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id) {
|
|
||||||
if (!kv) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
return kv->seq_pos_max(seq_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_kv_cache_defrag(llama_kv_cache * kv) {
|
|
||||||
if (!kv) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
kv->defrag();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool llama_kv_cache_can_shift(const llama_kv_cache * kv) {
|
|
||||||
if (!kv) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return kv->get_can_shift();
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// kv cache view
|
// kv cache view
|
||||||
//
|
//
|
||||||
|
@ -1393,7 +1287,7 @@ llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t
|
||||||
/*.n_cells = */ 0,
|
/*.n_cells = */ 0,
|
||||||
/*.n_seq_max = */ n_seq_max,
|
/*.n_seq_max = */ n_seq_max,
|
||||||
/*.token_count = */ 0,
|
/*.token_count = */ 0,
|
||||||
/*.used_cells = */ llama_kv_cache_used_cells(&kv),
|
/*.used_cells = */ kv.get_used_cells(),
|
||||||
/*.max_contiguous = */ 0,
|
/*.max_contiguous = */ 0,
|
||||||
/*.max_contiguous_idx = */ -1,
|
/*.max_contiguous_idx = */ -1,
|
||||||
/*.cells = */ nullptr,
|
/*.cells = */ nullptr,
|
||||||
|
|
|
@ -21,7 +21,7 @@ struct llama_kv_cache : public llama_memory_i {
|
||||||
virtual void commit() = 0; // call after successful batch processing - clears any pending state
|
virtual void commit() = 0; // call after successful batch processing - clears any pending state
|
||||||
|
|
||||||
virtual int32_t get_n_tokens() const = 0;
|
virtual int32_t get_n_tokens() const = 0;
|
||||||
virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
|
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
|
||||||
|
|
||||||
virtual bool get_can_shift() const = 0;
|
virtual bool get_can_shift() const = 0;
|
||||||
|
|
||||||
|
@ -90,7 +90,7 @@ public:
|
||||||
bool offload);
|
bool offload);
|
||||||
|
|
||||||
int32_t get_n_tokens() const override;
|
int32_t get_n_tokens() const override;
|
||||||
uint32_t get_used_cells() const override;
|
int32_t get_used_cells() const override;
|
||||||
|
|
||||||
size_t total_size() const;
|
size_t total_size() const;
|
||||||
|
|
||||||
|
@ -109,7 +109,7 @@ public:
|
||||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
||||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||||
|
|
||||||
llama_pos seq_pos_max(llama_seq_id seq_id) override;
|
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||||
|
|
||||||
bool get_can_shift() const override;
|
bool get_can_shift() const override;
|
||||||
|
|
||||||
|
@ -204,48 +204,6 @@ private:
|
||||||
// using llama_kv_cache_unified::llama_kv_cache_unified;
|
// using llama_kv_cache_unified::llama_kv_cache_unified;
|
||||||
//};
|
//};
|
||||||
|
|
||||||
// TODO: maybe become part of the public llama_kv_cache in the future
|
|
||||||
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv);
|
|
||||||
|
|
||||||
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv);
|
|
||||||
|
|
||||||
void llama_kv_cache_clear(llama_kv_cache * kv);
|
|
||||||
|
|
||||||
bool llama_kv_cache_seq_rm(
|
|
||||||
llama_kv_cache * kv,
|
|
||||||
llama_seq_id seq_id,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1);
|
|
||||||
|
|
||||||
void llama_kv_cache_seq_cp(
|
|
||||||
llama_kv_cache * kv,
|
|
||||||
llama_seq_id seq_id_src,
|
|
||||||
llama_seq_id seq_id_dst,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1);
|
|
||||||
|
|
||||||
void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id);
|
|
||||||
|
|
||||||
void llama_kv_cache_seq_add(
|
|
||||||
llama_kv_cache * kv,
|
|
||||||
llama_seq_id seq_id,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1,
|
|
||||||
llama_pos delta);
|
|
||||||
|
|
||||||
void llama_kv_cache_seq_div(
|
|
||||||
llama_kv_cache * kv,
|
|
||||||
llama_seq_id seq_id,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1,
|
|
||||||
int d);
|
|
||||||
|
|
||||||
llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id);
|
|
||||||
|
|
||||||
void llama_kv_cache_defrag(llama_kv_cache * kv);
|
|
||||||
|
|
||||||
bool llama_kv_cache_can_shift(const llama_kv_cache * kv);
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// kv cache view
|
// kv cache view
|
||||||
//
|
//
|
||||||
|
|
|
@ -15,7 +15,7 @@ public:
|
||||||
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
|
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
|
||||||
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
||||||
|
|
||||||
virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0;
|
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
|
||||||
|
|
||||||
virtual bool get_can_edit() const = 0;
|
virtual bool get_can_edit() const = 0;
|
||||||
};
|
};
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue