llama : rework embeddings logic (#14208)

* llama : rework embeddings logic

ggml-ci

* cont : fix rerank

ggml-ci

* cont : engrish [no ci]

* cont : fix rerank

ggml-ci

* server : support both embeddings and completions with single model

ggml-ci

* cont : avoid embeddings_org

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-06-16 14:14:00 +03:00 committed by GitHub
parent 3ba0d843c6
commit d3e64b9f49
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 159 additions and 114 deletions

View file

@ -299,7 +299,8 @@ llama_batch_allocr::llama_batch_allocr() {
bool llama_batch_allocr::init(
const llama_batch & batch_inp,
const llama_vocab & vocab,
const llama_memory_i * memory) {
const llama_memory_i * memory,
bool embd_all) {
clear();
batch = batch_inp;
@ -378,10 +379,31 @@ bool llama_batch_allocr::init(
}
if (!batch.logits) {
// by default return the output only for the last token
output.resize(batch.n_tokens);
output[output.size() - 1] = true;
if (embd_all) {
// return the output for all tokens
output.resize(batch.n_tokens, true);
} else {
// return the output only for the last token
output.resize(batch.n_tokens, false);
output[output.size() - 1] = true;
}
batch.logits = output.data();
} else if (embd_all) {
bool warn = false;
for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (batch.logits[i] == 0) {
warn = true;
}
}
if (warn) {
LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
output.resize(batch.n_tokens, true);
batch.logits = output.data();
}
}
//

View file

@ -88,7 +88,8 @@ public:
bool init(
const llama_batch & batch_inp,
const llama_vocab & vocab,
const llama_memory_i * memory);
const llama_memory_i * memory,
bool embd_all);
const llama_batch & get_batch() const;

View file

@ -728,7 +728,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
}
// note: during encode, we always pass the full sequence starting from pos = 0
if (!batch_allocr->init(batch_inp, model.vocab, nullptr)) {
if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
@ -894,7 +894,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
return -1;
}
if (!batch_allocr->init(batch_inp, model.vocab, memory.get())) {
// when computing embeddings, all tokens are output
const bool embd_all = cparams.embeddings;
if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
@ -911,12 +914,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
// this indicates we are doing pooled embedding
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
if (embd_pooled) {
if (embd_all) {
// require that all tokens are output
if (n_outputs_all != n_tokens_all) {
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
@ -945,7 +945,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
llama_memory_state_ptr mstate;
while (true) {
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all);
if (!mstate) {
return -2;
}
@ -1058,7 +1058,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
//}
auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
auto * t_logits = res->get_logits();
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
if (t_embd && res->get_embd_pooled()) {
@ -1222,9 +1222,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
const auto n_vocab = vocab.n_tokens();
const auto n_embd = hparams.n_embd;
// TODO: use a per-batch flag for logits presence instead
bool has_logits = !cparams.embeddings;
bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
bool has_logits = true;
bool has_embd = cparams.embeddings;
// TODO: hacky enc-dec support
if (model.arch == LLM_ARCH_T5) {
@ -2044,14 +2043,11 @@ void llama_context::opt_epoch_iter(
n_queued_tokens += n_tokens_all;
// this indicates we are doing pooled embedding
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
embd_seq.clear();
uint32_t n_outputs_all = n_tokens_all;
auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
auto mstate = memory->init_batch(batch, cparams.n_ubatch, true);
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
break;

View file

@ -359,9 +359,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
return result;
}
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
GGML_UNUSED(embd_pooled);
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
std::vector<llama_ubatch> ubatches;
@ -369,8 +367,8 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch &
while (sbatch.n_tokens > 0) {
llama_ubatch ubatch;
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
if (embd_all) {
// if all tokens are output, split by sequence
ubatch = sbatch.split_seq(n_ubatch);
} else {
ubatch = sbatch.split_equal(n_ubatch);

View file

@ -32,7 +32,7 @@ public:
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled) override;
bool embd_all) override;
llama_memory_state_ptr init_full() override;

View file

@ -95,8 +95,8 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
return kv_swa->seq_pos_max(seq_id);
}
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
GGML_UNUSED(embd_pooled);
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
GGML_UNUSED(embd_all);
// first try simple split
do {

View file

@ -34,7 +34,7 @@ public:
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled) override;
bool embd_all) override;
llama_memory_state_ptr init_full() override;

View file

@ -310,8 +310,8 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled) {
GGML_UNUSED(embd_pooled);
bool embd_all) {
GGML_UNUSED(embd_all);
do {
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);

View file

@ -59,7 +59,7 @@ public:
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled) override;
bool embd_all) override;
llama_memory_state_ptr init_full() override;

View file

@ -73,7 +73,7 @@ struct llama_memory_i {
virtual llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled) = 0;
bool embd_all) = 0;
// simulate full cache, used for allocating worst-case compute buffers
virtual llama_memory_state_ptr init_full() = 0;