examples : switch retrieval to llama_encode (#13685)
* switch retrieval to llama_encode * enable --no-warmup for retrieval
This commit is contained in:
parent
eb0f5c28d3
commit
2aa777d86d
2 changed files with 7 additions and 7 deletions
|
@ -1678,7 +1678,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
[](common_params & params) {
|
[](common_params & params) {
|
||||||
params.warmup = false;
|
params.warmup = false;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING}));
|
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--spm-infill"},
|
{"--spm-infill"},
|
||||||
string_format(
|
string_format(
|
||||||
|
|
|
@ -81,14 +81,14 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
|
static void batch_encode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
|
||||||
// clear previous kv_cache values (irrelevant for embeddings)
|
// clear previous kv_cache values (irrelevant for embeddings)
|
||||||
llama_kv_self_clear(ctx);
|
llama_kv_self_clear(ctx);
|
||||||
|
|
||||||
// run model
|
// run model
|
||||||
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
|
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
|
||||||
if (llama_decode(ctx, batch) < 0) {
|
if (llama_encode(ctx, batch) < 0) {
|
||||||
LOG_ERR("%s : failed to decode\n", __func__);
|
LOG_ERR("%s : failed to encode\n", __func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < batch.n_tokens; i++) {
|
for (int i = 0; i < batch.n_tokens; i++) {
|
||||||
|
@ -233,7 +233,7 @@ int main(int argc, char ** argv) {
|
||||||
// encode if at capacity
|
// encode if at capacity
|
||||||
if (batch.n_tokens + n_toks > n_batch) {
|
if (batch.n_tokens + n_toks > n_batch) {
|
||||||
float * out = emb + p * n_embd;
|
float * out = emb + p * n_embd;
|
||||||
batch_decode(ctx, batch, out, s, n_embd);
|
batch_encode(ctx, batch, out, s, n_embd);
|
||||||
common_batch_clear(batch);
|
common_batch_clear(batch);
|
||||||
p += s;
|
p += s;
|
||||||
s = 0;
|
s = 0;
|
||||||
|
@ -246,7 +246,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// final batch
|
// final batch
|
||||||
float * out = emb + p * n_embd;
|
float * out = emb + p * n_embd;
|
||||||
batch_decode(ctx, batch, out, s, n_embd);
|
batch_encode(ctx, batch, out, s, n_embd);
|
||||||
|
|
||||||
// save embeddings to chunks
|
// save embeddings to chunks
|
||||||
for (int i = 0; i < n_chunks; i++) {
|
for (int i = 0; i < n_chunks; i++) {
|
||||||
|
@ -267,7 +267,7 @@ int main(int argc, char ** argv) {
|
||||||
batch_add_seq(query_batch, query_tokens, 0);
|
batch_add_seq(query_batch, query_tokens, 0);
|
||||||
|
|
||||||
std::vector<float> query_emb(n_embd, 0);
|
std::vector<float> query_emb(n_embd, 0);
|
||||||
batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd);
|
batch_encode(ctx, query_batch, query_emb.data(), 1, n_embd);
|
||||||
|
|
||||||
common_batch_clear(query_batch);
|
common_batch_clear(query_batch);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue