mtmd : fix batch_view for m-rope (#13397)

* mtmd : fix batch_view for m-rope

* nits : fix comment
This commit is contained in:
Xuan-Son Nguyen 2025-05-09 11:18:02 +02:00 committed by GitHub
parent 3f96aeff39
commit 2189fd3b63
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -554,14 +554,19 @@ struct decode_embd_batch {
llama_batch get_view(int offset, int n_tokens) {
llama_pos * pos_ptr;
pos_view.clear();
pos_view.resize(n_tokens * n_pos_per_embd);
pos_view.reserve(n_tokens * n_pos_per_embd);
if (n_pos_per_embd > 1) {
// mrope
// for example, with layout of src: 1234...1234...1234...1234...
// offset 2 will give us dst: 34...34...34...34...
for (int i = 0; i < n_pos_per_embd; i++) {
auto src = pos.begin() + i * batch.n_tokens + offset;
pos_view.insert(pos_view.end(), src, src + n_tokens);
// assume n_tokens is less than or equal to batch.n_tokens
// batch.n_tokens is number of **total** tokens
// n_tokens is number of viewed token
size_t src_idx = i * batch.n_tokens + offset;
pos_view.insert(pos_view.end(),
pos.data() + src_idx,
pos.data() + src_idx + n_tokens);
}
pos_ptr = pos_view.data();
} else {