ggml: add GGML_SET Metal kernel + i32 CPU kernel (ggml/1037)
* implemented cpu kernel * add i32 test cases in test-backend-ops * typedef `ggml_metal_kargs_set` * implemented `kernel_set` * memcpy
This commit is contained in:
parent
c2082d93a8
commit
a8cbab201d
5 changed files with 206 additions and 1 deletions
|
|
@ -3927,6 +3927,38 @@ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_
|
|||
|
||||
#undef FA_TYPES
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_set(
|
||||
constant ggml_metal_kargs_set & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i13 = tgpig[2];
|
||||
const int i12 = tgpig[1];
|
||||
const int i11 = tgpig[0];
|
||||
|
||||
const int64_t n = i13*args.ne12*args.ne11*args.ne10 + i12*args.ne11*args.ne10 + i11*args.ne10;
|
||||
|
||||
const int64_t i3 = n / (args.ne12*args.ne11*args.ne10);
|
||||
const int64_t i2 = (n - i3*args.ne12*args.ne11*args.ne10) / (args.ne11*args.ne10);
|
||||
const int64_t i1 = (n - i3*args.ne12*args.ne11*args.ne10 - i2*args.ne11*args.ne10) / args.ne10;
|
||||
|
||||
device T * dst_data = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + args.offs);
|
||||
|
||||
for (int64_t i10 = tpitg.x; i10 < args.ne10; i10 += ntg.x) {
|
||||
device const T * src = (device T *) (src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10);
|
||||
dst_data[i10] = (T) src[0];
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_set<float>) kernel_set_t;
|
||||
|
||||
template [[host_name("kernel_set_f32")]] kernel kernel_set_t kernel_set<float>;
|
||||
template [[host_name("kernel_set_i32")]] kernel kernel_set_t kernel_set<int32_t>;
|
||||
|
||||
template<typename T0, typename T1>
|
||||
kernel void kernel_cpy(
|
||||
constant ggml_metal_kargs_cpy & args,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue