feat: add GGML_UNARY_OP_ARGMAX Metal kernel (ggml/1019)
* implemented argmax kernel * tpig -> tgpig * change to strides * contiguous assertions * kernel working and tested * argmax simd parallel implementation * added 2 new tests for argmax in test-backend-ops * cosmit * added 3 tests cases for perf eval * add test_argmax in make_test_cases_perf * Update test-backend-ops.cpp Co-authored-by: Diego Devesa <slarengh@gmail.com> --------- Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
parent
667d70d170
commit
efb6ae9630
3 changed files with 92 additions and 6 deletions
|
|
@ -1366,6 +1366,63 @@ kernel void kernel_ssm_scan_f32(
|
|||
}
|
||||
}
|
||||
|
||||
kernel void kernel_argmax(
|
||||
device const void * x,
|
||||
device int32_t * dst,
|
||||
constant int64_t & ncols,
|
||||
constant uint64_t & nb01,
|
||||
threadgroup float * shared_maxval [[threadgroup(0)]],
|
||||
threadgroup int32_t * shared_argmax [[threadgroup(1)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
device const float * x_row = (device const float *) ((device const char *) x + tgpig * nb01);
|
||||
|
||||
float lmax = -INFINITY;
|
||||
int32_t larg = -1;
|
||||
|
||||
for (int i00 = tpitg; i00 < ncols; i00 += ntg) {
|
||||
if (x_row[i00] > lmax) {
|
||||
lmax = x_row[i00];
|
||||
larg = i00;
|
||||
}
|
||||
}
|
||||
|
||||
// find the argmax value in the block
|
||||
float max_val = simd_max(lmax);
|
||||
int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));
|
||||
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
shared_maxval[tiisg] = -INFINITY;
|
||||
shared_argmax[tiisg] = -1;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shared_maxval[sgitg] = max_val;
|
||||
shared_argmax[sgitg] = arg_val;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
max_val = shared_maxval[tiisg];
|
||||
arg_val = shared_argmax[tiisg];
|
||||
|
||||
float max_val_reduced = simd_max(max_val);
|
||||
int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));
|
||||
|
||||
dst[tgpig] = arg_val_reduced;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
dst[tgpig] = arg_val;
|
||||
}
|
||||
|
||||
kernel void kernel_norm(
|
||||
constant ggml_metal_kargs_norm & args,
|
||||
device const char * src0,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue