kv-cache : avoid modifying recurrent cells when setting inputs (#13834)

* kv-cache : avoid modifying recurrent cells when setting inputs

* kv-cache : remove inp_s_mask

It was replaced with equivalent and simpler functionality
with rs_z (the first zeroed state) and the already-existing inp_s_copy.

* kv-cache : fix non-consecutive token pos warning for recurrent models

The problem was apparently caused by how the tail cells were swapped.

* graph : simplify logic for recurrent state copies

* kv-cache : use cell without src refs for rs_z in recurrent cache

* llama-graph : fix recurrent state copy

The `state_copy` shuffle assumes everything is moved at once,
which is not true when `states_extra` is copied back to the cache
before copying the range of states between `head` and `head + n_seqs`.
This is only a problem if any of the cells in [`head`, `head + n_seqs`)
have an `src` in [`head + n_seqs`, `head + n_kv`),
which does happen when `n_ubatch > 1` in the `llama-parallel` example.

Changing the order of the operations avoids the potential overwrite
before use, although when copies are avoided (like with Mamba2),
this will require further changes.

* llama-graph : rename n_state to state_size in build_recurrent_state

This naming should reduce confusion between the state size
and the number of states.
This commit is contained in:
compilade 2025-06-10 18:20:14 -04:00 committed by GitHub
parent 55f6b9fa65
commit dad5c44398
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 117 additions and 180 deletions

View file

@ -57,10 +57,6 @@ public:
bool get_can_shift() const override;
// TODO: temporary methods - they are not really const as they do const_cast<>, fix this
int32_t s_copy(int i) const;
float s_mask(int i) const;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
@ -73,10 +69,14 @@ public:
// computed before each graph build
uint32_t n = 0;
// first zero-ed state
int32_t rs_z = -1;
// TODO: optimize for recurrent state needs
struct kv_cell {
llama_pos pos = -1;
int32_t src = -1; // used to copy states
int32_t src = -1; // used to know where states should be copied from
int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
int32_t tail = -1;
std::set<llama_seq_id> seq_id;
@ -157,13 +157,13 @@ public:
uint32_t get_n_kv() const;
uint32_t get_head() const;
int32_t get_rs_z() const;
uint32_t get_size() const;
ggml_tensor * get_k_l(int32_t il) const;
ggml_tensor * get_v_l(int32_t il) const;
int32_t s_copy(int i) const;
float s_mask(int i) const;
private:
const llama_memory_status status;