kv-cache : refactor + add llama_memory_state_i (#13746)
* kv-cache : simplify the "struct llama_kv_cache" interface ggml-ci * kv-cache : revert the (n_swa + n_ubatch) change (for next PR) ggml-ci * kv-cache : some comments ggml-ci * context : fix graph reserve for multiple sequences ggml-ci * kv-cache : fix typo [no ci] * kv-cache : fix find_slot() logic for free slots ggml-ci * llama : add TODO for deprecating the defrag API in the future * kv-cache : improve find_slot() using min/max seq pos info ggml-ci * llama : handle aborts and compute errors ggml-ci * memory : extract state into llama_memory_state ggml-ci * kv-cache : add comments ggml-ci * server : update batching logic to reset n_batch on successful decode * server : upon full re-processing, remove the sequence from the cache * kv-cache : add TODO for doing split_equal when split_simple fails ggml-ci
This commit is contained in:
parent
eb3949938e
commit
12d0188c0d
14 changed files with 1304 additions and 655 deletions
|
@ -6,9 +6,10 @@
|
|||
#include "llama-model.h"
|
||||
#include "llama-kv-cache.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <stdexcept>
|
||||
#include <cinttypes>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <stdexcept>
|
||||
|
||||
//
|
||||
// llama_context
|
||||
|
@ -259,15 +260,9 @@ llama_context::llama_context(
|
|||
|
||||
// reserve worst-case graph
|
||||
if (!hparams.vocab_only && memory) {
|
||||
const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
|
||||
// restore later
|
||||
// TODO: something cleaner
|
||||
const auto n_outputs_save = n_outputs;
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||
|
||||
int n_splits_pp = -1;
|
||||
|
@ -279,23 +274,17 @@ llama_context::llama_context(
|
|||
// simulate full KV cache
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
kv_self->set_full();
|
||||
const auto kv_state = kv_self->init_full();
|
||||
if (!kv_state) {
|
||||
throw std::runtime_error("failed to initialize KV cache");
|
||||
}
|
||||
|
||||
cross.v_embd.clear();
|
||||
|
||||
// reserve pp graph first so that buffers are only allocated once
|
||||
{
|
||||
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
// max number of outputs
|
||||
n_outputs = ubatch_pp.n_tokens;
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
||||
|
||||
auto * gf = graph_init();
|
||||
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
||||
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||
}
|
||||
|
||||
|
@ -305,16 +294,8 @@ llama_context::llama_context(
|
|||
|
||||
// reserve with tg graph to get the number of splits and nodes
|
||||
{
|
||||
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
n_outputs = ubatch_tg.n_tokens;
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
|
||||
|
||||
auto * gf = graph_init();
|
||||
graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
|
||||
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||
auto * gf = graph_reserve(1, 1, 1, kv_state.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute tg buffers");
|
||||
}
|
||||
|
||||
|
@ -324,22 +305,12 @@ llama_context::llama_context(
|
|||
|
||||
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
||||
{
|
||||
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
n_outputs = ubatch_pp.n_tokens;
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
||||
|
||||
auto * gf = graph_init();
|
||||
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
||||
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||
}
|
||||
}
|
||||
|
||||
n_outputs = n_outputs_save;
|
||||
|
||||
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
||||
ggml_backend_t backend = backend_ptrs[i];
|
||||
ggml_backend_buffer_type_t buft = backend_buft[i];
|
||||
|
@ -454,33 +425,25 @@ const llama_kv_cache * llama_context::get_kv_self() const {
|
|||
}
|
||||
|
||||
void llama_context::kv_self_update() {
|
||||
bool need_reserve = false;
|
||||
if (!memory) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
need_reserve = kv_self->update(*this);
|
||||
if (kv_self->update(*this)) {
|
||||
// if the KV cache did any computation, we have to reserve a new worst-case graph
|
||||
const auto kv_state = kv_self->init_full();
|
||||
if (!kv_state) {
|
||||
throw std::runtime_error("failed to initialize KV cache");
|
||||
}
|
||||
|
||||
// reserve a worst case graph if needed
|
||||
if (need_reserve) {
|
||||
LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
// build worst-case graph
|
||||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
// simulate full KV cache
|
||||
kv_self->set_full();
|
||||
|
||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
auto * gf = graph_init();
|
||||
graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
|
||||
|
||||
// initialize scheduler with the worst-case graph
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
||||
if (!gf) {
|
||||
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -676,6 +639,49 @@ bool llama_context::apply_adapter_cvec(
|
|||
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
||||
}
|
||||
|
||||
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) {
|
||||
if (mstate && !mstate->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__);
|
||||
ret = GGML_STATUS_FAILED;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto * gf = graph_init();
|
||||
if (!gf) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
|
||||
ret = GGML_STATUS_FAILED;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate);
|
||||
if (!res) {
|
||||
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
|
||||
ret = GGML_STATUS_FAILED;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||
|
||||
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
|
||||
ret = GGML_STATUS_ALLOC_FAILED;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
res->set_inputs(&ubatch);
|
||||
|
||||
const auto status = graph_compute(gf, ubatch.n_tokens > 1);
|
||||
if (status != GGML_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
|
||||
ret = status;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ret = GGML_STATUS_SUCCESS;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
int llama_context::encode(llama_batch & inp_batch) {
|
||||
if (inp_batch.n_tokens == 0) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
||||
|
@ -737,8 +743,6 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|||
|
||||
n_outputs = n_tokens;
|
||||
|
||||
//batch_manager->prepare(ubatch);
|
||||
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||
|
||||
|
@ -749,26 +753,18 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|||
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
|
||||
cparams.causal_attn = false;
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
|
||||
|
||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||
|
||||
res->set_inputs(&ubatch);
|
||||
ggml_status status;
|
||||
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
|
||||
|
||||
cparams.causal_attn = causal_attn_org;
|
||||
|
||||
const auto compute_status = graph_compute(gf, n_tokens > 1);
|
||||
switch (compute_status) {
|
||||
case GGML_STATUS_SUCCESS:
|
||||
break;
|
||||
case GGML_STATUS_ABORTED:
|
||||
return 2;
|
||||
case GGML_STATUS_ALLOC_FAILED:
|
||||
return -2;
|
||||
case GGML_STATUS_FAILED:
|
||||
default:
|
||||
return -3;
|
||||
if (!res) {
|
||||
switch (status) {
|
||||
case GGML_STATUS_ABORTED: return 2;
|
||||
case GGML_STATUS_ALLOC_FAILED: return -2;
|
||||
case GGML_STATUS_FAILED: return -3;
|
||||
case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
|
||||
}
|
||||
}
|
||||
|
||||
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
||||
|
@ -889,8 +885,6 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||
const int64_t n_tokens_all = batch.n_tokens;
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
llama_kv_cache_guard kv_guard(kv_self);
|
||||
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
|
||||
// TODO: move the validation to the llama_batch_allocr
|
||||
|
@ -936,7 +930,28 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||
n_outputs_all = 1;
|
||||
}
|
||||
|
||||
llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
|
||||
// handle any pending defrags/shifts
|
||||
kv_self_update();
|
||||
|
||||
auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
|
||||
if (!kv_state) {
|
||||
return -2;
|
||||
}
|
||||
|
||||
switch (kv_state->get_status()) {
|
||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||
{
|
||||
} break;
|
||||
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
||||
{
|
||||
// not a fatal error, we can re-try with a different batch
|
||||
return 1;
|
||||
}
|
||||
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
||||
{
|
||||
return -2;
|
||||
}
|
||||
}
|
||||
|
||||
// reserve output buffer
|
||||
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
||||
|
@ -944,13 +959,10 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||
return -2;
|
||||
};
|
||||
|
||||
// handle any pending defrags/shifts
|
||||
kv_self_update();
|
||||
|
||||
int64_t n_outputs_prev = 0;
|
||||
|
||||
while (sbatch.n_tokens > 0) {
|
||||
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
|
||||
do {
|
||||
const auto & ubatch = kv_state->get_ubatch();
|
||||
|
||||
// count the outputs in this u_batch
|
||||
{
|
||||
|
@ -969,33 +981,37 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||
n_outputs = n_outputs_new;
|
||||
}
|
||||
|
||||
// find KV slot
|
||||
if (!kv_self->find_slot(ubatch)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER);
|
||||
ggml_status status;
|
||||
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status);
|
||||
|
||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||
if (!res) {
|
||||
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
||||
llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits<llama_pos>::max() };
|
||||
|
||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||
const auto & seq_id = ubatch.seq_id[i][0];
|
||||
|
||||
res->set_inputs(&ubatch);
|
||||
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
|
||||
}
|
||||
|
||||
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
|
||||
if (compute_status != GGML_STATUS_SUCCESS) {
|
||||
switch (compute_status) {
|
||||
case GGML_STATUS_ABORTED:
|
||||
return 2;
|
||||
case GGML_STATUS_ALLOC_FAILED:
|
||||
return -2;
|
||||
case GGML_STATUS_FAILED:
|
||||
default:
|
||||
return -3;
|
||||
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
||||
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
|
||||
|
||||
llama_kv_self_seq_rm(this, s, pos_min[s], -1);
|
||||
}
|
||||
|
||||
switch (status) {
|
||||
case GGML_STATUS_ABORTED: return 2;
|
||||
case GGML_STATUS_ALLOC_FAILED: return -2;
|
||||
case GGML_STATUS_FAILED: return -3;
|
||||
case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1082,10 +1098,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||
}
|
||||
|
||||
n_outputs_prev += n_outputs;
|
||||
}
|
||||
|
||||
// finalize the batch processing
|
||||
kv_guard.commit();
|
||||
} while (kv_state->next());
|
||||
|
||||
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
||||
n_outputs = n_outputs_all;
|
||||
|
@ -1094,7 +1107,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||
{
|
||||
bool sorted_output = true;
|
||||
|
||||
auto & out_ids = sbatch.out_ids;
|
||||
auto & out_ids = kv_state->out_ids();
|
||||
|
||||
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
|
||||
|
||||
|
@ -1254,11 +1267,52 @@ ggml_cgraph * llama_context::graph_init() {
|
|||
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) {
|
||||
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||
|
||||
if (n_tokens % n_seqs != 0) {
|
||||
n_tokens = (n_tokens / n_seqs) * n_seqs;
|
||||
n_outputs = std::min(n_outputs, n_tokens);
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||
}
|
||||
|
||||
// store the n_outputs as it is, and restore it afterwards
|
||||
// TODO: not sure if needed, might simplify in the future by removing this
|
||||
const auto save_n_outputs = this->n_outputs;
|
||||
|
||||
this->n_outputs = n_outputs;
|
||||
|
||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
|
||||
|
||||
this->n_outputs = save_n_outputs;
|
||||
|
||||
if (!res) {
|
||||
LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
||||
// initialize scheduler with the specified graph
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
llm_graph_result_ptr llama_context::graph_build(
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype) {
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype,
|
||||
const llama_memory_state_i * mstate) {
|
||||
return model.build_graph(
|
||||
{
|
||||
/*.ctx =*/ ctx,
|
||||
|
@ -1270,7 +1324,7 @@ llm_graph_result_ptr llama_context::graph_build(
|
|||
/*.backend_cpu =*/ backend_cpu,
|
||||
/*.cvec =*/ &cvec,
|
||||
/*.loras =*/ &loras,
|
||||
/*.memory =*/ memory.get(),
|
||||
/*.mstate =*/ mstate,
|
||||
/*.cross =*/ &cross,
|
||||
/*.n_outputs =*/ n_outputs,
|
||||
/*.cb =*/ graph_get_cb(),
|
||||
|
@ -1951,7 +2005,6 @@ void llama_context::opt_epoch_iter(
|
|||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
kv_self->clear();
|
||||
llama_kv_cache_guard kv_guard(kv_self);
|
||||
|
||||
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
|
||||
batch.n_tokens = n_batch;
|
||||
|
@ -1974,7 +2027,11 @@ void llama_context::opt_epoch_iter(
|
|||
|
||||
int64_t n_outputs_all = n_tokens_all;
|
||||
|
||||
llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
|
||||
auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
|
||||
if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
||||
break;
|
||||
}
|
||||
|
||||
// reserve output buffer
|
||||
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
||||
|
@ -1982,20 +2039,19 @@ void llama_context::opt_epoch_iter(
|
|||
GGML_ABORT("TODO: handle this error");
|
||||
};
|
||||
|
||||
for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
|
||||
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
|
||||
uint32_t pos_batch = 0;
|
||||
do {
|
||||
const auto & ubatch = kv_state->get_ubatch();
|
||||
|
||||
n_outputs = ubatch.n_tokens;
|
||||
|
||||
// TODO: not sure if this is needed
|
||||
if (!kv_self->find_slot(ubatch)) {
|
||||
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
|
||||
|
||||
GGML_ABORT("TODO: handle this error");
|
||||
if (!kv_state->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
|
||||
break;
|
||||
}
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get());
|
||||
|
||||
struct ggml_context * ctx_compute_opt;
|
||||
{
|
||||
|
@ -2010,6 +2066,7 @@ void llama_context::opt_epoch_iter(
|
|||
}
|
||||
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
|
||||
ggml_opt_alloc(opt_ctx, train);
|
||||
|
||||
res->set_inputs(&ubatch);
|
||||
{
|
||||
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
|
||||
|
@ -2027,10 +2084,10 @@ void llama_context::opt_epoch_iter(
|
|||
callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
|
||||
}
|
||||
ggml_free(ctx_compute_opt);
|
||||
}
|
||||
}
|
||||
|
||||
kv_guard.commit();
|
||||
pos_batch += ubatch.n_tokens;
|
||||
} while (kv_state->next());
|
||||
}
|
||||
}
|
||||
|
||||
void llama_context::opt_epoch(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue