metal : add gqa8 kernel to allow llama-2-70B on metal (#2459)
* Added gqa8 kernel to allow llama-2-70B on metal * Update ggml-metal.m Co-authored-by: Cebtenzzre <cebtenzzre@gmail.com> * Extend kernel_mul_mat_f16_f32 to handle gqa broadcast * Added ne03==ne13 assertion --------- Co-authored-by: Cebtenzzre <cebtenzzre@gmail.com>
This commit is contained in:
parent
49e7cb5bb1
commit
1873ff586b
2 changed files with 21 additions and 17 deletions
|
@ -509,11 +509,13 @@ kernel void kernel_mul_mat_f16_f32(
|
|||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
|
@ -529,7 +531,7 @@ kernel void kernel_mul_mat_f16_f32(
|
|||
const int64_t r1 = tgpig.y;
|
||||
const int64_t im = tgpig.z;
|
||||
|
||||
device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02);
|
||||
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||
|
||||
sum[tpitg.x] = 0.0f;
|
||||
|
@ -552,6 +554,7 @@ kernel void kernel_mul_mat_f16_f32(
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
kernel void kernel_alibi_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue