kv-cache : avoid modifying recurrent cells when setting inputs (#13834)

* kv-cache : avoid modifying recurrent cells when setting inputs

* kv-cache : remove inp_s_mask

It was replaced with equivalent and simpler functionality
with rs_z (the first zeroed state) and the already-existing inp_s_copy.

* kv-cache : fix non-consecutive token pos warning for recurrent models

The problem was apparently caused by how the tail cells were swapped.

* graph : simplify logic for recurrent state copies

* kv-cache : use cell without src refs for rs_z in recurrent cache

* llama-graph : fix recurrent state copy

The `state_copy` shuffle assumes everything is moved at once,
which is not true when `states_extra` is copied back to the cache
before copying the range of states between `head` and `head + n_seqs`.
This is only a problem if any of the cells in [`head`, `head + n_seqs`)
have an `src` in [`head + n_seqs`, `head + n_kv`),
which does happen when `n_ubatch > 1` in the `llama-parallel` example.

Changing the order of the operations avoids the potential overwrite
before use, although when copies are avoided (like with Mamba2),
this will require further changes.

* llama-graph : rename n_state to state_size in build_recurrent_state

This naming should reduce confusion between the state size
and the number of states.
This commit is contained in:
compilade 2025-06-10 18:20:14 -04:00 committed by GitHub
parent 55f6b9fa65
commit dad5c44398
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 117 additions and 180 deletions

View file

@ -406,21 +406,12 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
bool success = true;
// TODO: here we have to verify that all ubatches can fit in the cells
// however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells
// during the compute of each ubatch. to reproduce, uncomment the following loop and run:
//
// $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8
//
// recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed
//
GGML_UNUSED(ubatches);
//for (const auto & ubatch : ubatches) {
// if (!find_slot(ubatch)) {
// success = false;
// break;
// }
//}
for (const auto & ubatch : ubatches) {
if (!find_slot(ubatch)) {
success = false;
break;
}
}
// restore the original state
cells = std::move(org_cells);
@ -431,14 +422,13 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
}
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
const uint32_t n_tokens = ubatch.n_tokens;
const uint32_t n_seqs = ubatch.n_seqs;
const uint32_t n_seqs = ubatch.n_seqs;
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (head > used + 2*n_tokens) {
if (head > used + 2*n_seqs) {
head = 0;
}
@ -534,16 +524,16 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
empty_cell.src = orig_cell.src;
orig_cell.seq_id.erase(seq_id);
empty_cell.seq_id.insert(seq_id); // will be overwritten
GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id
}
seq_meta.tail = next_empty_cell;
// find next empty cell
if (s + 1 < n_seqs) {
next_empty_cell += 1;
for (uint32_t i = 0; i < size; ++i) {
next_empty_cell += 1;
if (next_empty_cell >= size) { next_empty_cell -= size; }
kv_cell & cell = cells[next_empty_cell];
if (cell.is_empty()) { break; }
next_empty_cell += 1;
}
}
}
@ -553,8 +543,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
// gather and re-order
for (uint32_t s = 0; s < n_seqs; ++s) {
int32_t dst_id = s + min;
int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
const int32_t dst_id = s + min;
const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
if (dst_id != src_id) {
kv_cell & dst_cell = cells[dst_id];
kv_cell & src_cell = cells[src_id];
@ -563,12 +553,14 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
std::swap(dst_cell.src, src_cell.src);
std::swap(dst_cell.seq_id, src_cell.seq_id);
// swap tails (assuming they NEVER overlap)
for (const llama_seq_id seq_id : src_cell.seq_id) {
cells[seq_id].tail = src_id;
}
for (const llama_seq_id seq_id : dst_cell.seq_id) {
cells[seq_id].tail = dst_id;
// swap tails
for (uint32_t i = 0; i < size; ++i) {
int32_t & tail = cells[i].tail;
if (tail == src_id) {
tail = dst_id;
} else if (tail == dst_id) {
tail = src_id;
}
}
}
}
@ -576,7 +568,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
// update the pos of the used seqs
for (uint32_t s = 0; s < n_seqs; ++s) {
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
int32_t cell_id = s + min;
const int32_t cell_id = s + min;
kv_cell & cell = cells[cell_id];
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
@ -594,6 +586,38 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
}
}
// Find first cell without src refs, to use as the zero-ed state
{
// TODO: bake-in src refcounts in the cell metadata
std::vector<int32_t> refcounts(size, 0);
for (size_t i = 0; i < size; ++i) {
const int32_t src = cells[i].src;
if (src >= 0) {
refcounts[src] += 1;
}
}
rs_z = -1;
for (int i = min; i <= max; ++i) {
if (refcounts[i] == 0) {
rs_z = i;
break;
}
}
for (int i = min; i <= max; ++i) {
if (cells[i].src < 0) {
GGML_ASSERT(rs_z >= 0);
cells[i].src0 = rs_z;
} else {
// Stage the source ids for all used cells to allow correct seq_* behavior
// and still make these values available when setting the inputs
cells[i].src0 = cells[i].src;
}
cells[i].src = i; // avoid moving or clearing twice
}
}
// allow getting the range of used cells, from head to head + n
head = min;
n = max - min + 1;
@ -605,47 +629,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
}
bool llama_kv_cache_recurrent::get_can_shift() const {
return false;
}
int32_t llama_kv_cache_recurrent::s_copy(int i) const {
const uint32_t cell_id = i + head;
//////////////////////////////////////////////
// TODO: this should not mutate the KV cache !
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
// prevent out-of-bound sources
if (cell.src < 0 || (uint32_t) cell.src >= size) {
cell.src = cell_id;
}
int32_t res = cell.src;
// TODO: do not mutate the KV cache
// ensure copy only happens once
if (cell.src != (int32_t) cell_id) {
cell.src = cell_id;
}
return res;
}
float llama_kv_cache_recurrent::s_mask(int i) const {
const uint32_t cell_id = i + head;
//////////////////////////////////////////////
// TODO: this should not mutate the KV cache !
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
float res = (float) (cell.src >= 0);
// only clear once
if (cell.src < 0) {
cell.src = cell_id;
}
return res;
// shifting the pos is trivial for recurrent models
return true;
}
size_t llama_kv_cache_recurrent::total_size() const {
@ -1111,6 +1096,10 @@ uint32_t llama_kv_cache_recurrent_state::get_head() const {
return is_full ? 0 : kv->head;
}
int32_t llama_kv_cache_recurrent_state::get_rs_z() const {
return is_full ? 0 : kv->rs_z;
}
uint32_t llama_kv_cache_recurrent_state::get_size() const {
return kv->size;
}
@ -1124,9 +1113,5 @@ ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
}
int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
return kv->s_copy(i);
}
float llama_kv_cache_recurrent_state::s_mask(int i) const {
return kv->s_mask(i);
return kv->cells[i + kv->head].src0;
}