feat: add GGML_UNARY_OP_ARGMAX Metal kernel (ggml/1019)

* implemented argmax kernel

* tpig -> tgpig

* change to strides

* contiguous assertions

* kernel working and tested

* argmax simd parallel implementation

* added 2 new tests for argmax in test-backend-ops

* cosmit

* added 3 tests cases for perf eval

* add test_argmax in make_test_cases_perf

* Update test-backend-ops.cpp

Co-authored-by: Diego Devesa <slarengh@gmail.com>

---------

Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
PAB 2024-12-02 19:27:24 +01:00 committed by Georgi Gerganov
parent 667d70d170
commit efb6ae9630
3 changed files with 92 additions and 6 deletions

View file

@ -1366,6 +1366,63 @@ kernel void kernel_ssm_scan_f32(
}
}
kernel void kernel_argmax(
device const void * x,
device int32_t * dst,
constant int64_t & ncols,
constant uint64_t & nb01,
threadgroup float * shared_maxval [[threadgroup(0)]],
threadgroup int32_t * shared_argmax [[threadgroup(1)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const float * x_row = (device const float *) ((device const char *) x + tgpig * nb01);
float lmax = -INFINITY;
int32_t larg = -1;
for (int i00 = tpitg; i00 < ncols; i00 += ntg) {
if (x_row[i00] > lmax) {
lmax = x_row[i00];
larg = i00;
}
}
// find the argmax value in the block
float max_val = simd_max(lmax);
int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
shared_maxval[tiisg] = -INFINITY;
shared_argmax[tiisg] = -1;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shared_maxval[sgitg] = max_val;
shared_argmax[sgitg] = arg_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = shared_maxval[tiisg];
arg_val = shared_argmax[tiisg];
float max_val_reduced = simd_max(max_val);
int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));
dst[tgpig] = arg_val_reduced;
return;
}
dst[tgpig] = arg_val;
}
kernel void kernel_norm(
constant ggml_metal_kargs_norm & args,
device const char * src0,