kv-cache : simplify + fix warning for recurrent models (#12756)

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-04-04 21:48:10 +03:00 committed by GitHub
parent 1be76e4620
commit 3e1d29348b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 80 additions and 173 deletions

View file

@ -131,7 +131,7 @@ int32_t llama_kv_cache_unified::get_n_tokens() const {
return result;
}
uint32_t llama_kv_cache_unified::get_used_cells() const {
int32_t llama_kv_cache_unified::get_used_cells() const {
return used;
}
@ -428,7 +428,7 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
}
}
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) {
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
llama_pos result = 0;
for (uint32_t i = 0; i < size; ++i) {
@ -481,6 +481,11 @@ void llama_kv_cache_unified::restore() {
}
void llama_kv_cache_unified::commit() {
// TODO: tmp - move to llama_kv_cache_recurrent
if (recurrent) {
return;
}
if (pending.ranges.empty()) {
LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
__func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
@ -1273,117 +1278,6 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
return true;
}
//
// interface implementation
//
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv) {
if (!kv) {
return 0;
}
return kv->get_n_tokens();
}
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
if (!kv) {
return 0;
}
return kv->get_used_cells();
}
void llama_kv_cache_clear(llama_kv_cache * kv) {
if (!kv) {
return;
}
kv->clear();
}
bool llama_kv_cache_seq_rm(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
if (!kv) {
return true;
}
return kv->seq_rm(seq_id, p0, p1);
}
void llama_kv_cache_seq_cp(
llama_kv_cache * kv,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
if (!kv) {
return;
}
kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
}
void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id) {
if (!kv) {
return;
}
kv->seq_keep(seq_id);
}
void llama_kv_cache_seq_add(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta) {
if (!kv) {
return;
}
kv->seq_add(seq_id, p0, p1, delta);
}
void llama_kv_cache_seq_div(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d) {
if (!kv) {
return;
}
kv->seq_div(seq_id, p0, p1, d);
}
llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id) {
if (!kv) {
return 0;
}
return kv->seq_pos_max(seq_id);
}
void llama_kv_cache_defrag(llama_kv_cache * kv) {
if (!kv) {
return;
}
kv->defrag();
}
bool llama_kv_cache_can_shift(const llama_kv_cache * kv) {
if (!kv) {
return false;
}
return kv->get_can_shift();
}
//
// kv cache view
//
@ -1393,7 +1287,7 @@ llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t
/*.n_cells = */ 0,
/*.n_seq_max = */ n_seq_max,
/*.token_count = */ 0,
/*.used_cells = */ llama_kv_cache_used_cells(&kv),
/*.used_cells = */ kv.get_used_cells(),
/*.max_contiguous = */ 0,
/*.max_contiguous_idx = */ -1,
/*.cells = */ nullptr,