ggml-vulkan: adds support for op CONV_TRANSPOSE_1D (#13813)
* * ggml-vulkan: adds op CONV_TRANSPOSE_1D * test-backend-ops: adds more spohisticated tests for CONV_TRANSPOSE_1D * Missing barrier added to shader. Number of additional tests reduced to 108. * * Fixes typo in variable name. * Removes extra whitespaces. * Adds int64->int32 casts to prevent possible warnings. * Problem size reduced in tests to pass tests with llvmpipe. * supports_op condition moved from unintended position
This commit is contained in:
parent
3e63a58ef7
commit
0d3984424f
4 changed files with 186 additions and 2 deletions
98
ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp
Normal file
98
ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
#version 450
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin]
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin]
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; // dst - result [KL, Cout]
|
||||
|
||||
layout(local_size_x = 128 , local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint32_t Cout;
|
||||
uint32_t Cin;
|
||||
uint32_t K;
|
||||
uint32_t L;
|
||||
uint32_t KL;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb11;
|
||||
uint32_t nb1;
|
||||
|
||||
int32_t s0;
|
||||
} p;
|
||||
|
||||
|
||||
uint32_t Cout_idx = gl_WorkGroupID.x;
|
||||
const uint32_t bs = gl_WorkGroupSize.x;
|
||||
uint32_t tid = gl_LocalInvocationID.x;
|
||||
// Code is more straightforward if we assume it is bs*s0+K instead of (bs-1)*s0+K.
|
||||
uint32_t tmp_len = bs*p.s0+p.K;
|
||||
shared D_TYPE tmp[4096];
|
||||
|
||||
uint splitWork(uint workSize){
|
||||
return (bs + workSize -1) / bs;
|
||||
}
|
||||
|
||||
void main(){
|
||||
for(uint32_t i = 0; i < splitWork(tmp_len); i++){
|
||||
uint32_t idx = i*bs+tid;
|
||||
if(idx < tmp_len){
|
||||
tmp[idx] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t L_blocks = splitWork(p.L);
|
||||
for(uint32_t L_block_id = 0; L_block_id < L_blocks; L_block_id++){
|
||||
if(L_block_id > 0){
|
||||
barrier();
|
||||
// Shift values in tmp to the current processing window
|
||||
for(int i = 0; i < splitWork(tmp_len); i++){
|
||||
uint32_t idx = i*bs+tid;
|
||||
if(idx >= bs*p.s0 && idx < tmp_len){
|
||||
tmp[idx-bs*p.s0] = tmp[idx];
|
||||
tmp[idx] = 0.0;
|
||||
}else if(idx >= p.K && idx < bs*p.s0){
|
||||
tmp[idx] = 0.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
// Save contributions of the block to tmp
|
||||
uint32_t L_idx = L_block_id*bs + tid;
|
||||
for(uint32_t K_idx = 0; K_idx < p.K; K_idx++){
|
||||
D_TYPE dp = 0.0;
|
||||
for(uint32_t Cin_idx = 0; Cin_idx < p.Cin; Cin_idx++){
|
||||
A_TYPE elemKrn = data_a[K_idx + Cout_idx * p.nb01 + Cin_idx * p.nb02];
|
||||
if(L_idx < p.L){
|
||||
B_TYPE elemInp = data_b[L_idx + Cin_idx*p.nb11];
|
||||
dp = fma(elemKrn, elemInp, dp);
|
||||
}
|
||||
}
|
||||
tmp[tid*p.s0 + K_idx] += dp;
|
||||
barrier();
|
||||
}
|
||||
|
||||
// Save the computed values except the last block that can have different size
|
||||
uint32_t KLb_idx = L_block_id*bs*p.s0;
|
||||
if(L_block_id < L_blocks-1){
|
||||
for(uint32_t s0_idx = 0; s0_idx < p.s0; s0_idx++){
|
||||
uint32_t sh_idx = p.s0*tid+s0_idx;
|
||||
uint32_t KL_idx = KLb_idx+sh_idx;
|
||||
if(KL_idx < p.KL){
|
||||
data_d[KL_idx + Cout_idx*p.nb1] = tmp[sh_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(uint32_t i = 0; i < splitWork(tmp_len); i++){
|
||||
uint32_t idx = i*bs+tid;
|
||||
uint32_t KL_idx = (L_blocks-1)*bs*p.s0+idx;
|
||||
if(KL_idx < p.KL){
|
||||
data_d[KL_idx + Cout_idx*p.nb1] = tmp[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -622,6 +622,8 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue