Introducing experimental OpenCL backend with support for Qualcomm Adreno GPUs (#10693)
* [cl][adreno] Add Adreno GPU support Add new OpenCL backend to support Adreno GPUs --------- Co-authored-by: Skyler Szot <quic_sszot@quicinc.com> Co-authored-by: Shangqing Gu <quic_shawngu@quicinc.com> Co-authored-by: Alexander Angus <quic_aangus@quicinc.com> Co-authored-by: Hongqiang Wang <quic_wangh@quicinc.com> Co-authored-by: Max Krasnyansky <quic_maxk@quicinc.com> * [cl][ci] Add workflow for CL * [cl][adreno] Fix memory leak for non SMALL_ALLOC path * opencl: integrate backend dyn.load interface and fix compiler and format warnings * opencl: remove small-alloc support and fix build errors for non-opencl platforms * opencl: fixed merge conflict (MUSA added twice in cmake) * opencl-ci: use RUNNER_TEMP instead of github.workspace * opencl: fix embed tool invocation with python3 * opencl: CI workflow fixes * opencl: Clean up small-alloc in CMake files * opencl: cleanup ggml-opencl2 header file * opencl: use ulong for offsets and strides in ADD kernel * opencl: use cl_ulong for all offsets * opencl: use cl_ulong for sizes and strides * opencl: use `GGML_LOG_xxx` instead of `fprintf(stderr, ...)` * opencl: rename backend `opencl2` -> `opencl` * opencl: rename kernel files `ggml-opencl2` -> `ggml-opencl` * opencl: make OpenCL required, remove redundant lib and inc directories * `ggml-base`, `..` and `.` are added by `ggml_add_backend_library` * opencl: rename backend - funcs, structs, etc `opencl2` -> `opencl` * opencl: remove copyright marker since main license already covers * opencl: replace some more OPENCL2 leftovers * opencl: remove limits on `tensor_extra` * opencl: use pools for `tensor_extra` * opencl: fix compiler warnings with GCC and Clang Still getting the warning about clCreateCmdQueue being obsolete. Will fix that separately. * opencl: fail gracefully if opencl devices are not available Also for unsupported GPUs. * opencl: fix MSVC builds (string length error) * opencl: check for various requirements, allow deprecated API * opencl: update log message for unsupported GPUs --------- Co-authored-by: Skyler Szot <quic_sszot@quicinc.com> Co-authored-by: Shangqing Gu <quic_shawngu@quicinc.com> Co-authored-by: Alexander Angus <quic_aangus@quicinc.com> Co-authored-by: Hongqiang Wang <quic_wangh@quicinc.com> Co-authored-by: Max Krasnyansky <quic_maxk@quicinc.com>
This commit is contained in:
parent
c27ac678dd
commit
a76c56fa1a
17 changed files with 9014 additions and 1 deletions
|
@ -0,0 +1,271 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
|
||||
// assume
|
||||
#define QK4_0 32
|
||||
#define N_SIMDGROUP 4
|
||||
|
||||
#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, y) \
|
||||
float shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s0, 0); \
|
||||
total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 0); \
|
||||
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 0); \
|
||||
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 0); \
|
||||
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s4, 0); \
|
||||
total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 0); \
|
||||
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 0); \
|
||||
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 0); \
|
||||
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s0, 1); \
|
||||
total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 1); \
|
||||
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 1); \
|
||||
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 1); \
|
||||
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s4, 1); \
|
||||
total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 1); \
|
||||
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 1); \
|
||||
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 1); \
|
||||
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
|
||||
|
||||
|
||||
#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, y) \
|
||||
shared_y = sub_group_broadcast(y.s0, 2); \
|
||||
total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 2); \
|
||||
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 2); \
|
||||
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 2); \
|
||||
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s4, 2); \
|
||||
total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 2); \
|
||||
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 2); \
|
||||
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 2); \
|
||||
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s0, 3); \
|
||||
total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s1, 3); \
|
||||
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s2, 3); \
|
||||
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s3, 3); \
|
||||
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s4, 3); \
|
||||
total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s5, 3); \
|
||||
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s6, 3); \
|
||||
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
|
||||
shared_y = sub_group_broadcast(y.s7, 3); \
|
||||
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
|
||||
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
|
||||
|
||||
|
||||
#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, y) \
|
||||
float8 shared_y; \
|
||||
shared_y = sub_group_broadcast(y, 0); \
|
||||
total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
|
||||
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
|
||||
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
|
||||
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
|
||||
total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
|
||||
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
|
||||
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
|
||||
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
|
||||
total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
|
||||
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
|
||||
total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
|
||||
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
|
||||
shared_y = sub_group_broadcast(y, 1); \
|
||||
total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
|
||||
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
|
||||
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
|
||||
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
|
||||
total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
|
||||
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
|
||||
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
|
||||
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
|
||||
total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
|
||||
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
|
||||
total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
|
||||
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
|
||||
|
||||
|
||||
#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, y) \
|
||||
shared_y = sub_group_broadcast(y, 2); \
|
||||
total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
|
||||
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
|
||||
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
|
||||
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
|
||||
total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
|
||||
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
|
||||
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
|
||||
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
|
||||
total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
|
||||
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
|
||||
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
|
||||
total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
|
||||
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
|
||||
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
|
||||
shared_y = sub_group_broadcast(y, 3); \
|
||||
total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
|
||||
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
|
||||
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
|
||||
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
|
||||
total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
|
||||
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
|
||||
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
|
||||
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
|
||||
total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
|
||||
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
|
||||
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
|
||||
total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
|
||||
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
|
||||
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
|
||||
|
||||
|
||||
__attribute__((qcom_reqd_sub_group_size("full")))
|
||||
__kernel void kernel_gemv_noshuffle(
|
||||
__read_only image1d_buffer_t src0_q, // quantized A
|
||||
global half2 * src0_d, // A scales
|
||||
__read_only image1d_buffer_t src1, // B
|
||||
ulong offset1, // offset to B (0)
|
||||
global float * dst, // C
|
||||
ulong offsetd, // offset to C (0)
|
||||
int ne00, // K
|
||||
int ne01, // M
|
||||
int ne02, // 1
|
||||
int ne10, // K
|
||||
int ne12, // 1
|
||||
int ne0, // M
|
||||
int ne1, // N
|
||||
int r2, // 1
|
||||
int r3)
|
||||
{
|
||||
uint groupId = get_local_id(1);
|
||||
uint gid = get_global_id(0);
|
||||
ushort slid = get_sub_group_local_id();
|
||||
|
||||
uint K = ne00;
|
||||
uint M = ne01;
|
||||
|
||||
uint LINE_STRIDE_A = M / 2;
|
||||
uint BLOCK_STRIDE_A = N_SIMDGROUP * M;
|
||||
|
||||
__private uint4 regA;
|
||||
__private half2 regS;
|
||||
__private float8 regB;
|
||||
|
||||
__private float2 totalSum = (float2)(0.0f);
|
||||
|
||||
// loop along K in block granularity, skip 4 blocks every iter
|
||||
for (uint k = groupId; k < (K / QK4_0); k += N_SIMDGROUP) {
|
||||
regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows
|
||||
// first 4 fibers in each wave load 8 B values to its private scope
|
||||
if (slid < 4) {
|
||||
regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));
|
||||
regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));
|
||||
}
|
||||
|
||||
// load half weights for two blocks in consecutive rows
|
||||
regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
|
||||
regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
|
||||
regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
|
||||
regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
|
||||
#ifdef VECTOR_SUB_GROUP_BROADCAT
|
||||
dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB);
|
||||
#else
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB);
|
||||
#endif // VECTOR_SUB_GROUP_BROADCAT
|
||||
|
||||
regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;
|
||||
regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;
|
||||
regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
|
||||
regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
|
||||
#ifdef VECTOR_SUB_GROUP_BROADCAT
|
||||
dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB);
|
||||
#else
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB);
|
||||
#endif // VECTOR_SUB_GROUP_BROADCAT
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #wave=4
|
||||
__local float2 reduceLM[SIMDGROUP_WIDTH * 3];
|
||||
if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum;
|
||||
if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum;
|
||||
if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
|
||||
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
|
||||
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
|
||||
|
||||
// 2 outputs per fiber in wave 0
|
||||
if (groupId == 0) {
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
vstore2(totalSum, 0, &(dst[gid * 2]));
|
||||
}
|
||||
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue