kv-cells : track min/max used cells and per-sequence positions (#13808)
* kv-cells : track min/max used cells and per-sequence positions ggml-ci * kv-cells : fix pos-modification updates for seq_pos ggml-ci * kv-cells : add comments ggml-ci
This commit is contained in:
parent
f9cd68398b
commit
81713121ee
3 changed files with 124 additions and 52 deletions
|
@ -286,31 +286,11 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
|
|||
}
|
||||
|
||||
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
|
||||
llama_pos result = std::numeric_limits<llama_pos>::max();
|
||||
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (cells.seq_has(i, seq_id)) {
|
||||
result = std::min(result, cells.pos_get(i));
|
||||
}
|
||||
}
|
||||
|
||||
if (result == std::numeric_limits<llama_pos>::max()) {
|
||||
result = -1;
|
||||
}
|
||||
|
||||
return result;
|
||||
return cells.seq_pos_min(seq_id);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
||||
llama_pos result = -1;
|
||||
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (cells.seq_has(i, seq_id)) {
|
||||
result = std::max(result, cells.pos_get(i));
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
return cells.seq_pos_max(seq_id);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::restore() {
|
||||
|
@ -504,7 +484,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
|
|||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cell_max(), n_pad)));
|
||||
n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
|
||||
|
||||
#ifdef FIND_SLOT_DEBUG
|
||||
LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
|
||||
|
@ -1018,7 +998,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
|||
bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
||||
const uint32_t n_layer = layers.size();
|
||||
|
||||
const uint32_t n_kv = cell_max();
|
||||
const uint32_t n_kv = cells.used_max_p1();
|
||||
const uint32_t n_used = cells.get_used();
|
||||
|
||||
assert(n_used <= n_kv);
|
||||
|
@ -1144,16 +1124,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|||
return true;
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::cell_max() const {
|
||||
for (uint32_t i = cells.size(); i > 0; --i) {
|
||||
if (!cells.is_empty(i - 1)) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
||||
assert(p0 >= 0 && p1 >= 0);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue