kv-cache : fix out-of-bounds view during reserve graph (#13547)
* kv-cache : fix reserve graph out-of-bounds access ggml-ci * cont : add comment * cont : fix comments [no ci] * cont : more correct comment [no ci]
This commit is contained in:
parent
5ab5d5fb25
commit
e3a9421b78
2 changed files with 12 additions and 10 deletions
|
@ -441,6 +441,13 @@ void llama_kv_cache_unified::defrag_sched(float thold) {
|
|||
|
||||
void llama_kv_cache_unified::set_full() {
|
||||
n = size;
|
||||
|
||||
// when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
|
||||
// affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
|
||||
// we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so
|
||||
// setting it to 0 is the simplest way to achieve that
|
||||
// ref: https://github.com/ggml-org/llama.cpp/issues/13359
|
||||
head = 0;
|
||||
}
|
||||
|
||||
llama_sbatch llama_kv_cache_unified::sbatch_init(
|
||||
|
@ -1712,6 +1719,7 @@ void llama_kv_cache_recurrent::defrag_sched(float thold) {
|
|||
|
||||
void llama_kv_cache_recurrent::set_full() {
|
||||
n = size;
|
||||
head = 0;
|
||||
}
|
||||
|
||||
llama_sbatch llama_kv_cache_recurrent::sbatch_init(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue