context : allow cache-less context for embeddings (#13108)
* context : allow cache-less context for embeddings ggml-ci * context : enable reranking with encode() ggml-ci * context : encode() clears embd_seq ggml-ci * examples : use llama_encode() when appropriate ggml-ci * models : nomic bert moe does not require KV cache * llama : update comments for llama_decode/llama_encode ggml-ci * context : update warning log [no ci]
This commit is contained in:
parent
51fb96b1ff
commit
6562e5a4d6
5 changed files with 47 additions and 23 deletions
|
@ -922,14 +922,19 @@ extern "C" {
|
|||
// Frees a batch of tokens allocated with llama_batch_init()
|
||||
LLAMA_API void llama_batch_free(struct llama_batch batch);
|
||||
|
||||
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
|
||||
// Stores the encoder output internally for later use by the decoder cross-attention layers.
|
||||
// Process a batch of tokens.
|
||||
// In contrast to llama_decode() - this call does not use KV cache.
|
||||
// For encode-decoder contexts, processes the batch using the encoder.
|
||||
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
|
||||
// 0 - success
|
||||
// < 0 - error. the KV cache state is restored to the state before this call
|
||||
LLAMA_API int32_t llama_encode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch);
|
||||
|
||||
// Process a batch of tokens.
|
||||
// Requires KV cache.
|
||||
// For encode-decoder contexts, processes the batch using the decoder.
|
||||
// Positive return values does not mean a fatal error, but rather a warning.
|
||||
// 0 - success
|
||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue