vulkan: implement several ops relevant for ggml_opt (#11769)
* vulkan: support memset_tensor * vulkan: support GGML_OP_SUM * vulkan: implement GGML_OP_ARGMAX * vulkan: implement GGML_OP_SUB * vulkan: implement GGML_OP_COUNT_EQUAL * vulkan: implement GGML_OP_OPT_STEP_ADAMW * vulkan: fix check_results RWKV_WKV6 crash and memory leaks * vulkan: implement GGML_OP_REPEAT_BACK * tests: remove invalid test-backend-ops REPEAT_BACK tests * vulkan: fix COUNT_EQUAL memset using a fillBuffer command
This commit is contained in:
parent
0f2bbe6564
commit
2eea03d86a
8 changed files with 569 additions and 223 deletions
51
ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp
Normal file
51
ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp
Normal file
|
@ -0,0 +1,51 @@
|
|||
#version 450
|
||||
|
||||
#include "generic_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
|
||||
shared FLOAT_TYPE tmpmax[BLOCK_SIZE];
|
||||
shared uint tmp[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint col = gl_LocalInvocationID.x;
|
||||
|
||||
if (col >= p.KX) {
|
||||
return;
|
||||
}
|
||||
A_TYPE amax = data_a[row*p.KX + col];
|
||||
tmp[col] = col;
|
||||
|
||||
for (uint i = col + BLOCK_SIZE; i < p.KX; i += BLOCK_SIZE) {
|
||||
A_TYPE val = data_a[row*p.KX + i];
|
||||
if (val > amax) {
|
||||
amax = val;
|
||||
tmp[col] = i;
|
||||
}
|
||||
}
|
||||
tmpmax[col] = amax;
|
||||
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) {
|
||||
if (col < s && col + s < p.KX) {
|
||||
if (tmpmax[col] < tmpmax[col + s]) {
|
||||
tmpmax[col] = tmpmax[col + s];
|
||||
tmp[col] = tmp[col + s];
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
if (col == 0) {
|
||||
data_d[row] = D_TYPE(tmp[0]);
|
||||
}
|
||||
}
|
31
ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp
Normal file
31
ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp
Normal file
|
@ -0,0 +1,31 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
#include "types.comp"
|
||||
#include "generic_head.comp"
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
|
||||
layout (binding = 2) buffer D {D_TYPE data_d[];};
|
||||
|
||||
const uint CHUNK_SIZE = 512;
|
||||
|
||||
void main() {
|
||||
const uint base = gl_WorkGroupID.x * CHUNK_SIZE;
|
||||
const uint col = gl_LocalInvocationID.x;
|
||||
|
||||
uint count = 0;
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < CHUNK_SIZE; i += gl_WorkGroupSize.x) {
|
||||
const uint idx = base + i + col;
|
||||
if (idx >= p.KX) {
|
||||
break;
|
||||
}
|
||||
count += uint(data_a[idx] == data_b[idx]);
|
||||
}
|
||||
|
||||
atomicAdd(data_d[0], D_TYPE(count));
|
||||
}
|
42
ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp
Normal file
42
ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp
Normal file
|
@ -0,0 +1,42 @@
|
|||
#version 450
|
||||
|
||||
#include "generic_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) buffer X {A_TYPE x[];};
|
||||
layout (binding = 1) readonly buffer G {A_TYPE grad[];};
|
||||
layout (binding = 2) buffer GM {A_TYPE gradm[];};
|
||||
layout (binding = 3) buffer GV {A_TYPE gradv[];};
|
||||
layout (binding = 4) readonly buffer P {float params[7];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float alpha = params[0];
|
||||
const float beta1 = params[1];
|
||||
const float beta2 = params[2];
|
||||
const float eps = params[3];
|
||||
const float wd = params[4];
|
||||
const float beta1h = params[5];
|
||||
const float beta2h = params[6];
|
||||
|
||||
const float gi = grad[i];
|
||||
const float gmi = gradm[i]*beta1 + gi*(1.0f - beta1);
|
||||
const float gvi = gradv[i]*beta2 + gi*gi*(1.0f - beta2);
|
||||
|
||||
gradm[i] = gmi;
|
||||
gradv[i] = gvi;
|
||||
|
||||
const float mh = gmi*beta1h;
|
||||
const float vh = sqrt(gvi*beta2h) + eps;
|
||||
|
||||
x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;
|
||||
}
|
37
ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp
Normal file
37
ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp
Normal file
|
@ -0,0 +1,37 @@
|
|||
#version 450
|
||||
|
||||
#include "types.comp"
|
||||
#include "generic_unary_head.comp"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Destination multi-index (inlined dst_idx)
|
||||
const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
|
||||
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
|
||||
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
|
||||
const uint i12_offset = i12*p.ne11*p.ne10;
|
||||
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
|
||||
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
|
||||
const uint d_idx = i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
|
||||
|
||||
// Accumulate from sources
|
||||
A_TYPE acc = A_TYPE(0);
|
||||
for (uint i3 = i13; i3 < p.ne03; i3 += p.ne13) {
|
||||
for (uint i2 = i12; i2 < p.ne02; i2 += p.ne12) {
|
||||
for (uint i1 = i11; i1 < p.ne01; i1 += p.ne11) {
|
||||
for (uint i0 = i10; i0 < p.ne00; i0 += p.ne10) {
|
||||
acc += data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data_d[get_doffset() + d_idx] = D_TYPE(acc);
|
||||
}
|
29
ggml/src/ggml-vulkan/vulkan-shaders/sub.comp
Normal file
29
ggml/src/ggml-vulkan/vulkan-shaders/sub.comp
Normal file
|
@ -0,0 +1,29 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
#include "types.comp"
|
||||
#include "generic_binary_head.comp"
|
||||
|
||||
const uint num_threads = 256;
|
||||
|
||||
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
uint idx = get_idx();
|
||||
|
||||
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
|
||||
const uint num_iter = 2;
|
||||
|
||||
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
||||
if (idx >= p.ne) {
|
||||
continue;
|
||||
}
|
||||
uint i00, i01, i02, i03;
|
||||
get_indices(idx, i00, i01, i02, i03);
|
||||
|
||||
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) - FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
|
||||
|
||||
idx += num_threads;
|
||||
}
|
||||
}
|
|
@ -443,6 +443,8 @@ void process_shaders() {
|
|||
string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
|
||||
|
@ -452,6 +454,7 @@ void process_shaders() {
|
|||
string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
|
@ -501,7 +504,9 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
|
||||
|
||||
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
|
||||
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
|
||||
|
||||
string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
|
||||
|
@ -513,6 +518,8 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
for (auto &c : compiles) {
|
||||
c.wait();
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue