
* feat: Add llama_model_is_hybrid API call Also, split llama_model_is_recurrent into llm_arch_is_recurrent in llama-arch with llama_model_is_recurrent delegating to llm_arch_is_recurrent. The same split is done for hybird. This is needed because there are places where the llama_model has not yet been initialized but we need to check if the model is recurrent (specifically for the per-layer recurrent check array in hparams). Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Add c++ side constants for attention layer indices hparam Branch: GraniteFour * feat: Add support for distinguishing recurrent vs non-recurrent layers in hparams Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Auto-fill hparams.recurrent_layer_arr based on whether the model is recurrent Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor: rename *_is_hybrid -> *_is_hybrid_recurrent The implementation of the hybrid cache intentionally does not specify the types of the child caches, so there was a naming mismatch with these predicate functions that used "hybrid" to imply "hybrid recurrent." Branch: HybridCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Add layer filter to recurrent cache Branch: HybridCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Use per-layer sizing everywhere in kv caches Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: First pass at llama_kv_cache_hybrid_recurrent This follows the pattern in iswa where the two child caches are held explicitly to support the case where a model requires a single attention cache and a single recurrent cache where each layer uses exactly one of the caches. This is a rewrite of the more generic approach in the original hybrid cache PR: https://github.com/ggml-org/llama.cpp/pull/13276 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Construct hybrid recurrent cache for hybrid recurrent models This includes a refactor of the create_memory logic to avoid needing to use the arch enum explicitly unless a model needs explicit cache instantiation logic beyond the standard logic for recurrent, hybrid, unified, and iswa. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Fix wrong bool condition for split equal in hybrid cache Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Fix shift logic to defer to unified cache Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Support hybrid recurrent in llama-graph NOTE: I intentionally did not add support for s_mask since it will be going away soon Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Fix logic for initializing inputs and attn layers for hybrid caches Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Update recurrent cache for changes to remove intermediate kv_cache interface Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Fix status for init_update sig for recurrent cache state Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Add missing padding to n_ctx for hybrid cache construction Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Update clear signature for data argument after rebase Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Remove errant virtual destructor leftover from previous impl attempt Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Use per-layer n_embd_k/v_s calls for mamba (1) layers Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor: Remove n_embd_k/v_s from unified cache No longer needed now that unified isn't also supporting recurrent https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140761069 Branch: HybridRecurrentCache * refactor: Remove layer index from n_embd_k/v_s Now that it's not used at all in the unified cache, we don't need to use the layer index to zero it out for attention layers. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor: Remove n_embd_k/v_gqa from recurrent cache This is no longer needed now that there are separate implementations https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140825128 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Allow custom layer filters for hybrid recurrent This should help support architectures like Falcon H1 where there is overlap between layers that need attention and recurrent caches. https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140748922 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Remove logits_all after rebase Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Remove llama_model_is_hybrid_Recurrent public API https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2141728423 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor: Use llama_memory_state_ptr for child states in hybrid memory state Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Overhaul build_recurrent_state / build_inp_s_copy to match attention pattern https://github.com/ggml-org/llama.cpp/pull/13979/files#r2141701738 This is a big overhaul to bring consistency between how inputs and per- layer components are created for attention layers and recurrent layers. The main changes are: - Rename class llm_graph_input_s_copy -> llm_graph_input_rs - Add a corresponding llm_graph_input_rs_hybrid_recurrent - Rename build_inp_s_copy -> build_rs_inp_recurrent - Add a corresponding build_rs_inp_hybrid_recurrent - Rename build_recurrent_state -> build_rs to match build_attn w/ llm_graph_input_rs android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input - Add a corresponding overload of build_rs w/ llm_graph_input_rs_hybrid_recurrent android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input - Add a llm_graph_input_attn_kv_hybrid_recurrent analogous to llm_graph_input_attn_kv_unified - Add a build_attn override that takes llm_graph_input_attn_kv_hybrid_recurrent android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input This makes the two paradigms fully consistent. The main drawback is the code duplication in the build_attn and build_rs implementations where the only difference between implementations is how they cast the memory state. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Fix resize vs reserve and skip null tensors in size computation https://github.com/ggml-org/llama.cpp/pull/13979/files#r2149469788 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> Co-Authored-By: @younesbelkada * fix: Fix initialization of child states Since initially writing this PR, the logic in the child state types changed such that using the "init full" signature and keeping the ubatches on the parent struct no longer worked. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor: Use a common build_recurrent_state method that is cache-agnostic This reduces the code duplication between the different build_rs impls and also retains a similar signature to the previous build_recurrent_state method while standardizing on the input-dispatched build_rs implementation. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * recurrent : rework graph inputs + add TODOs ggml-ci * refactor: Make status and child states const in hybrid and iswa Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor: Rename llama_kv_cache_[recurrent|hybrid_recurrent] to remove kv cache This removes the notion of "kv" from the interface names for these memory types. There are still many references to kv in the implementation of the recurrent memory which will need further adjustment. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor!: Rename all k/v related values for recurrent/hybrid to r/s Anywhere that "kv_<state|cell|size|etc>" is used, I've used the more generic "mem_" prefix. The specifics of "k" (key) translate to "r" (recurrent state) and "v" (value) translate to "s" (state-space embedding states). Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refacor: _recurrent -> _recr for brevity It just _happens_ to have the same number of letters as _attn! Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * style: Fix spacing for ref Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor: recurrent_layer() -> is_recurrent() Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * style: Fix spacing for size_s_bytes declaration Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
281 lines
9.1 KiB
C++
281 lines
9.1 KiB
C++
#include "llama-kv-cache-unified-iswa.h"
|
|
|
|
#include "llama-impl.h"
|
|
#include "llama-batch.h"
|
|
#include "llama-model.h"
|
|
|
|
#include <algorithm>
|
|
#include <cassert>
|
|
|
|
//
|
|
// llama_kv_cache_unified_iswa
|
|
//
|
|
|
|
llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|
const llama_model & model,
|
|
ggml_type type_k,
|
|
ggml_type type_v,
|
|
bool v_trans,
|
|
bool offload,
|
|
bool swa_full,
|
|
uint32_t kv_size,
|
|
uint32_t n_seq_max,
|
|
uint32_t n_ubatch,
|
|
uint32_t n_pad) : hparams(model.hparams) {
|
|
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
|
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
|
|
|
const uint32_t size_base = kv_size;
|
|
|
|
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
|
|
|
|
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
|
if (swa_full) {
|
|
LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
|
|
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
|
|
|
size_swa = size_base;
|
|
}
|
|
|
|
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
|
|
|
|
kv_base = std::make_unique<llama_kv_cache_unified>(
|
|
model, std::move(filter_base), type_k, type_v,
|
|
v_trans, offload, size_base, n_seq_max, n_pad,
|
|
0, LLAMA_SWA_TYPE_NONE);
|
|
|
|
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
|
|
|
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
|
model, std::move(filter_swa), type_k, type_v,
|
|
v_trans, offload, size_swa, n_seq_max, n_pad,
|
|
hparams.n_swa, hparams.swa_type);
|
|
}
|
|
|
|
void llama_kv_cache_unified_iswa::clear(bool data) {
|
|
kv_base->clear(data);
|
|
kv_swa ->clear(data);
|
|
}
|
|
|
|
bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
|
bool res = true;
|
|
|
|
res = res & kv_base->seq_rm(seq_id, p0, p1);
|
|
res = res & kv_swa ->seq_rm(seq_id, p0, p1);
|
|
|
|
return res;
|
|
}
|
|
|
|
void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
|
kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
}
|
|
|
|
void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
|
|
kv_base->seq_keep(seq_id);
|
|
kv_swa ->seq_keep(seq_id);
|
|
}
|
|
|
|
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
|
kv_base->seq_add(seq_id, p0, p1, shift);
|
|
kv_swa ->seq_add(seq_id, p0, p1, shift);
|
|
}
|
|
|
|
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
|
kv_base->seq_div(seq_id, p0, p1, d);
|
|
kv_swa ->seq_div(seq_id, p0, p1, d);
|
|
}
|
|
|
|
llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
|
|
// the base cache is a superset of the SWA cache, so we can just check the SWA cache
|
|
return kv_swa->seq_pos_min(seq_id);
|
|
}
|
|
|
|
llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
|
return kv_swa->seq_pos_max(seq_id);
|
|
}
|
|
|
|
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
|
|
GGML_UNUSED(embd_all);
|
|
|
|
// first try simple split
|
|
do {
|
|
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
|
|
|
|
std::vector<llama_ubatch> ubatches;
|
|
|
|
while (sbatch.n_tokens > 0) {
|
|
auto ubatch = sbatch.split_simple(n_ubatch);
|
|
|
|
ubatches.push_back(ubatch);
|
|
}
|
|
|
|
auto heads_base = kv_base->prepare(ubatches);
|
|
if (heads_base.empty()) {
|
|
break;
|
|
}
|
|
|
|
auto heads_swa = kv_swa->prepare(ubatches);
|
|
if (heads_swa.empty()) {
|
|
break;
|
|
}
|
|
|
|
assert(heads_base.size() == heads_swa.size());
|
|
|
|
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
|
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
|
} while (false);
|
|
|
|
// if it fails, try equal split
|
|
do {
|
|
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
|
|
|
|
std::vector<llama_ubatch> ubatches;
|
|
|
|
while (sbatch.n_tokens > 0) {
|
|
auto ubatch = sbatch.split_equal(n_ubatch);
|
|
|
|
ubatches.push_back(ubatch);
|
|
}
|
|
|
|
auto heads_base = kv_base->prepare(ubatches);
|
|
if (heads_base.empty()) {
|
|
break;
|
|
}
|
|
|
|
auto heads_swa = kv_swa->prepare(ubatches);
|
|
if (heads_swa.empty()) {
|
|
break;
|
|
}
|
|
|
|
assert(heads_base.size() == heads_swa.size());
|
|
|
|
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
|
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
|
} while (false);
|
|
|
|
// TODO: if we fail again, we should attempt different splitting strategies
|
|
// but to do that properly, we first have to refactor the batches to be more flexible
|
|
|
|
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
}
|
|
|
|
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
|
|
return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
|
|
}
|
|
|
|
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
|
|
return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
|
|
}
|
|
|
|
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
|
return kv_base->get_size() == kv_swa->get_size();
|
|
}
|
|
|
|
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
|
kv_base->state_write(io, seq_id);
|
|
kv_swa ->state_write(io, seq_id);
|
|
}
|
|
|
|
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
|
kv_base->state_read(io, seq_id);
|
|
kv_swa ->state_read(io, seq_id);
|
|
}
|
|
|
|
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
|
|
return kv_base.get();
|
|
}
|
|
|
|
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
|
|
return kv_swa.get();
|
|
}
|
|
|
|
//
|
|
// llama_kv_cache_unified_iswa_state
|
|
//
|
|
|
|
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
|
|
|
|
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
|
llama_kv_cache_unified_iswa * kv) :
|
|
state_base(kv->get_base()->init_full()),
|
|
state_swa (kv->get_swa ()->init_full()),
|
|
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
|
}
|
|
|
|
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
|
llama_kv_cache_unified_iswa * kv,
|
|
llama_context * lctx,
|
|
bool optimize) :
|
|
state_base(kv->get_base()->init_update(lctx, optimize)),
|
|
state_swa (kv->get_swa ()->init_update(lctx, optimize)),
|
|
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
|
}
|
|
|
|
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
|
llama_kv_cache_unified_iswa * kv,
|
|
llama_sbatch sbatch,
|
|
std::vector<uint32_t> heads_base,
|
|
std::vector<uint32_t> heads_swa,
|
|
std::vector<llama_ubatch> ubatches) :
|
|
sbatch(std::move(sbatch)),
|
|
ubatches(std::move(ubatches)),
|
|
// note: here we copy the ubatches. not sure if this is ideal
|
|
state_base(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches)),
|
|
state_swa (new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches)),
|
|
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
|
}
|
|
|
|
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
|
|
|
|
bool llama_kv_cache_unified_iswa_state::next() {
|
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
|
|
state_base->next();
|
|
state_swa ->next();
|
|
|
|
if (++i_next >= ubatches.size()) {
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool llama_kv_cache_unified_iswa_state::apply() {
|
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
|
|
bool res = true;
|
|
|
|
res = res & state_base->apply();
|
|
res = res & state_swa ->apply();
|
|
|
|
return res;
|
|
}
|
|
|
|
std::vector<int64_t> & llama_kv_cache_unified_iswa_state::out_ids() {
|
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
|
|
return sbatch.out_ids;
|
|
}
|
|
|
|
llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
|
|
return status;
|
|
}
|
|
|
|
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
|
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
|
|
return ubatches[i_next];
|
|
}
|
|
|
|
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
|
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
|
|
return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
|
|
}
|
|
|
|
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
|
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
|
|
return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
|
|
}
|