kv-cache : fix shift and defrag logic (#14081)
* kv-cache : fix shift ggml-ci * cont : reset shift[i] ggml-ci * cont : fix defrag erasing cells that didn't move ggml-ci
This commit is contained in:
parent
7f4fbe5183
commit
40cbf571c9
2 changed files with 12 additions and 9 deletions
|
@ -462,7 +462,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
||||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||||
assert(dinfo.ids[i] <= n_kv);
|
assert(dinfo.ids[i] <= n_kv);
|
||||||
|
|
||||||
if (dinfo.ids[i] == n_kv) {
|
if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -944,11 +944,9 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
||||||
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
||||||
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
||||||
|
|
||||||
//GGML_ASSERT(kv_self->size == n_ctx);
|
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
||||||
|
|
||||||
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
|
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size());
|
||||||
ggml_set_input(inp->k_shift);
|
ggml_set_input(inp->k_shift);
|
||||||
|
|
||||||
for (const auto & layer : layers) {
|
for (const auto & layer : layers) {
|
||||||
|
|
|
@ -80,6 +80,9 @@ public:
|
||||||
assert(isrc < pos.size());
|
assert(isrc < pos.size());
|
||||||
assert(idst < pos.size());
|
assert(idst < pos.size());
|
||||||
|
|
||||||
|
assert(pos[idst] == -1);
|
||||||
|
assert(pos[isrc] != -1);
|
||||||
|
|
||||||
pos [idst] = pos [isrc];
|
pos [idst] = pos [isrc];
|
||||||
shift[idst] = shift[isrc];
|
shift[idst] = shift[isrc];
|
||||||
seq [idst] = seq [isrc];
|
seq [idst] = seq [isrc];
|
||||||
|
@ -144,9 +147,10 @@ public:
|
||||||
assert(pos[i] != -1);
|
assert(pos[i] != -1);
|
||||||
|
|
||||||
seq_pos_rm(i);
|
seq_pos_rm(i);
|
||||||
|
seq[i].reset();
|
||||||
|
|
||||||
pos[i] = -1;
|
pos[i] = -1;
|
||||||
seq[i].reset();
|
shift[i] = 0;
|
||||||
|
|
||||||
used.erase(i);
|
used.erase(i);
|
||||||
}
|
}
|
||||||
|
@ -164,6 +168,7 @@ public:
|
||||||
|
|
||||||
if (seq[i].none()) {
|
if (seq[i].none()) {
|
||||||
pos[i] = -1;
|
pos[i] = -1;
|
||||||
|
shift[i] = 0;
|
||||||
|
|
||||||
used.erase(i);
|
used.erase(i);
|
||||||
|
|
||||||
|
@ -192,6 +197,7 @@ public:
|
||||||
seq[i].reset();
|
seq[i].reset();
|
||||||
|
|
||||||
pos[i] = -1;
|
pos[i] = -1;
|
||||||
|
shift[i] = 0;
|
||||||
|
|
||||||
used.erase(i);
|
used.erase(i);
|
||||||
|
|
||||||
|
@ -317,21 +323,20 @@ public:
|
||||||
pos[i] += d;
|
pos[i] += d;
|
||||||
shift[i] += d;
|
shift[i] += d;
|
||||||
|
|
||||||
seq_pos_add(i);
|
|
||||||
|
|
||||||
has_shift = true;
|
has_shift = true;
|
||||||
|
|
||||||
if (pos[i] < 0) {
|
if (pos[i] < 0) {
|
||||||
seq_pos_rm(i);
|
|
||||||
|
|
||||||
seq[i].reset();
|
seq[i].reset();
|
||||||
pos[i] = -1;
|
pos[i] = -1;
|
||||||
|
shift[i] = 0;
|
||||||
|
|
||||||
used.erase(i);
|
used.erase(i);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
seq_pos_add(i);
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue