llama : refactor llama_context, llama_kv_cache, llm_build_context (#12181)

* llama : refactor llama_context, llama_kv_cache, llm_build_context

ggml-ci

* graph : don't mutate the KV cache during defrag

ggml-ci

* context : reduce virtuals + remove test function

ggml-ci

* context : move interface implementation to source file + factory

ggml-ci

* graph : move KV cache build functions to llama_context impl

ggml-ci

* graph : remove model reference from build_pooling

ggml-ci

* graph : remove llama_model reference

ggml-ci

* kv_cache : provide rope factors

ggml-ci

* graph : rework inputs to use only unique_ptr, remove attn input abstraction

ggml-ci

* context : remove llama_context_i abstraction

ggml-ci

* context : clean-up

ggml-ci

* graph : clean-up

ggml-ci

* llama : remove redundant keywords (struct, enum)

ggml-ci

* model : adapt gemma3

ggml-ci

* graph : restore same attention ops as on master

ggml-ci

* llama : remove TODO + fix indent

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-03-13 12:35:44 +02:00 committed by GitHub
parent 2048b5913d
commit e0dbec0bc6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
46 changed files with 13903 additions and 12190 deletions

View file

@ -15,18 +15,21 @@ add_library(llama
llama-chat.cpp
llama-context.cpp
llama-grammar.cpp
llama-graph.cpp
llama-hparams.cpp
llama-impl.cpp
llama-io.cpp
llama-kv-cache.cpp
llama-memory.cpp
llama-mmap.cpp
llama-model-loader.cpp
llama-model.cpp
llama-quant.cpp
llama-sampling.cpp
llama-vocab.cpp
unicode.h
unicode.cpp
unicode-data.cpp
unicode.cpp
unicode.h
)
target_include_directories(llama PUBLIC . ../include ../common)

View file

@ -4,14 +4,13 @@
#include "llama-mmap.h"
#include "llama-model.h"
#include <algorithm>
#include <map>
#include <cassert>
#include <stdexcept>
// vec
struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
return nullptr;
}
@ -19,7 +18,7 @@ struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
return tensors[il];
}
struct ggml_tensor * llama_adapter_cvec::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const {
ggml_tensor * llama_adapter_cvec::apply_to(ggml_context * ctx, ggml_tensor * cur, int il) const {
ggml_tensor * layer_dir = tensor_for(il);
if (layer_dir != nullptr) {
cur = ggml_add(ctx, cur, layer_dir);
@ -40,7 +39,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
struct ggml_init_params params = {
ggml_init_params params = {
/*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
@ -91,7 +90,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
return true;
}
int32_t llama_adapter_cvec::apply(
bool llama_adapter_cvec::apply(
const llama_model & model,
const float * data,
size_t len,
@ -104,17 +103,17 @@ int32_t llama_adapter_cvec::apply(
// disable the current control vector (but leave allocated for later)
layer_start = -1;
layer_end = -1;
return 0;
return true;
}
if (n_embd != (int) hparams.n_embd) {
LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
return 1;
return false;
}
if (tensors.empty()) {
if (!init(model)) {
return 1;
return false;
}
}
@ -130,12 +129,12 @@ int32_t llama_adapter_cvec::apply(
}
}
return 0;
return true;
}
// lora
llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor * w) {
llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) {
const std::string name(w->name);
const auto pos = ab_map.find(name);
@ -146,11 +145,11 @@ llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor *
return nullptr;
}
static void llama_adapter_lora_init_impl(struct llama_model & model, const char * path_lora, struct llama_adapter_lora & adapter) {
static void llama_adapter_lora_init_impl(llama_model & model, const char * path_lora, llama_adapter_lora & adapter) {
LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
ggml_context * ctx_init;
struct gguf_init_params meta_gguf_params = {
gguf_init_params meta_gguf_params = {
/* .no_alloc = */ true,
/* .ctx = */ &ctx_init,
};
@ -201,7 +200,7 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
// add a new context
struct ggml_init_params params = {
ggml_init_params params = {
/*.mem_size =*/ n_tensors*ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
@ -264,7 +263,7 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)");
}
struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
// validate tensor shape
if (is_token_embd) {
// expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd()
@ -281,8 +280,8 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
}
// save tensor to adapter
struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
ggml_set_name(tensor_a, w.a->name);
ggml_set_name(tensor_b, w.b->name);
adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b);
@ -308,7 +307,7 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
{
llama_file gguf_file(path_lora, "rb");
std::vector<uint8_t> read_buf;
auto set_tensor = [&](struct ggml_tensor * orig, struct ggml_tensor * dev) {
auto set_tensor = [&](ggml_tensor * orig, ggml_tensor * dev) {
size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name));
size_t size = ggml_nbytes(orig);
read_buf.resize(size);
@ -327,8 +326,8 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
}
struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model, const char * path_lora) {
struct llama_adapter_lora * adapter = new llama_adapter_lora();
llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) {
llama_adapter_lora * adapter = new llama_adapter_lora();
try {
llama_adapter_lora_init_impl(*model, path_lora, *adapter);
@ -342,6 +341,6 @@ struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model,
return nullptr;
}
void llama_adapter_lora_free(struct llama_adapter_lora * adapter) {
void llama_adapter_lora_free(llama_adapter_lora * adapter) {
delete adapter;
}

View file

@ -15,11 +15,11 @@
//
struct llama_adapter_cvec {
struct ggml_tensor * tensor_for(int il) const;
ggml_tensor * tensor_for(int il) const;
struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const;
ggml_tensor * apply_to(ggml_context * ctx, ggml_tensor * cur, int il) const;
int32_t apply(
bool apply(
const llama_model & model,
const float * data,
size_t len,
@ -36,7 +36,7 @@ private:
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
std::vector<struct ggml_tensor *> tensors; // per layer
std::vector<ggml_tensor *> tensors; // per layer
};
//
@ -44,8 +44,8 @@ private:
//
struct llama_adapter_lora_weight {
struct ggml_tensor * a = nullptr;
struct ggml_tensor * b = nullptr;
ggml_tensor * a = nullptr;
ggml_tensor * b = nullptr;
// get actual scale based on rank and alpha
float get_scale(float alpha, float adapter_scale) const {
@ -55,12 +55,12 @@ struct llama_adapter_lora_weight {
}
llama_adapter_lora_weight() = default;
llama_adapter_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {}
llama_adapter_lora_weight(ggml_tensor * a, ggml_tensor * b) : a(a), b(b) {}
};
struct llama_adapter_lora {
// map tensor name to lora_a_b
std::unordered_map<std::string, struct llama_adapter_lora_weight> ab_map;
std::unordered_map<std::string, llama_adapter_lora_weight> ab_map;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
@ -70,5 +70,7 @@ struct llama_adapter_lora {
llama_adapter_lora() = default;
~llama_adapter_lora() = default;
llama_adapter_lora_weight * get_weight(struct ggml_tensor * w);
llama_adapter_lora_weight * get_weight(ggml_tensor * w);
};
using llama_adapter_loras = std::unordered_map<llama_adapter_lora *, float>;

View file

@ -42,9 +42,9 @@ struct llama_sbatch {
bool logits_all; // TODO: remove once lctx.logits_all is removed too
// sorted indices into the batch
std::vector<size_t> ids;
std::vector<int64_t> ids;
// batch indices of the output
std::vector<size_t> out_ids;
std::vector<int64_t> out_ids;
std::vector<llama_sbatch_seq> seq;
const llama_batch * batch = nullptr;

File diff suppressed because it is too large Load diff

View file

@ -3,66 +3,210 @@
#include "llama.h"
#include "llama-batch.h"
#include "llama-cparams.h"
#include "llama-model.h"
#include "llama-kv-cache.h"
#include "llama-graph.h"
#include "llama-adapter.h"
#include "ggml-cpp.h"
#include <map>
#include <unordered_map>
#include <vector>
#include <set>
struct llama_model;
struct llama_kv_cache;
class llama_io_read_i;
class llama_io_write_i;
struct llama_context {
llama_context(const llama_model & model)
: model(model)
, t_start_us(model.t_start_us)
, t_load_us(model.t_load_us) {}
// init scheduler and compute buffers, reserve worst-case graphs
llama_context(
const llama_model & model,
llama_context_params params);
const struct llama_model & model;
~llama_context();
struct llama_cparams cparams;
struct llama_sbatch sbatch; // TODO: revisit if needed
struct llama_kv_cache kv_self;
struct llama_adapter_cvec cvec;
void synchronize();
std::unordered_map<struct llama_adapter_lora *, float> lora;
const llama_model & get_model() const;
std::vector<ggml_backend_ptr> backends;
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
uint32_t n_ctx() const;
uint32_t n_ctx_per_seq() const;
uint32_t n_batch() const;
uint32_t n_ubatch() const;
uint32_t n_seq_max() const;
ggml_backend_t backend_cpu = nullptr;
uint32_t n_threads() const;
uint32_t n_threads_batch() const;
ggml_threadpool_t threadpool = nullptr;
ggml_threadpool_t threadpool_batch = nullptr;
llama_kv_cache * get_kv_self();
const llama_kv_cache * get_kv_self() const;
bool has_evaluated_once = false;
void kv_self_update();
mutable int64_t t_start_us;
mutable int64_t t_load_us;
mutable int64_t t_p_eval_us = 0;
mutable int64_t t_eval_us = 0;
enum llama_pooling_type pooling_type() const;
mutable int64_t t_compute_start_us = 0;
mutable int64_t n_queued_tokens = 0;
float * get_logits();
float * get_logits_ith(int32_t i);
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
mutable int32_t n_eval = 0; // number of eval calls
float * get_embeddings();
float * get_embeddings_ith(int32_t i);
float * get_embeddings_seq(llama_seq_id seq_id);
// host buffer for the model output (logits and embeddings)
ggml_backend_buffer_ptr buf_output;
void attach_threadpool(
ggml_threadpool_t threadpool,
ggml_threadpool_t threadpool_batch);
void detach_threadpool();
void set_n_threads(int32_t n_threads, int32_t n_threads_batch);
void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data);
void set_embeddings (bool value);
void set_causal_attn(bool value);
void set_adapter_lora(
llama_adapter_lora * adapter,
float scale);
bool rm_adapter_lora(
llama_adapter_lora * adapter);
void clear_adapter_lora();
bool apply_adapter_cvec(
const float * data,
size_t len,
int32_t n_embd,
int32_t il_start,
int32_t il_end);
int encode(llama_batch & inp_batch);
int decode(llama_batch & inp_batch);
//
// state save/load
//
size_t state_get_size();
size_t state_get_data( uint8_t * dst, size_t size);
size_t state_set_data(const uint8_t * src, size_t size);
size_t state_seq_get_size(llama_seq_id seq_id);
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size);
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
bool state_load_file(
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out);
bool state_save_file(
const char * filepath,
const llama_token * tokens,
size_t n_token_count);
size_t state_seq_load_file(
llama_seq_id seq_id,
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out);
size_t state_seq_save_file(
llama_seq_id seq_id,
const char * filepath,
const llama_token * tokens,
size_t n_token_count);
//
// perf
//
llama_perf_context_data perf_get_data() const;
void perf_reset();
private:
//
// output
//
// Make sure enough space is available for outputs.
// Returns max number of outputs for which space was reserved.
int32_t output_reserve(int32_t n_outputs);
// make the outputs have the same order they had in the user-provided batch
// TODO: maybe remove this
void output_reorder();
//
// graph
//
int32_t graph_max_nodes() const;
// zero-out inputs and create the ctx_compute for the compute graph
ggml_cgraph * graph_init();
llm_graph_result_ptr graph_build(
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype);
// returns the result of ggml_backend_sched_graph_compute_async execution
ggml_status graph_compute(
ggml_cgraph * gf,
bool batched);
llm_graph_cb graph_get_cb() const;
// used by kv_self_update()
ggml_tensor * build_rope_shift(
ggml_context * ctx0,
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * factors,
ggml_backend_buffer * bbuf) const;
llm_graph_result_ptr build_kv_self_shift(
ggml_context * ctx0,
ggml_cgraph * gf) const;
llm_graph_result_ptr build_kv_self_defrag(
ggml_context * ctx0,
ggml_cgraph * gf) const;
// TODO: read/write lora adapters and cvec
size_t state_write_data(llama_io_write_i & io);
size_t state_read_data (llama_io_read_i & io);
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id);
//
// members
//
const llama_model & model;
llama_cparams cparams;
llama_adapter_cvec cvec;
llama_adapter_loras loras;
llama_sbatch sbatch;
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
std::unique_ptr<llama_kv_cache_unified> kv_self;
// TODO: remove
bool logits_all = false;
// decode output (2-dimensional array: [n_outputs][n_vocab])
size_t logits_size = 0; // capacity (of floats) for logits
float * logits = nullptr;
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
bool logits_all = false;
// embeddings output (2-dimensional array: [n_outputs][n_embd])
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
size_t embd_size = 0; // capacity (of floats) for embeddings
@ -72,57 +216,47 @@ struct llama_context {
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
std::map<llama_seq_id, std::vector<float>> embd_seq;
// whether we are computing encoder output or decoder output
bool is_encoding = false;
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
int32_t n_outputs_max = 0; // capacity (of tokens positions) for the output buffers
// TODO: find a better way to accommodate mutli-dimension position encoding methods
// number of position id each token get, 1 for each token in most cases.
// when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate.
int n_pos_per_token = 1;
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
// output of the encoder part of the encoder-decoder models
std::vector<float> embd_enc;
std::vector<std::set<llama_seq_id>> seq_ids_enc;
// memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta;
ggml_backend_sched_ptr sched;
ggml_backend_t backend_cpu = nullptr;
std::vector<ggml_backend_ptr> backends;
ggml_context_ptr ctx_compute;
ggml_threadpool_t threadpool = nullptr;
ggml_threadpool_t threadpool_batch = nullptr;
ggml_abort_callback abort_callback = nullptr;
void * abort_callback_data = nullptr;
// input tensors
struct ggml_tensor * inp_tokens; // I32 [n_batch]
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
struct ggml_tensor * inp_pos; // I32 [n_batch]
struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch]
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
// buffer types used for the compute buffer of each backend
std::vector<ggml_backend_t> backend_ptrs;
std::vector<ggml_backend_buffer_type_t> backend_buft;
// memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta;
// host buffer for the model output (logits and embeddings)
ggml_backend_buffer_ptr buf_output;
bool has_evaluated_once = false;
// perf
mutable int64_t t_start_us = 0;
mutable int64_t t_load_us = 0;
mutable int64_t t_p_eval_us = 0;
mutable int64_t t_eval_us = 0;
mutable int64_t t_compute_start_us = 0;
mutable int64_t n_queued_tokens = 0;
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
mutable int32_t n_eval = 0; // number of eval calls
};
// TODO: make these methods of llama_context
void llama_set_k_shift(struct llama_context & lctx);
void llama_set_s_copy(struct llama_context & lctx);
void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch);
// Make sure enough space is available for outputs.
// Returns max number of outputs for which space was reserved.
size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs);
// make the outputs have the same order they had in the user-provided batch
void llama_output_reorder(struct llama_context & ctx);
// For internal test use
// TODO: remove
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(struct llama_context * ctx);

1695
src/llama-graph.cpp Normal file

File diff suppressed because it is too large Load diff

576
src/llama-graph.h Normal file
View file

@ -0,0 +1,576 @@
#pragma once
#include "llama-arch.h"
#include "llama-hparams.h"
#include "llama-adapter.h"
#include <cstdint>
#include <vector>
#include <memory>
#include <set>
#include <functional>
struct ggml_cgraph;
struct ggml_context;
struct ggml_tensor;
struct llama_ubatch;
struct llama_cparams;
class llama_memory_i;
class llama_kv_cache_unified;
// certain models (typically multi-modal) can produce different types of graphs
enum llm_graph_type {
LLM_GRAPH_TYPE_DEFAULT,
LLM_GRAPH_TYPE_ENCODER,
LLM_GRAPH_TYPE_DECODER,
};
enum llm_ffn_op_type {
LLM_FFN_SILU,
LLM_FFN_GELU,
LLM_FFN_RELU,
LLM_FFN_RELU_SQR,
LLM_FFN_SWIGLU,
};
enum llm_ffn_gate_type {
LLM_FFN_SEQ,
LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
};
enum llm_norm_type {
LLM_NORM,
LLM_NORM_RMS,
LLM_NORM_GROUP,
};
// TODO: tmp - need something better to pass the data from the encoder to the decoder
struct llama_cross {
// the output embeddings from the encoder as a ggml tensor
// TODO: this needs more work to be correct, for now copy the embeddings data to host memory
// ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
//ggml_tensor * t_embd = nullptr;
int64_t n_embd = 0;
int64_t n_enc = 0;
// embeddings data copied to host memory (tmp)
std::vector<float> v_embd;
// needed to construct the cross-attention mask in the decoder
std::vector<std::set<llama_seq_id>> seq_ids_enc;
};
//
// llm_graph_input
//
class llm_graph_input_i {
public:
virtual ~llm_graph_input_i() = default;
virtual void set_input(const llama_ubatch * ubatch) = 0;
};
using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
class llm_graph_input_embd : public llm_graph_input_i {
public:
llm_graph_input_embd() = default;
virtual ~llm_graph_input_embd() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * tokens = nullptr; // I32 [n_batch]
ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
};
class llm_graph_input_pos : public llm_graph_input_i {
public:
llm_graph_input_pos(int64_t n_pos_per_token) : n_pos_per_token(n_pos_per_token) {}
virtual ~llm_graph_input_pos() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * pos = nullptr; // I32 [n_batch]
const int64_t n_pos_per_token = 1;
};
class llm_graph_input_pos_bucket : public llm_graph_input_i {
public:
llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
virtual ~llm_graph_input_pos_bucket() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
const llama_hparams & hparams;
};
class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
public:
llm_graph_input_pos_bucket_kv(
const llama_hparams & hparams,
const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {}
virtual ~llm_graph_input_pos_bucket_kv() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
const llama_hparams & hparams;
const llama_kv_cache_unified * kv_self;
};
class llm_graph_input_out_ids : public llm_graph_input_i {
public:
llm_graph_input_out_ids(
const llama_hparams & hparams,
const llama_cparams & cparams,
int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
virtual ~llm_graph_input_out_ids() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * out_ids; // I32 [n_outputs]
const llama_hparams & hparams;
const llama_cparams & cparams;
const int32_t n_outputs;
};
class llm_graph_input_mean : public llm_graph_input_i {
public:
llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
virtual ~llm_graph_input_mean() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * mean; // F32 [n_batch, n_batch]
const llama_cparams & cparams;
};
class llm_graph_input_cls : public llm_graph_input_i {
public:
llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
virtual ~llm_graph_input_cls() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * cls; // I32 [n_batch]
const llama_cparams & cparams;
};
class llm_graph_input_s_copy : public llm_graph_input_i {
public:
llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
virtual ~llm_graph_input_s_copy() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_copy; // I32 [kv_size]
const llama_kv_cache_unified * kv_self;
};
class llm_graph_input_s_mask : public llm_graph_input_i {
public:
llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
virtual ~llm_graph_input_s_mask() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_mask; // F32 [1, n_kv]
const llama_kv_cache_unified * kv_self;
};
class llm_graph_input_cross_embd : public llm_graph_input_i {
public:
llm_graph_input_cross_embd(
const llama_cross * cross) : cross(cross) {}
virtual ~llm_graph_input_cross_embd() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
const llama_cross * cross;
};
class llm_graph_input_attn_no_cache : public llm_graph_input_i {
public:
llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
hparams(hparams),
cparams(cparams) {
}
~llm_graph_input_attn_no_cache() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
const llama_hparams & hparams;
const llama_cparams & cparams;
};
class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
public:
llm_graph_input_attn_kv_unified(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_unified * kv_self) :
hparams(hparams),
cparams(cparams),
kv_self(kv_self) {
}
~llm_graph_input_attn_kv_unified() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_kv_cache_unified * kv_self;
};
class llm_graph_input_attn_cross : public llm_graph_input_i {
public:
llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
~llm_graph_input_attn_cross() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
const llama_cross * cross = nullptr;
};
//
// llm_graph_result
//
// these objects deliver the result from the graph build process back to the llama_context
// note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
// specific data, by calling the set_inputs() method
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
// these are used by the llama_context to extact the relevant data, based on the compute parameters
class llm_graph_result_i {
public:
virtual ~llm_graph_result_i() = default;
virtual ggml_tensor * get_logits() = 0;
virtual ggml_tensor * get_embd() = 0;
virtual ggml_tensor * get_embd_pooled() = 0;
virtual void set_inputs(const llama_ubatch * ubatch) = 0;
};
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
class llm_graph_result : public llm_graph_result_i {
public:
virtual ~llm_graph_result() = default;
ggml_tensor * get_logits() override { return t_logits; }
ggml_tensor * get_embd() override { return t_embd; }
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
void set_inputs(const llama_ubatch * ubatch) override {
for (auto & input : inputs) {
input->set_input(ubatch);
}
}
llm_graph_input_i * add_input(llm_graph_input_ptr input) {
inputs.emplace_back(std::move(input));
return inputs.back().get();
}
// important graph nodes
ggml_tensor * t_logits = nullptr;
ggml_tensor * t_embd = nullptr;
ggml_tensor * t_embd_pooled = nullptr;
std::vector<llm_graph_input_ptr> inputs;
};
//
// llm_graph_context
//
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
struct llm_graph_params {
ggml_context * ctx;
const llm_arch arch;
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_ubatch & ubatch;
ggml_backend_sched * sched;
ggml_backend * backend_cpu;
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_i * memory;
const llama_cross * cross;
int32_t n_outputs;
const llm_graph_cb & cb;
};
struct llm_graph_context {
const llm_arch arch;
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_ubatch & ubatch;
const int64_t n_embd;
const int64_t n_layer;
const int64_t n_rot;
const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
const int64_t n_ctx_per_seq;
const int64_t n_head;
const int64_t n_head_kv;
const int64_t n_embd_head_k;
const int64_t n_embd_k_gqa;
const int64_t n_embd_head_v;
const int64_t n_embd_v_gqa;
const int64_t n_expert;
const int64_t n_expert_used;
const float freq_base;
const float freq_scale;
const float ext_factor;
const float attn_factor;
const float beta_fast;
const float beta_slow;
const float norm_eps;
const float norm_rms_eps;
const int32_t n_tokens;
const int32_t n_outputs;
const int32_t n_ctx_orig; // yarn
const enum llama_pooling_type pooling_type;
const enum llama_rope_type rope_type;
ggml_context * ctx0 = nullptr;
ggml_backend_sched * sched;
ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_i * memory;
const llama_cross * cross;
const llm_graph_cb & cb_func;
std::unique_ptr<llm_graph_result> res;
llm_graph_context(const llm_graph_params & params);
int64_t n_pos_per_token() const;
void cb(ggml_tensor * cur, const char * name, int il) const;
//
// common
//
ggml_tensor * build_cvec(
ggml_tensor * cur,
int il) const;
// do mat_mul, while optionally apply lora
ggml_tensor * build_lora_mm(
ggml_tensor * w,
ggml_tensor * cur) const;
// do mat_mul_id, while optionally apply lora
ggml_tensor * build_lora_mm_id(
ggml_tensor * w, // ggml_tensor * as
ggml_tensor * cur, // ggml_tensor * b
ggml_tensor * ids) const;
ggml_tensor * build_norm(
ggml_tensor * cur,
ggml_tensor * mw,
ggml_tensor * mb,
llm_norm_type type,
int il) const;
ggml_tensor * build_ffn(
ggml_tensor * cur,
ggml_tensor * up,
ggml_tensor * up_b,
ggml_tensor * up_s,
ggml_tensor * gate,
ggml_tensor * gate_b,
ggml_tensor * gate_s,
ggml_tensor * down,
ggml_tensor * down_b,
ggml_tensor * down_s,
ggml_tensor * act_scales,
llm_ffn_op_type type_op,
llm_ffn_gate_type type_gate,
int il) const;
ggml_tensor * build_moe_ffn(
ggml_tensor * cur,
ggml_tensor * gate_inp,
ggml_tensor * up_exps,
ggml_tensor * gate_exps,
ggml_tensor * down_exps,
ggml_tensor * exp_probs_b,
int64_t n_expert,
int64_t n_expert_used,
llm_ffn_op_type type_op,
bool norm_w,
bool scale_w,
float w_scale,
llama_expert_gating_func_type gating_op,
int il) const;
//
// inputs
//
ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
ggml_tensor * build_inp_pos() const;
ggml_tensor * build_inp_out_ids() const;
ggml_tensor * build_inp_mean() const;
ggml_tensor * build_inp_cls() const;
ggml_tensor * build_inp_s_copy() const;
ggml_tensor * build_inp_s_mask() const;
ggml_tensor * build_inp_cross_embd() const;
ggml_tensor * build_inp_pos_bucket_enc() const;
ggml_tensor * build_inp_pos_bucket_dec() const;
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
//
// attention
//
ggml_tensor * build_attn_mha(
ggml_cgraph * gf,
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * kq_b,
ggml_tensor * kq_mask,
bool v_trans,
float kq_scale) const;
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
ggml_tensor * build_attn(
llm_graph_input_attn_no_cache * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il) const;
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified(
bool causal,
bool swa) const;
ggml_tensor * build_attn(
llm_graph_input_attn_kv_unified * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il) const;
llm_graph_input_attn_cross * build_attn_inp_cross() const;
ggml_tensor * build_attn(
llm_graph_input_attn_cross * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il) const;
//
// recurrent
//
ggml_tensor * build_copy_mask_state(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
int32_t n_state,
int32_t n_seqs) const;
ggml_tensor * build_rwkv_token_shift_load(
ggml_cgraph * gf,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const;
ggml_tensor * build_rwkv_token_shift_store(
ggml_tensor * token_shift,
const llama_ubatch & ubatch,
int il) const;
//
// pooling
//
void build_pooling(
ggml_cgraph * gf,
ggml_tensor * cls,
ggml_tensor * cls_b,
ggml_tensor * cls_out,
ggml_tensor * cls_out_b) const;
};

15
src/llama-io.cpp Normal file
View file

@ -0,0 +1,15 @@
#include "llama-io.h"
void llama_io_write_i::write_string(const std::string & str) {
uint32_t str_size = str.size();
write(&str_size, sizeof(str_size));
write(str.data(), str_size);
}
void llama_io_read_i::read_string(std::string & str) {
uint32_t str_size;
read_to(&str_size, sizeof(str_size));
str.assign((const char *) read(str_size), str_size);
}

35
src/llama-io.h Normal file
View file

@ -0,0 +1,35 @@
#pragma once
#include <cstddef>
#include <cstdint>
#include <string>
struct ggml_tensor;
class llama_io_write_i {
public:
llama_io_write_i() = default;
virtual ~llama_io_write_i() = default;
virtual void write(const void * src, size_t size) = 0;
virtual void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) = 0;
// bytes written so far
virtual size_t n_bytes() = 0;
void write_string(const std::string & str);
};
class llama_io_read_i {
public:
llama_io_read_i() = default;
virtual ~llama_io_read_i() = default;
virtual const uint8_t * read(size_t size) = 0;
virtual void read_to(void * dst, size_t size) = 0;
// bytes read so far
virtual size_t n_bytes() = 0;
void read_string(std::string & str);
};

File diff suppressed because it is too large Load diff

View file

@ -1,12 +1,29 @@
#pragma once
#include "llama.h"
#include "llama-io.h"
#include "llama-memory.h"
#include "ggml-cpp.h"
#include <functional>
#include <set>
#include <vector>
#include <algorithm>
struct llama_cparams;
struct llama_hparams;
struct llama_ubatch;
struct llama_kv_cache : public llama_memory_i {
using llama_memory_i::llama_memory_i;
virtual int32_t get_n_tokens() const = 0;
virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
virtual bool get_can_shift() const = 0;
bool get_can_edit() const override { return get_can_shift(); }
};
struct llama_kv_cell {
llama_pos pos = -1;
@ -29,11 +46,105 @@ struct llama_kv_cell {
}
};
// a structure holds information about the slot found in llama_kv_cache_find_slot
struct llama_kv_cache_slot_info {
std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
bool found = false; // the slot was found
explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
operator bool() const { return found; }
};
// ring-buffer of cached KV data
struct llama_kv_cache {
// TODO: pimpl
// TODO: add notion of max sequences
class llama_kv_cache_unified : public llama_kv_cache {
public:
// can be used to query data from the model if needed
struct callbacks {
std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
};
llama_kv_cache_unified(
const llama_hparams & hparams,
callbacks cbs);
virtual ~llama_kv_cache_unified() = default;
// TODO: become constructor
bool init(
const llama_model & model, // TODO: do not reference the model
const llama_cparams & cparams,
ggml_type type_k,
ggml_type type_v,
uint32_t kv_size,
bool offload);
int32_t get_n_tokens() const override;
uint32_t get_used_cells() const override;
size_t total_size() const;
// TODO: better data structures to reduce the cost of this operation
llama_pos pos_max() const;
void clear() override;
void defrag() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_max(llama_seq_id seq_id) override;
bool get_can_shift() const override;
// find an empty slot of size "n_tokens" in the cache
// updates the cache head
// returns a structure holding information about the slot found
// Note: On success, it's important that cache.head points
// to the first cell of the slot.
llama_kv_cache_slot_info find_slot(const llama_ubatch & batch);
// TODO: maybe not needed
uint32_t get_padding(const llama_cparams & cparams) const;
// find how many cells are currently in use
uint32_t cell_max() const;
size_t size_k_bytes() const;
size_t size_v_bytes() const;
// defrag
struct {
std::vector<uint32_t> ids;
} defrag_info;
// return true if cells have been moved
bool defrag_prepare(int32_t n_max_nodes);
// state save/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
// members
const llama_hparams & hparams;
callbacks cbs;
bool has_shift = false;
bool do_defrag = false;
// TODO: remove this and implement llama_kv_cache_recurrent instead
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
bool v_trans = true; // the value tensor is transposed
bool can_shift = false;
@ -47,124 +158,30 @@ struct llama_kv_cache {
// computed before each graph build
uint32_t n = 0;
std::vector<llama_kv_cell> cells;
std::vector<ggml_tensor *> k_l; // per layer
std::vector<ggml_tensor *> v_l;
private:
ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16;
std::vector<llama_kv_cell> cells;
std::vector<struct ggml_tensor *> k_l; // per layer
std::vector<struct ggml_tensor *> v_l;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
size_t total_size() const {
size_t size = 0;
for (const auto & buf : bufs) {
size += ggml_backend_buffer_get_size(buf.get());
}
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
return size;
}
// TODO: better data structures to reduce the cost of this operation
llama_pos max_pos() const {
llama_pos max_pos = -1;
for (const auto & cell : cells) {
max_pos = std::max(max_pos, cell.pos);
}
return max_pos;
}
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
// a structure holds information about the slot found in llama_kv_cache_find_slot
struct llama_kv_cache_slot_info {
std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
bool found = false; // the slot was found
explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
operator bool() const { return found; }
};
// TODO: maybe not needed
uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams);
bool llama_kv_cache_init(
struct llama_kv_cache & cache,
const llama_model & model,
const llama_cparams & cparams,
ggml_type type_k,
ggml_type type_v,
uint32_t kv_size,
bool offload);
// find an empty slot of size "n_tokens" in the cache
// updates the cache head
// returns a structure holding information about the slot found
// Note: On success, it's important that cache.head points
// to the first cell of the slot.
struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
struct llama_kv_cache & cache,
const struct llama_ubatch & batch);
// find how many cells are currently in use
uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache);
void llama_kv_cache_clear(struct llama_kv_cache & cache);
bool llama_kv_cache_seq_rm(
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1);
void llama_kv_cache_seq_cp(
struct llama_kv_cache & cache,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1);
void llama_kv_cache_seq_keep(
struct llama_kv_cache & cache,
llama_seq_id seq_id);
void llama_kv_cache_seq_add(
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta);
void llama_kv_cache_seq_div(
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d);
llama_pos llama_kv_cache_seq_pos_max(
struct llama_kv_cache & cache,
llama_seq_id seq_id);
void llama_kv_cache_defrag(struct llama_kv_cache & cache);
int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv);
int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv);
bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv);
//
// kv cache view
//
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max);
void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv);
// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
//class llama_kv_cache_recurrent : public llama_kv_cache_unified {
//public:
// using llama_kv_cache_unified::llama_kv_cache_unified;
//};
//
// kv cache restore
@ -184,13 +201,15 @@ struct llama_kv_slot_restorer {
bool do_restore = false;
explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
llama_kv_cache_unified & cache;
explicit llama_kv_slot_restorer(llama_kv_cache_unified & cache) : cache(cache) {
old_state.head = cache.head;
old_state.n = cache.n;
}
// saves a slot information for future restoration
void save(const struct llama_kv_cache_slot_info & slot) {
void save(const llama_kv_cache_slot_info & slot) {
if (slot) {
do_restore = true;
if (slot.boundaries.first != slot.boundaries.second) {
@ -201,19 +220,68 @@ struct llama_kv_slot_restorer {
// must be explicitly called to restore the kv_cache state
// and rollback changes from all llama_kv_cache_find_slot calls
void restore(struct llama_kv_cache & cache) {
void restore() {
if (do_restore) {
cache.head = old_state.head;
cache.n = old_state.n;
if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
llama_kv_cache_seq_rm(cache, -1, -1, -1);
cache.seq_rm(-1, -1, -1);
} else {
for (auto & slot : slot_boundaries) {
llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
cache.seq_rm(-1, slot.first, slot.second);
}
}
}
}
};
// TODO: maybe become part of the public llama_kv_cache in the future
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv);
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv);
void llama_kv_cache_clear(llama_kv_cache * kv);
bool llama_kv_cache_seq_rm(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1);
void llama_kv_cache_seq_cp(
llama_kv_cache * kv,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1);
void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id);
void llama_kv_cache_seq_add(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta);
void llama_kv_cache_seq_div(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d);
llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id);
void llama_kv_cache_defrag(llama_kv_cache * kv);
bool llama_kv_cache_can_shift(const llama_kv_cache * kv);
//
// kv cache view
//
llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max);
void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv);

1
src/llama-memory.cpp Normal file
View file

@ -0,0 +1 @@
#include "llama-memory.h"

21
src/llama-memory.h Normal file
View file

@ -0,0 +1,21 @@
#pragma once
#include "llama.h"
// general concept of LLM memory
// the KV cache is a type of LLM memory, but there can be other types
class llama_memory_i {
public:
virtual void clear() = 0;
virtual void defrag() = 0;
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
virtual void seq_keep(llama_seq_id seq_id) = 0;
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0;
virtual bool get_can_edit() const = 0;
};

File diff suppressed because it is too large Load diff

View file

@ -2,7 +2,9 @@
#include "llama.h"
#include "llama-arch.h"
#include "llama-graph.h"
#include "llama-hparams.h"
#include "llama-memory.h"
#include "llama-vocab.h"
#include <memory>
@ -10,6 +12,8 @@
#include <unordered_map>
#include <vector>
struct llama_cparams;
struct llama_ubatch;
struct llama_model_loader;
// available models
@ -347,7 +351,7 @@ struct llama_model {
std::string desc() const;
size_t size() const;
size_t max_nodes() const;
size_t n_tensors() const;
size_t n_devices() const;
// total number of parameters in the model
@ -362,9 +366,22 @@ struct llama_model {
const struct ggml_tensor * get_tensor(const char * name) const;
// TODO: move this to new llm_arch_model_i interface
llama_memory_i * create_memory() const; // TODO: params
// TODO: move this to new llm_arch_model_i interface
llm_graph_result_ptr build_graph(
const llm_graph_params & params,
ggml_cgraph * gf,
llm_graph_type type) const;
private:
struct impl;
std::unique_ptr<impl> pimpl;
};
const char * llm_type_name(llm_type type);
// For internal test use
// TODO: remove
const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model);

File diff suppressed because it is too large Load diff