context : allow cache-less context for embeddings (#13108)
* context : allow cache-less context for embeddings ggml-ci * context : enable reranking with encode() ggml-ci * context : encode() clears embd_seq ggml-ci * examples : use llama_encode() when appropriate ggml-ci * models : nomic bert moe does not require KV cache * llama : update comments for llama_decode/llama_encode ggml-ci * context : update warning log [no ci]
This commit is contained in:
parent
51fb96b1ff
commit
6562e5a4d6
5 changed files with 47 additions and 23 deletions
|
@ -251,7 +251,7 @@ llama_context::llama_context(
|
|||
}
|
||||
|
||||
// reserve worst-case graph
|
||||
if (!hparams.vocab_only) {
|
||||
if (!hparams.vocab_only && memory) {
|
||||
const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
|
@ -700,6 +700,8 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|||
t_compute_start_us = ggml_time_us();
|
||||
}
|
||||
|
||||
embd_seq.clear();
|
||||
|
||||
n_queued_tokens += n_tokens;
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
@ -761,12 +763,12 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
|
||||
GGML_ASSERT(embd != nullptr);
|
||||
|
||||
switch (cparams.pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// extract token embeddings
|
||||
GGML_ASSERT(embd != nullptr);
|
||||
|
||||
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
|
||||
} break;
|
||||
|
@ -791,11 +793,18 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|||
} break;
|
||||
case LLAMA_POOLING_TYPE_RANK:
|
||||
{
|
||||
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
|
||||
// wait for an encoder model that requires this pooling type in order to test it
|
||||
// https://github.com/ggerganov/llama.cpp/pull/9510
|
||||
GGML_ABORT("RANK pooling not implemented yet");
|
||||
}
|
||||
// extract the rerank score - a single float per sequence
|
||||
auto & embd_seq_out = embd_seq;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
||||
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||
continue;
|
||||
}
|
||||
embd_seq_out[seq_id].resize(1);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ABORT("unknown pooling type");
|
||||
|
@ -833,6 +842,11 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|||
}
|
||||
|
||||
int llama_context::decode(llama_batch & inp_batch) {
|
||||
if (!memory) {
|
||||
LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
|
||||
return encode(inp_batch);
|
||||
}
|
||||
|
||||
if (inp_batch.n_tokens == 0) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
||||
return -1;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue