From 3e63a58ef7addec35408e2eb67850d7cdc935dc3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 4 Jun 2025 18:58:20 +0300 Subject: [PATCH] kv-cache : refactor the update/defrag mechanism (#13988) * kv-cache : refactor update mechanism ggml-ci * memory : improve status handling * defrag : reset head + add comments ggml-ci * cont : minor fixes ggml-ci --- src/llama-context.cpp | 83 ++++++++---- src/llama-context.h | 6 +- src/llama-kv-cache-recurrent.cpp | 19 ++- src/llama-kv-cache-recurrent.h | 4 +- src/llama-kv-cache-unified-iswa.cpp | 59 ++++---- src/llama-kv-cache-unified-iswa.h | 18 +-- src/llama-kv-cache-unified.cpp | 200 +++++++++++++++++----------- src/llama-kv-cache-unified.h | 73 +++++++--- src/llama-kv-cache.h | 19 ++- src/llama-memory.cpp | 41 ++++++ src/llama-memory.h | 9 +- 11 files changed, 340 insertions(+), 191 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 4ab57438..7c1a642c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -429,22 +429,54 @@ const llama_kv_cache * llama_context::get_kv_self() const { return kv_self; } -bool llama_context::kv_self_update() { +void llama_context::kv_self_defrag_sched() { + if (!memory) { + return; + } + + memory_force_optimize = true; +} + +bool llama_context::kv_self_update(bool optimize) { if (!memory) { return false; } llama_kv_cache * kv_self = static_cast(memory.get()); - if (!kv_self->update(*this)) { - // no updates have been performed - return false; + { + // TODO: remove in the future + optimize |= memory_force_optimize; + memory_force_optimize = false; + + const auto kv_state = kv_self->init_update(this, optimize); + switch (kv_state->get_status()) { + case LLAMA_MEMORY_STATUS_SUCCESS: + { + // noop + } break; + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + // no updates need to be performed + return false; + } + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__); + return false; + } + } + + if (!kv_state->apply()) { + LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__); + } } // if the KV cache did any computation, we have to reserve a new worst-case graph const auto kv_state = kv_self->init_full(); if (!kv_state) { - throw std::runtime_error("failed to initialize KV cache"); + throw std::runtime_error("failed to initialize memory state"); } const uint32_t n_seqs = cparams.n_seq_max; @@ -452,7 +484,7 @@ bool llama_context::kv_self_update() { auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); if (!gf) { - LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__); + LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); } return true; @@ -940,13 +972,13 @@ int llama_context::decode(llama_batch & inp_batch) { n_outputs_all = 1; } + bool did_optimize = false; + // handle any pending defrags/shifts - kv_self_update(); + kv_self_update(false); llama_memory_state_ptr kv_state; - bool did_defrag = false; - while (true) { kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all); if (!kv_state) { @@ -957,25 +989,32 @@ int llama_context::decode(llama_batch & inp_batch) { case LLAMA_MEMORY_STATUS_SUCCESS: { } break; + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, kv_state->get_status()); + + return -2; + } case LLAMA_MEMORY_STATUS_FAILED_PREPARE: { - if (!did_defrag) { - did_defrag = true; + if (!did_optimize) { + did_optimize = true; - kv_self->defrag_sched(-1.0f); - if (kv_self_update()) { - LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens); + if (kv_self_update(true)) { + LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens); continue; } } - LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens); + LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens); return 1; } case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: { + LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens); + return -2; } } @@ -1189,11 +1228,6 @@ int llama_context::decode(llama_batch & inp_batch) { // wait for the computation to finish (automatically done when obtaining the model output) //synchronize(); - // decide if we need to defrag the kv cache - if (cparams.defrag_thold > 0.0f) { - kv_self->defrag_sched(cparams.defrag_thold); - } - // Reset state for the next token before backend sync, to allow the CPU activities in the reset to // overlap with device computation. ggml_backend_sched_reset(sched.get()); @@ -2283,7 +2317,7 @@ llama_kv_cache * llama_get_kv_self(llama_context * ctx) { // deprecated void llama_kv_self_update(llama_context * ctx) { - ctx->kv_self_update(); + ctx->kv_self_update(false); } enum llama_pooling_type llama_pooling_type(const llama_context * ctx) { @@ -2538,13 +2572,8 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { // deprecated void llama_kv_self_defrag(llama_context * ctx) { - auto * kv = ctx->get_kv_self(); - if (!kv) { - return; - } - // force defrag - kv->defrag_sched(-1.0f); + ctx->kv_self_defrag_sched(); } bool llama_kv_self_can_shift(const llama_context * ctx) { diff --git a/src/llama-context.h b/src/llama-context.h index 3b880286..c1c7efb3 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -52,7 +52,8 @@ struct llama_context { // return true of the KV cache was updated // TODO: remove - bool kv_self_update(); + bool kv_self_update(bool optimize); + void kv_self_defrag_sched(); enum llama_pooling_type pooling_type() const; @@ -231,6 +232,9 @@ private: std::unique_ptr memory; + // TODO: temporary, until the llama_kv_self_defrag() API is removed + bool memory_force_optimize = false; + // decode output (2-dimensional array: [n_outputs][n_vocab]) size_t logits_size = 0; // capacity (of floats) for logits float * logits = nullptr; diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp index 641eab2f..77bd5706 100644 --- a/src/llama-kv-cache-recurrent.cpp +++ b/src/llama-kv-cache-recurrent.cpp @@ -1,6 +1,7 @@ #include "llama-kv-cache-recurrent.h" #include "llama-impl.h" +#include "llama-io.h" #include "llama-batch.h" #include "llama-model.h" @@ -386,6 +387,13 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_full() { return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); } +llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) { + GGML_UNUSED(lctx); + GGML_UNUSED(optimize); + + return std::make_unique(LLAMA_MEMORY_STATUS_NO_UPDATE); +} + bool llama_kv_cache_recurrent::prepare(const std::vector & ubatches) { // simply remember the full state because it is very small for this type of cache // TODO: optimize @@ -419,17 +427,6 @@ bool llama_kv_cache_recurrent::prepare(const std::vector & ubatche return success; } -bool llama_kv_cache_recurrent::update(llama_context & lctx) { - GGML_UNUSED(lctx); - // noop - return false; -} - -void llama_kv_cache_recurrent::defrag_sched(float thold) { - GGML_UNUSED(thold); - // noop -} - bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { const uint32_t n_tokens = ubatch.n_tokens; const uint32_t n_seqs = ubatch.n_seqs; diff --git a/src/llama-kv-cache-recurrent.h b/src/llama-kv-cache-recurrent.h index a178ae85..b32f258f 100644 --- a/src/llama-kv-cache-recurrent.h +++ b/src/llama-kv-cache-recurrent.h @@ -52,9 +52,7 @@ public: llama_memory_state_ptr init_full() override; - bool update(llama_context & lctx) override; - - void defrag_sched(float thold) override; + llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; bool prepare(const std::vector & ubatches); diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp index 0eb04563..3aa606c8 100644 --- a/src/llama-kv-cache-unified-iswa.cpp +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -123,26 +123,16 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch assert(heads_base.size() == heads_swa.size()); - return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, + return std::make_unique( this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches)); } llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() { - return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); + return std::make_unique(this); } -bool llama_kv_cache_unified_iswa::update(llama_context & lctx) { - bool res = false; - - res = res | kv_base->update(lctx); - res = res | kv_swa ->update(lctx); - - return res; -} - -void llama_kv_cache_unified_iswa::defrag_sched(float thold) { - kv_base->defrag_sched(thold); - kv_swa ->defrag_sched(thold); +llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) { + return std::make_unique(this, lctx, optimize); } bool llama_kv_cache_unified_iswa::get_can_shift() const { @@ -174,26 +164,38 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const { llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {} llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( - llama_memory_status status, - llama_kv_cache_unified_iswa * kv) : status(status) { - state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base())); - state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ())); + llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) { + state_base = kv->get_base()->init_full(); + state_swa = kv->get_swa ()->init_full(); + + status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status()); +} + +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( + llama_kv_cache_unified_iswa * kv, + llama_context * lctx, + bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) { + state_base = kv->get_base()->init_update(lctx, optimize); + state_swa = kv->get_swa ()->init_update(lctx, optimize); + + status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status()); } llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( - llama_memory_status status, llama_kv_cache_unified_iswa * kv, llama_sbatch sbatch, std::vector heads_base, std::vector heads_swa, std::vector ubatches) - : status(status), - sbatch(std::move(sbatch)), - ubatches(std::move(ubatches)) { - // note: here we copy the ubatches. not sure if this is ideal - state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches)); - state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches)); - } + : status(LLAMA_MEMORY_STATUS_SUCCESS), + sbatch(std::move(sbatch)), + ubatches(std::move(ubatches)) { + // note: here we copy the ubatches. not sure if this is ideal + state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches)); + state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches)); + + status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status()); +} llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default; @@ -233,17 +235,18 @@ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const { const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + return ubatches[i_next]; } const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - return state_base.get(); + return static_cast(state_base.get()); } const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - return state_swa.get(); + return static_cast(state_swa.get()); } diff --git a/src/llama-kv-cache-unified-iswa.h b/src/llama-kv-cache-unified-iswa.h index 8b067da0..cba5bbe9 100644 --- a/src/llama-kv-cache-unified-iswa.h +++ b/src/llama-kv-cache-unified-iswa.h @@ -54,9 +54,7 @@ public: llama_memory_state_ptr init_full() override; - bool update(llama_context & lctx) override; - - void defrag_sched(float thold) override; + llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; bool get_can_shift() const override; @@ -86,12 +84,16 @@ public: // used to create a full-cache state llama_kv_cache_unified_iswa_state( - llama_memory_status status, llama_kv_cache_unified_iswa * kv); + // used to create an update state + llama_kv_cache_unified_iswa_state( + llama_kv_cache_unified_iswa * kv, + llama_context * lctx, + bool optimize); + // used to create a state from a batch llama_kv_cache_unified_iswa_state( - llama_memory_status status, llama_kv_cache_unified_iswa * kv, llama_sbatch sbatch, std::vector heads_base, @@ -120,7 +122,7 @@ public: const llama_kv_cache_unified_state * get_swa() const; private: - const llama_memory_status status; + llama_memory_status status; //llama_kv_cache_unified_iswa * kv; @@ -131,6 +133,6 @@ private: std::vector ubatches; - std::unique_ptr state_base; - std::unique_ptr state_swa; + llama_memory_state_ptr state_base; + llama_memory_state_ptr state_swa; }; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 4007f202..5354f808 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -1,6 +1,7 @@ #include "llama-kv-cache-unified.h" #include "llama-impl.h" +#include "llama-io.h" #include "llama-model.h" #include "llama-context.h" @@ -320,16 +321,49 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch( return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } - return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, + return std::make_unique( this, std::move(sbatch), std::move(heads), std::move(ubatches)); } llama_memory_state_ptr llama_kv_cache_unified::init_full() { - return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); + return std::make_unique(this); } -std::vector llama_kv_cache_unified::prepare(const std::vector & ubatches) { - std::vector res; +llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) { + bool do_shift = get_has_shift(); + + defrag_info dinfo; + + // see if we need to defrag + { + bool do_defrag = optimize; + + const auto thold = lctx->get_cparams().defrag_thold; + + if (!do_defrag && thold > 0.0f) { + const auto n_kv = cells.used_max_p1(); + + // - do not defrag small contexts (i.e. < 2048 tokens) + // - count the padding towards the number of used tokens + const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f; + + if (fragmentation > thold) { + LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); + + do_defrag = true; + } + } + + if (do_defrag) { + dinfo = defrag_prepare(lctx->graph_max_nodes()); + } + } + + return std::make_unique(this, lctx, do_shift, std::move(dinfo)); +} + +llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector & ubatches) { + llama_kv_cache_unified::ubatch_heads res; struct state { uint32_t head_old; // old position of the head, before placing the ubatch @@ -374,12 +408,12 @@ std::vector llama_kv_cache_unified::prepare(const std::vectorget_sched(); - if (cells.get_has_shift()) { + if (do_shift) { if (!get_can_shift()) { GGML_ABORT("The current KV cache / model configuration does not support K-shift"); } @@ -390,9 +424,9 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { ggml_backend_sched_reset(sched); - auto * gf = lctx.graph_init(); + auto * gf = lctx->graph_init(); - auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf); + auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf); if (!res) { LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__); return updated; @@ -405,7 +439,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { res->set_inputs(nullptr); - if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) { + if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) { LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__); return updated; } @@ -416,56 +450,55 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { cells.reset_shift(); } - if (do_defrag) { + if (!dinfo.empty()) { LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); - if (defrag_prepare(lctx.graph_max_nodes())) { - ggml_backend_sched_reset(sched); + // apply moves: + { + const auto n_kv = dinfo.ids.size(); - auto * gf = lctx.graph_init(); + for (uint32_t i = 0; i < n_kv; ++i) { + assert(dinfo.ids[i] <= n_kv); - auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf); - if (!res) { - LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__); - return updated; + if (dinfo.ids[i] == n_kv) { + continue; + } + + cells.mv(i, dinfo.ids[i]); } - if (!ggml_backend_sched_alloc_graph(sched, gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__); - return updated; - } - - res->set_inputs(nullptr); - - if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) { - LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__); - return updated; - } - - updated = true; + // reset the head so we can find the first free slot during the next ubatch + head = 0; } - do_defrag = false; + ggml_backend_sched_reset(sched); + + auto * gf = lctx->graph_init(); + + auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__); + return updated; + } + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__); + return updated; + } + + res->set_inputs(nullptr); + + if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__); + return updated; + } + + updated = true; } return updated; } -void llama_kv_cache_unified::defrag_sched(float thold) { - const auto n_kv = cells.used_max_p1(); - - // - do not defrag small contexts (i.e. < 2048 tokens) - // - count the padding towards the number of used tokens - const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f; - - // queue defragmentation for next llama_kv_cache_update - if (fragmentation > thold) { - LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); - - do_defrag = true; - } -} - int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { const uint32_t n_tokens = ubatch.n_tokens; @@ -612,6 +645,10 @@ uint32_t llama_kv_cache_unified::get_size() const { return cells.size(); } +bool llama_kv_cache_unified::get_has_shift() const { + return cells.get_has_shift(); +} + uint32_t llama_kv_cache_unified::get_n_kv() const { return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); } @@ -941,12 +978,13 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( } llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_cgraph * gf) const { + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf, + const defrag_info & dinfo) const { auto res = std::make_unique(); - const auto & ids = defrag_info.ids; + const auto & ids = dinfo.ids; #if 0 // CPU defrag @@ -1087,7 +1125,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( return res; } -bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { +llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const { const uint32_t n_layer = layers.size(); const uint32_t n_kv = cells.used_max_p1(); @@ -1108,14 +1146,9 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer); // determine which KV cells to move where - // - // cell i moves to ids[i] - // - // if ids[i] == i || ids[i] == n_kv, then cell i is not moved - // - auto & ids = defrag_info.ids; + defrag_info res; + auto & ids = res.ids; - ids.clear(); ids.resize(n_kv, n_kv); for (uint32_t i0 = 0; i0 < n_used; ++i0) { @@ -1179,11 +1212,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { // this cell goes to (i0 + nf) ids[i1] = i0 + nf; - // move the cell meta data - cells.mv(i1, i0 + nf); - - head = n_used; - if (!cont) { n_moves++; cont = true; @@ -1206,14 +1234,14 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { } if (n_moves == 0) { - return false; + return {}; } LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves); LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer); - return true; + return res; } bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { @@ -1636,24 +1664,27 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {} llama_kv_cache_unified_state::llama_kv_cache_unified_state( - llama_memory_status status, - llama_kv_cache_unified * kv) : status(status), kv(kv) { - n_kv = kv->get_size(); - head = 0; - } + llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) { + n_kv = kv->get_size(); + head = 0; +} llama_kv_cache_unified_state::llama_kv_cache_unified_state( - llama_memory_status status, - llama_kv_cache_unified * kv, - llama_sbatch sbatch, - std::vector heads, - std::vector ubatches) - : status(status), - kv(kv), - sbatch(std::move(sbatch)), - heads(std::move(heads)), - ubatches(std::move(ubatches)) { + llama_kv_cache_unified * kv, + llama_context * lctx, + bool do_shift, + defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) { + if (!do_shift && dinfo.empty()) { + status = LLAMA_MEMORY_STATUS_NO_UPDATE; } +} + +llama_kv_cache_unified_state::llama_kv_cache_unified_state( + llama_kv_cache_unified * kv, + llama_sbatch sbatch, + llama_kv_cache_unified::ubatch_heads heads, + std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) { +} llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default; @@ -1670,6 +1701,13 @@ bool llama_kv_cache_unified_state::next() { bool llama_kv_cache_unified_state::apply() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + // no ubatches -> this is a KV cache update + if (ubatches.empty()) { + kv->update(lctx, do_shift, dinfo); + + return true; + } + kv->apply_ubatch(heads[i_next], ubatches[i_next]); n_kv = kv->get_n_kv(); diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 1f1d44b9..6ff388a8 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -24,6 +24,19 @@ public: // this callback is used to filter out layers that should not be included in the cache using layer_filter_cb = std::function; + using ubatch_heads = std::vector; + + struct defrag_info { + bool empty() const { + return ids.empty(); + } + + // contains information about which cell moves where: + // - cell i moves to ids[i] + // - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved + std::vector ids; + }; + llama_kv_cache_unified( const llama_model & model, layer_filter_cb && filter, @@ -66,9 +79,7 @@ public: llama_memory_state_ptr init_full() override; - bool update(llama_context & lctx) override; - - void defrag_sched(float thold) override; + llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; bool get_can_shift() const override; @@ -83,6 +94,8 @@ public: uint32_t get_size() const; + bool get_has_shift() const; + // // graph_build API // @@ -103,7 +116,9 @@ public: // find places for the provided ubatches in the cache, returns the head locations // return empty vector on failure - std::vector prepare(const std::vector & ubatches); + ubatch_heads prepare(const std::vector & ubatches); + + bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo); // return the cell position where we can insert the ubatch // return -1 on failure to find a contiguous slot of kv cells @@ -133,8 +148,7 @@ private: ggml_tensor * v; }; - bool do_defrag = false; - bool v_trans = true; // the value tensor is transposed + bool v_trans = true; // the value tensor is transposed // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) // note: this is not part of the KV state and it's only used to speed-up the find_slot() method @@ -160,13 +174,8 @@ private: // model layer id -> KV cache layer id std::unordered_map map_layer_ids; - // defrag - struct { - std::vector ids; - } defrag_info; - - // return true if cells have been moved - bool defrag_prepare(int32_t n_max_nodes); + // return non-empty vector if cells have been moved + defrag_info defrag_prepare(int32_t n_max_nodes) const; size_t total_size() const; @@ -192,7 +201,8 @@ private: llm_graph_result_ptr build_graph_defrag( const llama_cparams & cparams, ggml_context * ctx, - ggml_cgraph * gf) const; + ggml_cgraph * gf, + const defrag_info & dinfo) const; void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; @@ -203,20 +213,29 @@ private: class llama_kv_cache_unified_state : public llama_memory_state_i { public: + // some shorthands + using ubatch_heads = llama_kv_cache_unified::ubatch_heads; + using defrag_info = llama_kv_cache_unified::defrag_info; + // used for errors llama_kv_cache_unified_state(llama_memory_status status); // used to create a full-cache state llama_kv_cache_unified_state( - llama_memory_status status, llama_kv_cache_unified * kv); - // used to create a state from a batch + // used to create an update state + llama_kv_cache_unified_state( + llama_kv_cache_unified * kv, + llama_context * lctx, + bool do_shift, + defrag_info dinfo); + + // used to create a decode state from a batch llama_kv_cache_unified_state( - llama_memory_status status, llama_kv_cache_unified * kv, llama_sbatch sbatch, - std::vector heads, + ubatch_heads heads, std::vector ubatches); virtual ~llama_kv_cache_unified_state(); @@ -253,16 +272,30 @@ public: void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; private: - const llama_memory_status status; + llama_memory_status status; llama_kv_cache_unified * kv; + llama_context * lctx; + + // + // update state + // + + bool do_shift = false; + + defrag_info dinfo; + + // + // batch processing state + // llama_sbatch sbatch; // the index of the next ubatch to process size_t i_next = 0; - std::vector heads; + ubatch_heads heads; + std::vector ubatches; // diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 2d04705f..17a5e5cb 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -1,12 +1,16 @@ #pragma once #include "llama.h" -#include "llama-io.h" #include "llama-memory.h" +class llama_io_write_i; +class llama_io_read_i; + struct llama_kv_cache : public llama_memory_i { virtual ~llama_kv_cache() = default; + // TODO: move the init_ interfaces to llama_memory_i + // split the input batch into a set of ubatches and verify that they can fit into the cache // return a state object containing the ubatches and KV cache state required to process them // check the llama_memory_state_i::get_status() for the result @@ -19,16 +23,9 @@ struct llama_kv_cache : public llama_memory_i { // simulate full cache, used for allocating worst-case compute buffers virtual llama_memory_state_ptr init_full() = 0; - // process any pending defrag/shift/etc. operations - // optionally call once before processing a new batch - // return true if any operations were performed - virtual bool update(llama_context & lctx) = 0; - - // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing - // TODO: change to - // llama_memory_state_ptr init_defrag(float thold) = 0; - // - virtual void defrag_sched(float thold) = 0; + // prepare for any pending memory updates, such as shifts, defrags, etc. + // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update + virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0; // getters virtual bool get_can_shift() const = 0; diff --git a/src/llama-memory.cpp b/src/llama-memory.cpp index 10173253..f1107672 100644 --- a/src/llama-memory.cpp +++ b/src/llama-memory.cpp @@ -1 +1,42 @@ #include "llama-memory.h" + +llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1) { + bool has_update = false; + + switch (s0) { + case LLAMA_MEMORY_STATUS_SUCCESS: + { + has_update = true; + break; + } + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + break; + } + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + return s0; + } + } + + switch (s1) { + case LLAMA_MEMORY_STATUS_SUCCESS: + { + has_update = true; + break; + } + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + break; + } + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + return s1; + } + } + + // if either status has an update, then the combined status has an update + return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE; +} diff --git a/src/llama-memory.h b/src/llama-memory.h index b3799d66..ab0d399c 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -36,12 +36,19 @@ public: virtual bool get_can_edit() const = 0; }; +using llama_memory_ptr = std::unique_ptr; + enum llama_memory_status { LLAMA_MEMORY_STATUS_SUCCESS = 0, + LLAMA_MEMORY_STATUS_NO_UPDATE, LLAMA_MEMORY_STATUS_FAILED_PREPARE, LLAMA_MEMORY_STATUS_FAILED_COMPUTE, }; +// helper function for combining the status of two memory states +// useful for implementing hybrid memory types (e.g. iSWA) +llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1); + // the interface for managing the memory state during batch processing // this interface is implemented per memory type. see: // - llama_kv_cache_unified_state @@ -69,7 +76,7 @@ public: // get the current ubatch virtual const llama_ubatch & get_ubatch() const = 0; - // get the status of the memory state + // get the status of the memory state - used for error handling and checking if any updates would be applied virtual llama_memory_status get_status() const = 0; };