kv-cache : refactor + add llama_memory_state_i (#13746)
* kv-cache : simplify the "struct llama_kv_cache" interface ggml-ci * kv-cache : revert the (n_swa + n_ubatch) change (for next PR) ggml-ci * kv-cache : some comments ggml-ci * context : fix graph reserve for multiple sequences ggml-ci * kv-cache : fix typo [no ci] * kv-cache : fix find_slot() logic for free slots ggml-ci * llama : add TODO for deprecating the defrag API in the future * kv-cache : improve find_slot() using min/max seq pos info ggml-ci * llama : handle aborts and compute errors ggml-ci * memory : extract state into llama_memory_state ggml-ci * kv-cache : add comments ggml-ci * server : update batching logic to reset n_batch on successful decode * server : upon full re-processing, remove the sequence from the cache * kv-cache : add TODO for doing split_equal when split_simple fails ggml-ci
This commit is contained in:
parent
eb3949938e
commit
12d0188c0d
14 changed files with 1304 additions and 655 deletions
|
@ -3214,9 +3214,12 @@ struct server_context {
|
|||
}
|
||||
|
||||
if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
|
||||
if (llama_kv_self_seq_pos_min(ctx, slot.id) > 0) {
|
||||
const auto pos_min = llama_kv_self_seq_pos_min(ctx, slot.id);
|
||||
if (pos_min > 0) {
|
||||
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min);
|
||||
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
|
||||
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
||||
llama_kv_self_seq_rm(ctx, slot.id, 0, -1);
|
||||
slot.n_past = 0;
|
||||
}
|
||||
}
|
||||
|
@ -3379,8 +3382,10 @@ struct server_context {
|
|||
}
|
||||
}
|
||||
|
||||
int32_t i_next = 0;
|
||||
|
||||
// process the created batch of tokens
|
||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; i = i_next) {
|
||||
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
||||
|
||||
llama_batch batch_view = {
|
||||
|
@ -3425,13 +3430,18 @@ struct server_context {
|
|||
|
||||
// retry with half the batch size to try to find a free slot in the KV cache
|
||||
n_batch /= 2;
|
||||
i -= n_batch;
|
||||
|
||||
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
|
||||
|
||||
continue; // continue loop of n_batch
|
||||
}
|
||||
|
||||
// move the head of the batch forward with the number of tokens we just processed
|
||||
i_next = i + n_tokens;
|
||||
|
||||
// on successful decode, restore the original batch size
|
||||
n_batch = llama_n_batch(ctx);
|
||||
|
||||
for (auto & slot : slots) {
|
||||
if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
|
||||
continue; // continue loop of slots
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue