llama : refactor kv cache guard (#12695)

* llama : refactor kv cache guard

ggml-ci

* cont : fix comment [no ci]

* llama : fix kv_cache restore logic

ggml-ci

* context : simplify kv cache updates

ggml-ci

* cont : better name [no ci]

* llama : fix llama_decode return code when could not find KV slot

ggml-ci

* context : change log err -> warn [no ci]

* kv-cache : add comment + warning
This commit is contained in:
Georgi Gerganov 2025-04-02 14:32:59 +03:00 committed by GitHub
parent 83a88bd6af
commit a10b36c91a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 107 additions and 127 deletions

View file

@ -1201,33 +1201,7 @@ int llama_context::decode(llama_batch & inp_batch) {
const int64_t n_tokens_all = batch.n_tokens;
const int64_t n_embd = hparams.n_embd;
// TODO: remove this stuff
class batch_guard {
public:
batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) {
}
~batch_guard() {
if (!is_done) {
kv_slot_restorer.restore();
}
}
void done() {
is_done = true;
}
void save(const llama_kv_cache_slot_info & slot_info) {
kv_slot_restorer.save(slot_info);
}
private:
bool is_done = false;
llama_kv_slot_restorer kv_slot_restorer;
};
batch_guard bg(*kv_self);
llama_kv_cache_guard kv_guard(kv_self.get());
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
@ -1280,6 +1254,9 @@ int llama_context::decode(llama_batch & inp_batch) {
return -2;
};
// handle any pending defrags/shifts
kv_self_update();
int64_t n_outputs_prev = 0;
while (sbatch.n_tokens > 0) {
@ -1319,22 +1296,12 @@ int llama_context::decode(llama_batch & inp_batch) {
// find KV slot
{
kv_self_update();
if (!kv_self->find_slot(ubatch)) {
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) {
kv_self->head = 0;
return 1;
}
const auto slot_info = kv_self->find_slot(ubatch);
if (!slot_info) {
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
return -3;
}
bg.save(slot_info);
if (!kv_self->recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
@ -1371,16 +1338,6 @@ int llama_context::decode(llama_batch & inp_batch) {
}
}
// update the kv ring buffer
{
kv_self->head += ubatch.n_tokens;
// Ensure kv cache head points to a valid index.
if (kv_self->head >= kv_self->size) {
kv_self->head = 0;
}
}
// plot the computation graph in dot format (for debugging purposes)
//if (n_past%100 == 0) {
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
@ -1467,7 +1424,7 @@ int llama_context::decode(llama_batch & inp_batch) {
}
// finalize the batch processing
bg.done();
kv_guard.commit();
// set output mappings
{