parent
c95fa362b3
commit
2d77d88e70
1 changed files with 21 additions and 4 deletions
|
@ -294,10 +294,7 @@ llama_context::llama_context(
|
||||||
// TODO: something cleaner
|
// TODO: something cleaner
|
||||||
const auto n_outputs_save = n_outputs;
|
const auto n_outputs_save = n_outputs;
|
||||||
|
|
||||||
// max number of outputs
|
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||||
n_outputs = n_tokens;
|
|
||||||
|
|
||||||
LLAMA_LOG_DEBUG("%s: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
||||||
|
|
||||||
int n_splits_pp = -1;
|
int n_splits_pp = -1;
|
||||||
int n_nodes_pp = -1;
|
int n_nodes_pp = -1;
|
||||||
|
@ -313,8 +310,15 @@ llama_context::llama_context(
|
||||||
// reserve pp graph first so that buffers are only allocated once
|
// reserve pp graph first so that buffers are only allocated once
|
||||||
{
|
{
|
||||||
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||||
|
|
||||||
|
// max number of outputs
|
||||||
|
n_outputs = ubatch_pp.n_tokens;
|
||||||
|
|
||||||
|
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
||||||
|
|
||||||
auto * gf = graph_init();
|
auto * gf = graph_init();
|
||||||
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
||||||
|
|
||||||
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||||
}
|
}
|
||||||
|
@ -326,11 +330,18 @@ llama_context::llama_context(
|
||||||
// reserve with tg graph to get the number of splits and nodes
|
// reserve with tg graph to get the number of splits and nodes
|
||||||
{
|
{
|
||||||
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||||
|
|
||||||
|
n_outputs = ubatch_tg.n_tokens;
|
||||||
|
|
||||||
|
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
|
||||||
|
|
||||||
auto * gf = graph_init();
|
auto * gf = graph_init();
|
||||||
graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
|
graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
|
||||||
|
|
||||||
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||||
throw std::runtime_error("failed to allocate compute tg buffers");
|
throw std::runtime_error("failed to allocate compute tg buffers");
|
||||||
}
|
}
|
||||||
|
|
||||||
n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
|
n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
|
||||||
n_nodes_tg = ggml_graph_n_nodes(gf);
|
n_nodes_tg = ggml_graph_n_nodes(gf);
|
||||||
}
|
}
|
||||||
|
@ -338,8 +349,14 @@ llama_context::llama_context(
|
||||||
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
||||||
{
|
{
|
||||||
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||||
|
|
||||||
|
n_outputs = ubatch_pp.n_tokens;
|
||||||
|
|
||||||
|
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
||||||
|
|
||||||
auto * gf = graph_init();
|
auto * gf = graph_init();
|
||||||
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
||||||
|
|
||||||
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue