metal : add GGML_OP_CONV_TRANSPOSE_1D kernels (ggml/1026)
* wip * wip implementation f32 * kernel conv transpose 1d f32 working * initial commit
This commit is contained in:
parent
3b4f2e33e2
commit
667d70d170
2 changed files with 121 additions and 0 deletions
|
|
@ -2671,6 +2671,79 @@ kernel void kernel_im2col_ext(
|
|||
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
|
||||
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
|
||||
|
||||
typedef void (conv_transpose_1d_t)(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
constant int32_t & IC,
|
||||
constant int32_t & IL,
|
||||
constant int32_t & K,
|
||||
constant int32_t & s0,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]);
|
||||
|
||||
template <typename T>
|
||||
kernel void kernel_conv_transpose_1d(
|
||||
device const T * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
constant int32_t & IC,
|
||||
constant int32_t & IL,
|
||||
constant int32_t & K,
|
||||
constant int32_t & s0,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]) {
|
||||
|
||||
float v = 0.0f;
|
||||
|
||||
for (int64_t c = 0; c < IC; c++) {
|
||||
const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
|
||||
const int32_t input_offset = c * IL;
|
||||
|
||||
for (int64_t i = 0; i < IL; i++) {
|
||||
if (tgpig[0] >= i * s0 && tgpig[0] < i * s0 + K) {
|
||||
v += src0[kernel_offset + tgpig[0] - i * s0] * src1[input_offset + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
device float * dst_ptr = (device float *) (dst + tgpig[0] * nb0 + tgpig[1] * nb1);
|
||||
|
||||
dst_ptr[0] = v;
|
||||
}
|
||||
|
||||
template [[host_name("kernel_conv_transpose_1d_f32_f32")]]
|
||||
kernel void kernel_conv_transpose_1d<float>(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
constant int32_t & IC,
|
||||
constant int32_t & IL,
|
||||
constant int32_t & K,
|
||||
constant int32_t & s0,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]);
|
||||
|
||||
template [[host_name("kernel_conv_transpose_1d_f16_f32")]]
|
||||
kernel void kernel_conv_transpose_1d<half>(
|
||||
device const half * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
constant int32_t & IC,
|
||||
constant int32_t & IL,
|
||||
constant int32_t & K,
|
||||
constant int32_t & s0,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]);
|
||||
|
||||
kernel void kernel_upscale_f32(
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue