diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index ae2d2684..4a42d6ec 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -286,31 +286,11 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po } llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { - llama_pos result = std::numeric_limits::max(); - - for (uint32_t i = 0; i < cells.size(); ++i) { - if (cells.seq_has(i, seq_id)) { - result = std::min(result, cells.pos_get(i)); - } - } - - if (result == std::numeric_limits::max()) { - result = -1; - } - - return result; + return cells.seq_pos_min(seq_id); } llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { - llama_pos result = -1; - - for (uint32_t i = 0; i < cells.size(); ++i) { - if (cells.seq_has(i, seq_id)) { - result = std::max(result, cells.pos_get(i)); - } - } - - return result; + return cells.seq_pos_max(seq_id); } void llama_kv_cache_unified::restore() { @@ -504,7 +484,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cell_max(), n_pad))); + n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); #ifdef FIND_SLOT_DEBUG LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa); @@ -1018,7 +998,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { const uint32_t n_layer = layers.size(); - const uint32_t n_kv = cell_max(); + const uint32_t n_kv = cells.used_max_p1(); const uint32_t n_used = cells.get_used(); assert(n_used <= n_kv); @@ -1144,16 +1124,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { return true; } -uint32_t llama_kv_cache_unified::cell_max() const { - for (uint32_t i = cells.size(); i > 0; --i) { - if (!cells.is_empty(i - 1)) { - return i; - } - } - - return 0; -} - bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { assert(p0 >= 0 && p1 >= 0); diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 86a96820..ce6261e4 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -246,10 +246,6 @@ private: // return true if cells have been moved bool defrag_prepare(int32_t n_max_nodes); - // find how many cells are currently in use - // TODO: optimize - uint32_t cell_max() const; - size_t total_size() const; size_t size_k_bytes() const; diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index 13854553..dbbd03fc 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -6,6 +6,7 @@ #include #include #include +#include // meta information about KV cells that can be part of multiple sequences at the same time // TODO: add unit tests @@ -18,8 +19,13 @@ public: seq[i].reset(); } - used = 0; has_shift = false; + + used.clear(); + + for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + seq_pos[s].clear(); + } } void reset_shift() { @@ -50,7 +56,25 @@ public: } uint32_t get_used() const { - return used; + return used.size(); + } + + // the index of the first cell that is used + // return 0 if no cells are used + uint32_t used_min() const { + return used.empty() ? 0 : *used.begin(); + } + + // the index of the last cell that is used + 1 + // return 0 if no cells are used + uint32_t used_max_p1() const { +#if 0 + if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin()); + if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin()); + if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin()); +#endif + + return used.empty() ? 0 : *used.rbegin() + 1; } bool get_has_shift() const { @@ -69,6 +93,9 @@ public: pos [isrc] = -1; shift[isrc] = 0; seq [isrc].reset(); + + used.erase (isrc); + used.insert(idst); } // copy the state of cells [i, i + n) (used for save/restore the state of the cells) @@ -95,16 +122,24 @@ public: for (uint32_t j = 0; j < other.pos.size(); ++j) { if (pos[i + j] == -1 && other.pos[j] != -1) { - used++; + used.insert(i + j); } if (pos[i + j] != -1 && other.pos[j] == -1) { - used--; + used.erase(i + j); + } + + if (pos[i + j] != -1) { + seq_pos_rm(i + j); } pos[i + j] = other.pos[j]; seq[i + j] = other.seq[j]; + if (pos[i + j] != -1) { + seq_pos_add(i + j); + } + assert(shift[i + j] == 0); } } @@ -118,11 +153,12 @@ public: assert(seq_id >= 0); seq[i].reset(seq_id); + seq_pos[seq_id].erase(pos[i]); if (seq[i].none()) { pos[i] = -1; - used--; + used.erase(i); return true; } @@ -135,17 +171,22 @@ public: assert(i < pos.size()); if (seq[i].test(seq_id)) { + seq_pos_rm(i); seq[i].reset(); + seq[i].set(seq_id); + seq_pos[seq_id].insert(pos[i]); return false; } if (seq[i].any()) { + seq_pos_rm(i); seq[i].reset(); + pos[i] = -1; - used--; + used.erase(i); return true; } @@ -169,6 +210,33 @@ public: assert(!seq[i].test(seq_id)); seq[i].set(seq_id); + seq_pos[seq_id].insert(pos[i]); + } + + // the minimum position of sequence seq_id currently present in any of the cells + // return -1 if the sequence is not present + llama_pos seq_pos_min(llama_seq_id seq_id) const { + assert(seq_id >= 0); + assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES); + + if (seq_pos[seq_id].empty()) { + return -1; + } + + return *seq_pos[seq_id].begin(); + } + + // the maximum position of sequence seq_id currently present in any of the cells + // return -1 if the sequence is not present + llama_pos seq_pos_max(llama_seq_id seq_id) const { + assert(seq_id >= 0); + assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES); + + if (seq_pos[seq_id].empty()) { + return -1; + } + + return *seq_pos[seq_id].rbegin(); } // note: call only if the cell is not empty @@ -202,7 +270,8 @@ public: assert(pos[i] == -1); pos[i] = p; - used++; + + used.insert(i); } // pos[i] = pos[i] + d @@ -212,16 +281,22 @@ public: assert(i < pos.size()); assert(pos[i] != -1); + seq_pos_rm(i); + pos[i] += d; shift[i] += d; + seq_pos_add(i); + has_shift = true; if (pos[i] < 0) { - pos[i] = -1; - seq[i].reset(); + seq_pos_rm(i); - used--; + seq[i].reset(); + pos[i] = -1; + + used.erase(i); return true; } @@ -238,17 +313,22 @@ public: const llama_pos p_old = pos[i]; + seq_pos_rm(i); + pos[i] /= d; shift[i] += p_old - pos[i]; + seq_pos_add(i); + has_shift = true; } private: - uint32_t used = 0; // used cells (i.e. pos[i] != -1, allowed to not have any seq_id) - bool has_shift = false; + // set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id) + std::set used; + std::vector pos; // this array accumulates any applied shifts to the pos array since the last reset_shift() call @@ -268,6 +348,32 @@ private: // std::vector shift; - std::vector> seq; -}; + using bits_t = std::bitset; + // the bitset seq[i] tells us which sequences are currently occupying the i-th cell + std::vector seq; + + // the set seq_pos[s] tells us which positions are currently present for sequence s + // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache + std::set seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES]; + + // helper functions for updating `seq_pos`, once cell at a time: + + // remove cell i + void seq_pos_rm(uint32_t i) { + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (seq[i].test(s)) { + seq_pos[s].erase(pos[i]); + } + } + } + + // add cell i + void seq_pos_add(uint32_t i) { + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (seq[i].test(s)) { + seq_pos[s].insert(pos[i]); + } + } + } +};