llama : auto-batch preparation (#13845)

* llama : auto-batch

ggml-ci

* context : simplify if branching
This commit is contained in:
Georgi Gerganov 2025-05-31 12:55:57 +03:00 committed by GitHub
parent 51fa76f172
commit 3f55f781f1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 67 additions and 54 deletions

View file

@ -424,28 +424,33 @@ const llama_kv_cache * llama_context::get_kv_self() const {
return kv_self;
}
void llama_context::kv_self_update() {
bool llama_context::kv_self_update() {
if (!memory) {
return;
return false;
}
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
if (kv_self->update(*this)) {
// if the KV cache did any computation, we have to reserve a new worst-case graph
const auto kv_state = kv_self->init_full();
if (!kv_state) {
throw std::runtime_error("failed to initialize KV cache");
}
const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
if (!gf) {
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
}
if (!kv_self->update(*this)) {
// no updates have been performed
return false;
}
// if the KV cache did any computation, we have to reserve a new worst-case graph
const auto kv_state = kv_self->init_full();
if (!kv_state) {
throw std::runtime_error("failed to initialize KV cache");
}
const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
if (!gf) {
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
}
return true;
}
enum llama_pooling_type llama_context::pooling_type() const {
@ -933,24 +938,44 @@ int llama_context::decode(llama_batch & inp_batch) {
// handle any pending defrags/shifts
kv_self_update();
auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
if (!kv_state) {
return -2;
}
llama_memory_state_ptr kv_state;
switch (kv_state->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS:
{
} break;
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
{
// not a fatal error, we can re-try with a different batch
return 1;
}
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
{
return -2;
}
bool did_defrag = false;
while (true) {
kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
if (!kv_state) {
return -2;
}
switch (kv_state->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS:
{
} break;
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
{
if (!did_defrag) {
did_defrag = true;
kv_self->defrag_sched(-1.0f);
if (kv_self_update()) {
LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
continue;
}
}
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
return 1;
}
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
{
return -2;
}
}
break;
}
// reserve output buffer
@ -2646,22 +2671,8 @@ int32_t llama_encode(
int32_t llama_decode(
llama_context * ctx,
llama_batch batch) {
int ret = ctx->decode(batch);
// defrag and try again
// TODO: distinguish return code when we are sure that even after defrag there is no space available
if (ret == 1) {
llama_kv_self_defrag(ctx);
ret = ctx->decode(batch);
if (ret == 1) {
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
return ret;
}
}
if (ret != 0) {
const int ret = ctx->decode(batch);
if (ret != 0 && ret != 1) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
}