metal : add gqa8 kernel to allow llama-2-70B on metal (#2459)

* Added gqa8 kernel to allow llama-2-70B on metal

* Update ggml-metal.m

Co-authored-by: Cebtenzzre <cebtenzzre@gmail.com>

* Extend kernel_mul_mat_f16_f32 to handle gqa broadcast

* Added ne03==ne13 assertion

---------

Co-authored-by: Cebtenzzre <cebtenzzre@gmail.com>
This commit is contained in:
Matteo Boschini 2023-08-01 09:43:12 +02:00 committed by GitHub
parent 49e7cb5bb1
commit 1873ff586b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 17 deletions

View file

@ -509,11 +509,13 @@ kernel void kernel_mul_mat_f16_f32(
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
@ -529,7 +531,7 @@ kernel void kernel_mul_mat_f16_f32(
const int64_t r1 = tgpig.y;
const int64_t im = tgpig.z;
device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02);
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
sum[tpitg.x] = 0.0f;
@ -552,6 +554,7 @@ kernel void kernel_mul_mat_f16_f32(
}
}
kernel void kernel_alibi_f32(
device const float * src0,
device float * dst,