llama : remove all_pos_0, all_pos_1, all_seq_id from llama_batch (#9745)
* refactor llama_batch_get_one * adapt all examples * fix simple.cpp * fix llama_bench * fix * fix context shifting * free batch before return * use common_batch_add, reuse llama_batch in loop * null terminated seq_id list * fix save-load-state example * fix perplexity * correct token pos in llama_batch_allocr
This commit is contained in:
parent
afd9909a64
commit
cda0e4b648
22 changed files with 205 additions and 118 deletions
137
src/llama.cpp
137
src/llama.cpp
|
@ -2949,9 +2949,6 @@ struct llama_sbatch_seq {
|
|||
llama_seq_id * seq_id;
|
||||
size_t offset;
|
||||
size_t length;
|
||||
|
||||
// helper for smoother batch API transition -- can be deprecated in the future
|
||||
llama_seq_id all_seq_id; // used if seq_id == NULL
|
||||
};
|
||||
|
||||
// sequence-length-aware batch splitting
|
||||
|
@ -3046,30 +3043,18 @@ struct llama_sbatch {
|
|||
} else {
|
||||
ubatch.embd = nullptr;
|
||||
}
|
||||
// from here on, the else branches are deprecated;
|
||||
// they are helpers for smoother batch API transition
|
||||
if (batch->pos) {
|
||||
if (ubatch.equal_seqs) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
|
||||
}
|
||||
} else {
|
||||
// simple split
|
||||
ubatch.pos = batch->pos + seq.offset;
|
||||
if (ubatch.equal_seqs) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
llama_pos bi = ids[seq.offset + i];
|
||||
ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1);
|
||||
}
|
||||
// simple split
|
||||
ubatch.pos = batch->pos + seq.offset;
|
||||
}
|
||||
if (ubatch.equal_seqs) {
|
||||
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
|
||||
if (seq.seq_id) {
|
||||
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
|
||||
} else {
|
||||
GGML_ASSERT(seq.n_seq_id == 1);
|
||||
ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id;
|
||||
}
|
||||
} else {
|
||||
// simple split
|
||||
|
@ -3082,10 +3067,6 @@ struct llama_sbatch {
|
|||
}
|
||||
if (batch->seq_id) {
|
||||
ubatch.seq_id = batch->seq_id + seq.offset;
|
||||
} else {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (logits_all) {
|
||||
|
@ -3204,7 +3185,6 @@ struct llama_sbatch {
|
|||
s.seq_id = nullptr;
|
||||
s.offset = 0;
|
||||
s.length = n_tokens;
|
||||
s.all_seq_id = batch.all_seq_id;
|
||||
return;
|
||||
}
|
||||
std::sort(ids.begin(), ids.end(),
|
||||
|
@ -3227,7 +3207,7 @@ struct llama_sbatch {
|
|||
if (batch.pos) {
|
||||
return batch.pos[a] < batch.pos[b];
|
||||
}
|
||||
// no pos, sort by id (assuming batch.all_pos_1 is positive)
|
||||
// no pos, sort by id
|
||||
return a < b;
|
||||
}
|
||||
// shared prompts go first
|
||||
|
@ -3237,30 +3217,25 @@ struct llama_sbatch {
|
|||
// init seq
|
||||
llama_sbatch_seq * last_seq = nullptr;
|
||||
|
||||
if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) {
|
||||
for (size_t i = 0; i < n_tokens; ++i) {
|
||||
const size_t bi = ids[i];
|
||||
const int32_t n_seqs = batch.n_seq_id[bi];
|
||||
llama_seq_id * seq_ids = batch.seq_id[bi];
|
||||
if (last_seq != nullptr) {
|
||||
bool same = n_seqs == last_seq->n_seq_id;
|
||||
for (int32_t j = 0; same && j < n_seqs; ++j) {
|
||||
if (seq_ids[j] != last_seq->seq_id[j]) {
|
||||
same = false;
|
||||
}
|
||||
}
|
||||
if (same) {
|
||||
last_seq->length += 1;
|
||||
continue;
|
||||
for (size_t i = 0; i < n_tokens; ++i) {
|
||||
const size_t bi = ids[i];
|
||||
const int32_t n_seqs = batch.n_seq_id[bi];
|
||||
llama_seq_id * seq_ids = batch.seq_id[bi];
|
||||
if (last_seq != nullptr) {
|
||||
bool same = n_seqs == last_seq->n_seq_id;
|
||||
for (int32_t j = 0; same && j < n_seqs; ++j) {
|
||||
if (seq_ids[j] != last_seq->seq_id[j]) {
|
||||
same = false;
|
||||
}
|
||||
}
|
||||
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id};
|
||||
seq.push_back(new_seq);
|
||||
last_seq = &seq.back();
|
||||
if (same) {
|
||||
last_seq->length += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id};
|
||||
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
|
||||
seq.push_back(new_seq);
|
||||
last_seq = &seq.back();
|
||||
}
|
||||
// keep shared prompts first at the end, then sort by length descending.
|
||||
std::sort(seq.begin(), seq.end(),
|
||||
|
@ -21096,9 +21071,7 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
|
|||
|
||||
struct llama_batch llama_batch_get_one(
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
llama_pos pos_0,
|
||||
llama_seq_id seq_id) {
|
||||
int32_t n_tokens) {
|
||||
return {
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*tokens =*/ tokens,
|
||||
|
@ -21107,9 +21080,6 @@ struct llama_batch llama_batch_get_one(
|
|||
/*n_seq_id =*/ nullptr,
|
||||
/*seq_id =*/ nullptr,
|
||||
/*logits =*/ nullptr,
|
||||
/*all_pos_0 =*/ pos_0,
|
||||
/*all_pos_1 =*/ 1,
|
||||
/*all_seq_id =*/ seq_id,
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -21122,9 +21092,6 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
|||
/*n_seq_id =*/ nullptr,
|
||||
/*seq_id =*/ nullptr,
|
||||
/*logits =*/ nullptr,
|
||||
/*all_pos_0 =*/ 0,
|
||||
/*all_pos_1 =*/ 0,
|
||||
/*all_seq_id =*/ 0,
|
||||
};
|
||||
|
||||
if (embd) {
|
||||
|
@ -21160,11 +21127,62 @@ void llama_batch_free(struct llama_batch batch) {
|
|||
if (batch.logits) free(batch.logits);
|
||||
}
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
static const llama_seq_id batch_default_seq_id = 0;
|
||||
struct llama_batch_allocr {
|
||||
std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id *> seq_id;
|
||||
std::vector<int8_t> logits;
|
||||
struct llama_batch batch;
|
||||
// optionally fulfill the batch returned by llama_batch_get_one
|
||||
llama_batch_allocr(struct llama_context * ctx, struct llama_batch in_batch) {
|
||||
batch = in_batch;
|
||||
if (!batch.pos) {
|
||||
// determine the last position in KV cache
|
||||
llama_pos last_pos = -1;
|
||||
for (const auto & cell : ctx->kv_self.cells) {
|
||||
if (cell.has_seq_id(batch_default_seq_id)) {
|
||||
last_pos = std::max(last_pos, cell.pos);
|
||||
}
|
||||
}
|
||||
last_pos++; // next position
|
||||
pos.resize(batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
pos[i] = i+last_pos;
|
||||
}
|
||||
batch.pos = pos.data();
|
||||
}
|
||||
if (!batch.n_seq_id) {
|
||||
n_seq_id.resize(batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
n_seq_id[i] = seq_id_0.size();
|
||||
}
|
||||
batch.n_seq_id = n_seq_id.data();
|
||||
}
|
||||
if (!batch.seq_id) {
|
||||
seq_id.resize(batch.n_tokens + 1);
|
||||
seq_id[batch.n_tokens] = NULL;
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
seq_id[i] = seq_id_0.data();
|
||||
}
|
||||
batch.seq_id = seq_id.data();
|
||||
}
|
||||
if (!batch.logits) {
|
||||
logits.resize(batch.n_tokens);
|
||||
logits[logits.size() - 1] = true;
|
||||
batch.logits = logits.data();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
int32_t llama_encode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch) {
|
||||
const int ret = llama_encode_internal(*ctx, batch);
|
||||
if (ret < 0) {
|
||||
llama_batch_allocr batch_allocr(ctx, batch);
|
||||
const int ret = llama_encode_internal(*ctx, batch_allocr.batch);
|
||||
if (ret != 0) {
|
||||
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
|
||||
}
|
||||
|
||||
|
@ -21174,8 +21192,9 @@ int32_t llama_encode(
|
|||
int32_t llama_decode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch) {
|
||||
const int ret = llama_decode_internal(*ctx, batch);
|
||||
if (ret < 0) {
|
||||
llama_batch_allocr batch_allocr(ctx, batch);
|
||||
const int ret = llama_decode_internal(*ctx, batch_allocr.batch);
|
||||
if (ret != 0) {
|
||||
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue