ggml-cuda : support stablelm rope (#4156)

* ggml-cuda : support stablelm rope

* remove unused freq_base kernel parameter

* add n_dims parameter to llm_build_k_shift, default to n_rot via overload

* llama : fix llm_build_k_shift args

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
slaren 2023-11-24 18:04:31 +01:00 committed by GitHub
parent 189d68446e
commit 8a052c131e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 17 deletions

View file

@ -3469,7 +3469,7 @@ static void llm_build_k_shift(
struct ggml_cgraph * graph,
llm_rope_type type,
int64_t n_ctx,
int64_t n_rot,
int n_rot,
float freq_base,
float freq_scale,
const llm_build_cb & cb) {
@ -3501,7 +3501,7 @@ static void llm_build_k_shift(
// we rotate only the first n_rot dimensions
ggml_rope_custom_inplace(ctx,
ggml_view_3d(ctx, kv.k,
n_rot, n_head_kv, n_ctx,
n_embd_head, n_head_kv, n_ctx,
ggml_element_size(kv.k)*n_embd_head,
ggml_element_size(kv.k)*n_embd_gqa,
ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il),