kv-cells : track min/max used cells and per-sequence positions (#13808)

* kv-cells : track min/max used cells and per-sequence positions

ggml-ci

* kv-cells : fix pos-modification updates for seq_pos

ggml-ci

* kv-cells : add comments

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-05-27 13:49:41 +03:00 committed by GitHub
parent f9cd68398b
commit 81713121ee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 124 additions and 52 deletions

View file

@ -286,31 +286,11 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
}
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
llama_pos result = std::numeric_limits<llama_pos>::max();
for (uint32_t i = 0; i < cells.size(); ++i) {
if (cells.seq_has(i, seq_id)) {
result = std::min(result, cells.pos_get(i));
}
}
if (result == std::numeric_limits<llama_pos>::max()) {
result = -1;
}
return result;
return cells.seq_pos_min(seq_id);
}
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
llama_pos result = -1;
for (uint32_t i = 0; i < cells.size(); ++i) {
if (cells.seq_has(i, seq_id)) {
result = std::max(result, cells.pos_get(i));
}
}
return result;
return cells.seq_pos_max(seq_id);
}
void llama_kv_cache_unified::restore() {
@ -504,7 +484,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cell_max(), n_pad)));
n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
#ifdef FIND_SLOT_DEBUG
LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
@ -1018,7 +998,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
const uint32_t n_layer = layers.size();
const uint32_t n_kv = cell_max();
const uint32_t n_kv = cells.used_max_p1();
const uint32_t n_used = cells.get_used();
assert(n_used <= n_kv);
@ -1144,16 +1124,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
return true;
}
uint32_t llama_kv_cache_unified::cell_max() const {
for (uint32_t i = cells.size(); i > 0; --i) {
if (!cells.is_empty(i - 1)) {
return i;
}
}
return 0;
}
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
assert(p0 >= 0 && p1 >= 0);