llama : rework embeddings logic (#14208)
* llama : rework embeddings logic ggml-ci * cont : fix rerank ggml-ci * cont : engrish [no ci] * cont : fix rerank ggml-ci * server : support both embeddings and completions with single model ggml-ci * cont : avoid embeddings_org ggml-ci
This commit is contained in:
parent
3ba0d843c6
commit
d3e64b9f49
16 changed files with 159 additions and 114 deletions
|
@ -299,7 +299,8 @@ llama_batch_allocr::llama_batch_allocr() {
|
|||
bool llama_batch_allocr::init(
|
||||
const llama_batch & batch_inp,
|
||||
const llama_vocab & vocab,
|
||||
const llama_memory_i * memory) {
|
||||
const llama_memory_i * memory,
|
||||
bool embd_all) {
|
||||
clear();
|
||||
|
||||
batch = batch_inp;
|
||||
|
@ -378,10 +379,31 @@ bool llama_batch_allocr::init(
|
|||
}
|
||||
|
||||
if (!batch.logits) {
|
||||
// by default return the output only for the last token
|
||||
output.resize(batch.n_tokens);
|
||||
output[output.size() - 1] = true;
|
||||
if (embd_all) {
|
||||
// return the output for all tokens
|
||||
output.resize(batch.n_tokens, true);
|
||||
} else {
|
||||
// return the output only for the last token
|
||||
output.resize(batch.n_tokens, false);
|
||||
output[output.size() - 1] = true;
|
||||
}
|
||||
|
||||
batch.logits = output.data();
|
||||
} else if (embd_all) {
|
||||
bool warn = false;
|
||||
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
if (batch.logits[i] == 0) {
|
||||
warn = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (warn) {
|
||||
LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
|
||||
|
||||
output.resize(batch.n_tokens, true);
|
||||
batch.logits = output.data();
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
|
|
|
@ -88,7 +88,8 @@ public:
|
|||
bool init(
|
||||
const llama_batch & batch_inp,
|
||||
const llama_vocab & vocab,
|
||||
const llama_memory_i * memory);
|
||||
const llama_memory_i * memory,
|
||||
bool embd_all);
|
||||
|
||||
const llama_batch & get_batch() const;
|
||||
|
||||
|
|
|
@ -728,7 +728,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|||
}
|
||||
|
||||
// note: during encode, we always pass the full sequence starting from pos = 0
|
||||
if (!batch_allocr->init(batch_inp, model.vocab, nullptr)) {
|
||||
if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
@ -894,7 +894,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
return -1;
|
||||
}
|
||||
|
||||
if (!batch_allocr->init(batch_inp, model.vocab, memory.get())) {
|
||||
// when computing embeddings, all tokens are output
|
||||
const bool embd_all = cparams.embeddings;
|
||||
|
||||
if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
@ -911,12 +914,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
|
||||
// this indicates we are doing pooled embedding
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
|
||||
|
||||
if (embd_pooled) {
|
||||
if (embd_all) {
|
||||
// require that all tokens are output
|
||||
if (n_outputs_all != n_tokens_all) {
|
||||
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
|
||||
|
@ -945,7 +945,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
llama_memory_state_ptr mstate;
|
||||
|
||||
while (true) {
|
||||
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
|
||||
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all);
|
||||
if (!mstate) {
|
||||
return -2;
|
||||
}
|
||||
|
@ -1058,7 +1058,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
//}
|
||||
|
||||
auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
|
||||
auto * t_logits = res->get_logits();
|
||||
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
||||
|
||||
if (t_embd && res->get_embd_pooled()) {
|
||||
|
@ -1222,9 +1222,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|||
const auto n_vocab = vocab.n_tokens();
|
||||
const auto n_embd = hparams.n_embd;
|
||||
|
||||
// TODO: use a per-batch flag for logits presence instead
|
||||
bool has_logits = !cparams.embeddings;
|
||||
bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
||||
bool has_logits = true;
|
||||
bool has_embd = cparams.embeddings;
|
||||
|
||||
// TODO: hacky enc-dec support
|
||||
if (model.arch == LLM_ARCH_T5) {
|
||||
|
@ -2044,14 +2043,11 @@ void llama_context::opt_epoch_iter(
|
|||
|
||||
n_queued_tokens += n_tokens_all;
|
||||
|
||||
// this indicates we are doing pooled embedding
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
embd_seq.clear();
|
||||
|
||||
uint32_t n_outputs_all = n_tokens_all;
|
||||
|
||||
auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
|
||||
auto mstate = memory->init_batch(batch, cparams.n_ubatch, true);
|
||||
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
||||
break;
|
||||
|
|
|
@ -359,9 +359,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
|||
return result;
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
|
||||
GGML_UNUSED(embd_pooled);
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
|
||||
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
@ -369,8 +367,8 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch &
|
|||
while (sbatch.n_tokens > 0) {
|
||||
llama_ubatch ubatch;
|
||||
|
||||
if (embd_pooled) {
|
||||
// Pooled embeddings cannot be split across ubatches (yet)
|
||||
if (embd_all) {
|
||||
// if all tokens are output, split by sequence
|
||||
ubatch = sbatch.split_seq(n_ubatch);
|
||||
} else {
|
||||
ubatch = sbatch.split_equal(n_ubatch);
|
||||
|
|
|
@ -32,7 +32,7 @@ public:
|
|||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled) override;
|
||||
bool embd_all) override;
|
||||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
|
|
|
@ -95,8 +95,8 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
|||
return kv_swa->seq_pos_max(seq_id);
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
|
||||
GGML_UNUSED(embd_pooled);
|
||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
|
||||
GGML_UNUSED(embd_all);
|
||||
|
||||
// first try simple split
|
||||
do {
|
||||
|
|
|
@ -34,7 +34,7 @@ public:
|
|||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled) override;
|
||||
bool embd_all) override;
|
||||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
|
|
|
@ -310,8 +310,8 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
|||
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled) {
|
||||
GGML_UNUSED(embd_pooled);
|
||||
bool embd_all) {
|
||||
GGML_UNUSED(embd_all);
|
||||
|
||||
do {
|
||||
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
|
||||
|
|
|
@ -59,7 +59,7 @@ public:
|
|||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled) override;
|
||||
bool embd_all) override;
|
||||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
|
|
|
@ -73,7 +73,7 @@ struct llama_memory_i {
|
|||
virtual llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled) = 0;
|
||||
bool embd_all) = 0;
|
||||
|
||||
// simulate full cache, used for allocating worst-case compute buffers
|
||||
virtual llama_memory_state_ptr init_full() = 0;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue