diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 01ff6763..71f70087 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -41,8 +41,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu // run model LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); - if (llama_encode(ctx, batch) < 0) { - LOG_ERR("%s : failed to encode\n", __func__); + if (llama_decode(ctx, batch) < 0) { + LOG_ERR("%s : failed to process\n", __func__); } for (int i = 0; i < batch.n_tokens; i++) { diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index e3d0c954..754da141 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -81,14 +81,14 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke } } -static void batch_encode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { +static void batch_process(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { // clear previous kv_cache values (irrelevant for embeddings) llama_kv_self_clear(ctx); // run model LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); - if (llama_encode(ctx, batch) < 0) { - LOG_ERR("%s : failed to encode\n", __func__); + if (llama_decode(ctx, batch) < 0) { + LOG_ERR("%s : failed to process\n", __func__); } for (int i = 0; i < batch.n_tokens; i++) { @@ -233,7 +233,7 @@ int main(int argc, char ** argv) { // encode if at capacity if (batch.n_tokens + n_toks > n_batch) { float * out = emb + p * n_embd; - batch_encode(ctx, batch, out, s, n_embd); + batch_process(ctx, batch, out, s, n_embd); common_batch_clear(batch); p += s; s = 0; @@ -246,7 +246,7 @@ int main(int argc, char ** argv) { // final batch float * out = emb + p * n_embd; - batch_encode(ctx, batch, out, s, n_embd); + batch_process(ctx, batch, out, s, n_embd); // save embeddings to chunks for (int i = 0; i < n_chunks; i++) { @@ -267,7 +267,7 @@ int main(int argc, char ** argv) { batch_add_seq(query_batch, query_tokens, 0); std::vector query_emb(n_embd, 0); - batch_encode(ctx, query_batch, query_emb.data(), 1, n_embd); + batch_process(ctx, query_batch, query_emb.data(), 1, n_embd); common_batch_clear(query_batch); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 98ecb7c8..ad77cae2 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -852,7 +852,7 @@ int llama_context::encode(llama_batch & inp_batch) { int llama_context::decode(llama_batch & inp_batch) { if (!memory) { - LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__); + LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); return encode(inp_batch); } diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 07b61312..fcab1dfa 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3394,13 +3394,7 @@ struct server_context { batch.logits + i, }; - int ret = 0; - - if (do_encode) { - ret = llama_encode(ctx, batch_view); - } else { - ret = llama_decode(ctx, batch_view); - } + const int ret = llama_decode(ctx, batch_view); metrics.on_decoded(slots);