examples : allow extracting embeddings from decoder contexts (#13797)
ggml-ci
This commit is contained in:
parent
22229314fc
commit
79c137f776
4 changed files with 10 additions and 16 deletions
|
@ -41,8 +41,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
||||||
|
|
||||||
// run model
|
// run model
|
||||||
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
|
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
|
||||||
if (llama_encode(ctx, batch) < 0) {
|
if (llama_decode(ctx, batch) < 0) {
|
||||||
LOG_ERR("%s : failed to encode\n", __func__);
|
LOG_ERR("%s : failed to process\n", __func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < batch.n_tokens; i++) {
|
for (int i = 0; i < batch.n_tokens; i++) {
|
||||||
|
|
|
@ -81,14 +81,14 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void batch_encode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
|
static void batch_process(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
|
||||||
// clear previous kv_cache values (irrelevant for embeddings)
|
// clear previous kv_cache values (irrelevant for embeddings)
|
||||||
llama_kv_self_clear(ctx);
|
llama_kv_self_clear(ctx);
|
||||||
|
|
||||||
// run model
|
// run model
|
||||||
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
|
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
|
||||||
if (llama_encode(ctx, batch) < 0) {
|
if (llama_decode(ctx, batch) < 0) {
|
||||||
LOG_ERR("%s : failed to encode\n", __func__);
|
LOG_ERR("%s : failed to process\n", __func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < batch.n_tokens; i++) {
|
for (int i = 0; i < batch.n_tokens; i++) {
|
||||||
|
@ -233,7 +233,7 @@ int main(int argc, char ** argv) {
|
||||||
// encode if at capacity
|
// encode if at capacity
|
||||||
if (batch.n_tokens + n_toks > n_batch) {
|
if (batch.n_tokens + n_toks > n_batch) {
|
||||||
float * out = emb + p * n_embd;
|
float * out = emb + p * n_embd;
|
||||||
batch_encode(ctx, batch, out, s, n_embd);
|
batch_process(ctx, batch, out, s, n_embd);
|
||||||
common_batch_clear(batch);
|
common_batch_clear(batch);
|
||||||
p += s;
|
p += s;
|
||||||
s = 0;
|
s = 0;
|
||||||
|
@ -246,7 +246,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// final batch
|
// final batch
|
||||||
float * out = emb + p * n_embd;
|
float * out = emb + p * n_embd;
|
||||||
batch_encode(ctx, batch, out, s, n_embd);
|
batch_process(ctx, batch, out, s, n_embd);
|
||||||
|
|
||||||
// save embeddings to chunks
|
// save embeddings to chunks
|
||||||
for (int i = 0; i < n_chunks; i++) {
|
for (int i = 0; i < n_chunks; i++) {
|
||||||
|
@ -267,7 +267,7 @@ int main(int argc, char ** argv) {
|
||||||
batch_add_seq(query_batch, query_tokens, 0);
|
batch_add_seq(query_batch, query_tokens, 0);
|
||||||
|
|
||||||
std::vector<float> query_emb(n_embd, 0);
|
std::vector<float> query_emb(n_embd, 0);
|
||||||
batch_encode(ctx, query_batch, query_emb.data(), 1, n_embd);
|
batch_process(ctx, query_batch, query_emb.data(), 1, n_embd);
|
||||||
|
|
||||||
common_batch_clear(query_batch);
|
common_batch_clear(query_batch);
|
||||||
|
|
||||||
|
|
|
@ -852,7 +852,7 @@ int llama_context::encode(llama_batch & inp_batch) {
|
||||||
|
|
||||||
int llama_context::decode(llama_batch & inp_batch) {
|
int llama_context::decode(llama_batch & inp_batch) {
|
||||||
if (!memory) {
|
if (!memory) {
|
||||||
LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
|
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
|
||||||
return encode(inp_batch);
|
return encode(inp_batch);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3394,13 +3394,7 @@ struct server_context {
|
||||||
batch.logits + i,
|
batch.logits + i,
|
||||||
};
|
};
|
||||||
|
|
||||||
int ret = 0;
|
const int ret = llama_decode(ctx, batch_view);
|
||||||
|
|
||||||
if (do_encode) {
|
|
||||||
ret = llama_encode(ctx, batch_view);
|
|
||||||
} else {
|
|
||||||
ret = llama_decode(ctx, batch_view);
|
|
||||||
}
|
|
||||||
|
|
||||||
metrics.on_decoded(slots);
|
metrics.on_decoded(slots);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue