llama : refactor kv cache guard (#12695)
* llama : refactor kv cache guard ggml-ci * cont : fix comment [no ci] * llama : fix kv_cache restore logic ggml-ci * context : simplify kv cache updates ggml-ci * cont : better name [no ci] * llama : fix llama_decode return code when could not find KV slot ggml-ci * context : change log err -> warn [no ci] * kv-cache : add comment + warning
This commit is contained in:
parent
83a88bd6af
commit
a10b36c91a
4 changed files with 107 additions and 127 deletions
|
@ -11,8 +11,6 @@
|
|||
#include <map>
|
||||
#include <stdexcept>
|
||||
|
||||
static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
|
||||
|
||||
llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
|
||||
}
|
||||
|
||||
|
@ -206,6 +204,8 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
|
@ -446,16 +446,66 @@ void llama_kv_cache_unified::defrag() {
|
|||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::restore() {
|
||||
if (pending.ranges.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: tmp - move to llama_kv_cache_recurrent
|
||||
if (recurrent) {
|
||||
seq_rm(-1, -1, -1);
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t new_head = size;
|
||||
|
||||
for (auto & range : pending.ranges) {
|
||||
for (uint32_t i = range.c0; i < range.c1; ++i) {
|
||||
cells[i].seq_id.clear();
|
||||
|
||||
// keep count of the number of used cells
|
||||
if (cells[i].pos >= 0) {
|
||||
used--;
|
||||
}
|
||||
|
||||
cells[i].pos = -1;
|
||||
cells[i].src = -1;
|
||||
}
|
||||
|
||||
new_head = std::min(new_head, range.c0);
|
||||
}
|
||||
|
||||
if (new_head != size && new_head < head) {
|
||||
head = new_head;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::commit() {
|
||||
if (pending.ranges.empty()) {
|
||||
LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
|
||||
__func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
|
||||
return;
|
||||
}
|
||||
|
||||
pending.ranges.clear();
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::get_can_shift() const {
|
||||
return can_shift;
|
||||
}
|
||||
|
||||
llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
|
||||
bool llama_kv_cache_unified::find_slot(
|
||||
const llama_ubatch & ubatch) {
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
const uint32_t n_seqs = ubatch.n_seqs;
|
||||
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (head > used + 2*ubatch.n_tokens) {
|
||||
head = 0;
|
||||
}
|
||||
|
||||
if (recurrent) {
|
||||
// For recurrent state architectures (like Mamba or RWKV),
|
||||
// each cache cell can store the state for a whole sequence.
|
||||
|
@ -477,7 +527,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
|
|||
// too big seq_id
|
||||
// TODO: would it be possible to resize the cache instead?
|
||||
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
|
||||
return llama_kv_cache_slot_info_failed;
|
||||
return false;
|
||||
}
|
||||
if (j > 0) {
|
||||
llama_kv_cell & seq = cells[seq_id];
|
||||
|
@ -616,14 +666,14 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
|
|||
[](const llama_kv_cell& cell){ return !cell.is_empty(); });
|
||||
|
||||
// sanity check
|
||||
return llama_kv_cache_slot_info(n >= n_seqs);
|
||||
return n >= n_seqs;
|
||||
}
|
||||
|
||||
// otherwise, one cell per token.
|
||||
|
||||
if (n_tokens > size) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
|
||||
return llama_kv_cache_slot_info_failed;
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t n_tested = 0;
|
||||
|
@ -651,7 +701,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
|
|||
|
||||
if (n_tested >= size) {
|
||||
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
||||
return llama_kv_cache_slot_info_failed;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -668,7 +718,9 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
|
|||
|
||||
used += n_tokens;
|
||||
|
||||
return llama_kv_cache_slot_info(head, head + n_tokens);
|
||||
pending.ranges.push_back({head, head + n_tokens});
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const {
|
||||
|
@ -1033,6 +1085,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|||
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
||||
return false;
|
||||
}
|
||||
commit();
|
||||
|
||||
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
||||
// Assume that this is one contiguous block of cells
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue