llama : (mrope) allow using normal 1D position for text token (#13138)
* llama : (mrope) use normal position for text token * rm n_pos_per_embd from llm_graph_input_attn_temp
This commit is contained in:
parent
5fa9e63be8
commit
d2b2031e5f
3 changed files with 24 additions and 22 deletions
|
@ -92,20 +92,12 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
|
|||
|
||||
static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> tokens, int n_batch, int * n_past, int * st_pos_id) {
|
||||
int N = (int) tokens.size();
|
||||
std::vector<llama_pos> pos;
|
||||
for (int i = 0; i < N; i += n_batch) {
|
||||
int n_eval = (int) tokens.size() - i;
|
||||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
auto batch = llama_batch_get_one(&tokens[i], n_eval);
|
||||
// TODO: add mrope pos ids somewhere else
|
||||
pos.resize(batch.n_tokens * 4);
|
||||
std::fill(pos.begin(), pos.end(), 0);
|
||||
for (int j = 0; j < batch.n_tokens * 3; j ++) {
|
||||
pos[j] = *st_pos_id + (j % batch.n_tokens);
|
||||
}
|
||||
batch.pos = pos.data();
|
||||
|
||||
if (llama_decode(ctx_llama, batch)) {
|
||||
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue