server : fix cache_tokens bug with no cache_prompt (#13533)

This commit is contained in:
Xuan-Son Nguyen 2025-05-14 13:35:07 +02:00 committed by GitHub
parent 09d13d94fb
commit 360a9c98e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 25 additions and 11 deletions

View file

@ -2951,7 +2951,8 @@ struct server_context {
llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard); llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard); llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
if (slot.params.cache_prompt) { // add generated tokens to cache
{
llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
new_tokens[i - n_discard] = new_tokens[i]; new_tokens[i - n_discard] = new_tokens[i];
@ -2996,10 +2997,7 @@ struct server_context {
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
slot.n_past += 1; slot.n_past += 1;
slot.cache_tokens.push_back(slot.sampled);
if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(slot.sampled);
}
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated); slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated);
@ -3171,6 +3169,11 @@ struct server_context {
SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past); SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
} }
} else {
// if we don't cache the prompt, we have to remove the entire KV cache
llama_kv_self_seq_rm(ctx, slot.id, 0, -1);
slot.n_past = 0;
slot.cache_tokens.clear();
} }
} }
@ -3204,7 +3207,7 @@ struct server_context {
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
// remove the non-common part from the cache // remove the non-common part from the cache
slot.cache_tokens.resize(slot.n_past); slot.cache_tokens.keep_first(slot.n_past);
// check if we should process the image // check if we should process the image
if (slot.n_past < slot.n_prompt_tokens if (slot.n_past < slot.n_prompt_tokens
@ -3221,7 +3224,8 @@ struct server_context {
continue; continue;
} }
if (slot.params.cache_prompt) { // add the image chunk to cache
{
const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past); const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past);
slot.cache_tokens.push_back(chunk.get()); // copy slot.cache_tokens.push_back(chunk.get()); // copy
} }
@ -3242,9 +3246,7 @@ struct server_context {
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd);
if (slot.params.cache_prompt) { slot.cache_tokens.push_back(cur_tok);
slot.cache_tokens.push_back(cur_tok);
}
slot.n_prompt_tokens_processed++; slot.n_prompt_tokens_processed++;
slot.n_past++; slot.n_past++;

View file

@ -196,6 +196,18 @@ def test_cache_vs_nocache_prompt():
assert res_cache.body["content"] == res_no_cache.body["content"] assert res_cache.body["content"] == res_no_cache.body["content"]
def test_nocache_long_input_prompt():
global server
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is"*32,
"seed": 42,
"temperature": 1.0,
"cache_prompt": False,
})
assert res.status_code == 200
def test_completion_with_tokens_input(): def test_completion_with_tokens_input():
global server global server
server.temperature = 0.0 server.temperature = 0.0

View file

@ -1153,7 +1153,7 @@ public:
tokens.clear(); tokens.clear();
} }
void resize(size_t n) { void keep_first(size_t n) {
GGML_ASSERT(n <= tokens.size()); GGML_ASSERT(n <= tokens.size());
if (has_mtmd) { if (has_mtmd) {
// we throw an error if we try to remove a token in the middle of an image // we throw an error if we try to remove a token in the middle of an image