llama : rework embeddings logic (#14208)
* llama : rework embeddings logic ggml-ci * cont : fix rerank ggml-ci * cont : engrish [no ci] * cont : fix rerank ggml-ci * server : support both embeddings and completions with single model ggml-ci * cont : avoid embeddings_org ggml-ci
This commit is contained in:
parent
3ba0d843c6
commit
d3e64b9f49
16 changed files with 159 additions and 114 deletions
|
@ -359,9 +359,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
|||
return result;
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
|
||||
GGML_UNUSED(embd_pooled);
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
|
||||
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
@ -369,8 +367,8 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch &
|
|||
while (sbatch.n_tokens > 0) {
|
||||
llama_ubatch ubatch;
|
||||
|
||||
if (embd_pooled) {
|
||||
// Pooled embeddings cannot be split across ubatches (yet)
|
||||
if (embd_all) {
|
||||
// if all tokens are output, split by sequence
|
||||
ubatch = sbatch.split_seq(n_ubatch);
|
||||
} else {
|
||||
ubatch = sbatch.split_equal(n_ubatch);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue