kv-cells : track min/max used cells and per-sequence positions (#13808)
* kv-cells : track min/max used cells and per-sequence positions ggml-ci * kv-cells : fix pos-modification updates for seq_pos ggml-ci * kv-cells : add comments ggml-ci
This commit is contained in:
parent
f9cd68398b
commit
81713121ee
3 changed files with 124 additions and 52 deletions
|
@ -286,31 +286,11 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
|
|||
}
|
||||
|
||||
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
|
||||
llama_pos result = std::numeric_limits<llama_pos>::max();
|
||||
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (cells.seq_has(i, seq_id)) {
|
||||
result = std::min(result, cells.pos_get(i));
|
||||
}
|
||||
}
|
||||
|
||||
if (result == std::numeric_limits<llama_pos>::max()) {
|
||||
result = -1;
|
||||
}
|
||||
|
||||
return result;
|
||||
return cells.seq_pos_min(seq_id);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
||||
llama_pos result = -1;
|
||||
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (cells.seq_has(i, seq_id)) {
|
||||
result = std::max(result, cells.pos_get(i));
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
return cells.seq_pos_max(seq_id);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::restore() {
|
||||
|
@ -504,7 +484,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
|
|||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cell_max(), n_pad)));
|
||||
n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
|
||||
|
||||
#ifdef FIND_SLOT_DEBUG
|
||||
LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
|
||||
|
@ -1018,7 +998,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
|||
bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
||||
const uint32_t n_layer = layers.size();
|
||||
|
||||
const uint32_t n_kv = cell_max();
|
||||
const uint32_t n_kv = cells.used_max_p1();
|
||||
const uint32_t n_used = cells.get_used();
|
||||
|
||||
assert(n_used <= n_kv);
|
||||
|
@ -1144,16 +1124,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|||
return true;
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::cell_max() const {
|
||||
for (uint32_t i = cells.size(); i > 0; --i) {
|
||||
if (!cells.is_empty(i - 1)) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
||||
assert(p0 >= 0 && p1 >= 0);
|
||||
|
||||
|
|
|
@ -246,10 +246,6 @@ private:
|
|||
// return true if cells have been moved
|
||||
bool defrag_prepare(int32_t n_max_nodes);
|
||||
|
||||
// find how many cells are currently in use
|
||||
// TODO: optimize
|
||||
uint32_t cell_max() const;
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
size_t size_k_bytes() const;
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
#include <bitset>
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
// meta information about KV cells that can be part of multiple sequences at the same time
|
||||
// TODO: add unit tests
|
||||
|
@ -18,8 +19,13 @@ public:
|
|||
seq[i].reset();
|
||||
}
|
||||
|
||||
used = 0;
|
||||
has_shift = false;
|
||||
|
||||
used.clear();
|
||||
|
||||
for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
||||
seq_pos[s].clear();
|
||||
}
|
||||
}
|
||||
|
||||
void reset_shift() {
|
||||
|
@ -50,7 +56,25 @@ public:
|
|||
}
|
||||
|
||||
uint32_t get_used() const {
|
||||
return used;
|
||||
return used.size();
|
||||
}
|
||||
|
||||
// the index of the first cell that is used
|
||||
// return 0 if no cells are used
|
||||
uint32_t used_min() const {
|
||||
return used.empty() ? 0 : *used.begin();
|
||||
}
|
||||
|
||||
// the index of the last cell that is used + 1
|
||||
// return 0 if no cells are used
|
||||
uint32_t used_max_p1() const {
|
||||
#if 0
|
||||
if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin());
|
||||
if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin());
|
||||
if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin());
|
||||
#endif
|
||||
|
||||
return used.empty() ? 0 : *used.rbegin() + 1;
|
||||
}
|
||||
|
||||
bool get_has_shift() const {
|
||||
|
@ -69,6 +93,9 @@ public:
|
|||
pos [isrc] = -1;
|
||||
shift[isrc] = 0;
|
||||
seq [isrc].reset();
|
||||
|
||||
used.erase (isrc);
|
||||
used.insert(idst);
|
||||
}
|
||||
|
||||
// copy the state of cells [i, i + n) (used for save/restore the state of the cells)
|
||||
|
@ -95,16 +122,24 @@ public:
|
|||
|
||||
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
||||
if (pos[i + j] == -1 && other.pos[j] != -1) {
|
||||
used++;
|
||||
used.insert(i + j);
|
||||
}
|
||||
|
||||
if (pos[i + j] != -1 && other.pos[j] == -1) {
|
||||
used--;
|
||||
used.erase(i + j);
|
||||
}
|
||||
|
||||
if (pos[i + j] != -1) {
|
||||
seq_pos_rm(i + j);
|
||||
}
|
||||
|
||||
pos[i + j] = other.pos[j];
|
||||
seq[i + j] = other.seq[j];
|
||||
|
||||
if (pos[i + j] != -1) {
|
||||
seq_pos_add(i + j);
|
||||
}
|
||||
|
||||
assert(shift[i + j] == 0);
|
||||
}
|
||||
}
|
||||
|
@ -118,11 +153,12 @@ public:
|
|||
assert(seq_id >= 0);
|
||||
|
||||
seq[i].reset(seq_id);
|
||||
seq_pos[seq_id].erase(pos[i]);
|
||||
|
||||
if (seq[i].none()) {
|
||||
pos[i] = -1;
|
||||
|
||||
used--;
|
||||
used.erase(i);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -135,17 +171,22 @@ public:
|
|||
assert(i < pos.size());
|
||||
|
||||
if (seq[i].test(seq_id)) {
|
||||
seq_pos_rm(i);
|
||||
seq[i].reset();
|
||||
|
||||
seq[i].set(seq_id);
|
||||
seq_pos[seq_id].insert(pos[i]);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
if (seq[i].any()) {
|
||||
seq_pos_rm(i);
|
||||
seq[i].reset();
|
||||
|
||||
pos[i] = -1;
|
||||
|
||||
used--;
|
||||
used.erase(i);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -169,6 +210,33 @@ public:
|
|||
assert(!seq[i].test(seq_id));
|
||||
|
||||
seq[i].set(seq_id);
|
||||
seq_pos[seq_id].insert(pos[i]);
|
||||
}
|
||||
|
||||
// the minimum position of sequence seq_id currently present in any of the cells
|
||||
// return -1 if the sequence is not present
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const {
|
||||
assert(seq_id >= 0);
|
||||
assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
|
||||
|
||||
if (seq_pos[seq_id].empty()) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return *seq_pos[seq_id].begin();
|
||||
}
|
||||
|
||||
// the maximum position of sequence seq_id currently present in any of the cells
|
||||
// return -1 if the sequence is not present
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const {
|
||||
assert(seq_id >= 0);
|
||||
assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
|
||||
|
||||
if (seq_pos[seq_id].empty()) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return *seq_pos[seq_id].rbegin();
|
||||
}
|
||||
|
||||
// note: call only if the cell is not empty
|
||||
|
@ -202,7 +270,8 @@ public:
|
|||
assert(pos[i] == -1);
|
||||
|
||||
pos[i] = p;
|
||||
used++;
|
||||
|
||||
used.insert(i);
|
||||
}
|
||||
|
||||
// pos[i] = pos[i] + d
|
||||
|
@ -212,16 +281,22 @@ public:
|
|||
assert(i < pos.size());
|
||||
assert(pos[i] != -1);
|
||||
|
||||
seq_pos_rm(i);
|
||||
|
||||
pos[i] += d;
|
||||
shift[i] += d;
|
||||
|
||||
seq_pos_add(i);
|
||||
|
||||
has_shift = true;
|
||||
|
||||
if (pos[i] < 0) {
|
||||
pos[i] = -1;
|
||||
seq[i].reset();
|
||||
seq_pos_rm(i);
|
||||
|
||||
used--;
|
||||
seq[i].reset();
|
||||
pos[i] = -1;
|
||||
|
||||
used.erase(i);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -238,17 +313,22 @@ public:
|
|||
|
||||
const llama_pos p_old = pos[i];
|
||||
|
||||
seq_pos_rm(i);
|
||||
|
||||
pos[i] /= d;
|
||||
shift[i] += p_old - pos[i];
|
||||
|
||||
seq_pos_add(i);
|
||||
|
||||
has_shift = true;
|
||||
}
|
||||
|
||||
private:
|
||||
uint32_t used = 0; // used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
|
||||
|
||||
bool has_shift = false;
|
||||
|
||||
// set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
|
||||
std::set<uint32_t> used;
|
||||
|
||||
std::vector<llama_pos> pos;
|
||||
|
||||
// this array accumulates any applied shifts to the pos array since the last reset_shift() call
|
||||
|
@ -268,6 +348,32 @@ private:
|
|||
//
|
||||
std::vector<llama_pos> shift;
|
||||
|
||||
std::vector<std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>> seq;
|
||||
};
|
||||
using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>;
|
||||
|
||||
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
|
||||
std::vector<bits_t> seq;
|
||||
|
||||
// the set seq_pos[s] tells us which positions are currently present for sequence s
|
||||
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
|
||||
std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES];
|
||||
|
||||
// helper functions for updating `seq_pos`, once cell at a time:
|
||||
|
||||
// remove cell i
|
||||
void seq_pos_rm(uint32_t i) {
|
||||
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
||||
if (seq[i].test(s)) {
|
||||
seq_pos[s].erase(pos[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add cell i
|
||||
void seq_pos_add(uint32_t i) {
|
||||
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
||||
if (seq[i].test(s)) {
|
||||
seq_pos[s].insert(pos[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue