metal : add POOL2D and fix IM2COL (#9943)
* add pool_2d Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * fix im2col and add unittest for N>=1024 Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * add tests for N % 1024 != 0 Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * remove trailing whitespaces Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * apply suggestions Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * apply more optimization - original IM2COL kernel + _ext with MIN() Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * apply review: change kernel name of pool_2d Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * apply review Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> * fix more formatting and enhance readability Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com> --------- Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>
This commit is contained in:
parent
873279b159
commit
4c9388fb96
3 changed files with 298 additions and 18 deletions
|
|
@ -1933,6 +1933,85 @@ kernel void kernel_im2col(
|
|||
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
|
||||
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
|
||||
|
||||
typedef void (im2col_ext_t)(
|
||||
device const float * x,
|
||||
device char * dst,
|
||||
constant int32_t & ofs0,
|
||||
constant int32_t & ofs1,
|
||||
constant int32_t & IW,
|
||||
constant int32_t & IH,
|
||||
constant int32_t & CHW,
|
||||
constant int32_t & s0,
|
||||
constant int32_t & s1,
|
||||
constant int32_t & p0,
|
||||
constant int32_t & p1,
|
||||
constant int32_t & d0,
|
||||
constant int32_t & d1,
|
||||
constant int32_t & N,
|
||||
constant int32_t & KH,
|
||||
constant int32_t & KW,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
template <typename T>
|
||||
kernel void kernel_im2col_ext(
|
||||
device const float * x,
|
||||
device char * dst,
|
||||
constant int32_t & ofs0,
|
||||
constant int32_t & ofs1,
|
||||
constant int32_t & IW,
|
||||
constant int32_t & IH,
|
||||
constant int32_t & CHW,
|
||||
constant int32_t & s0,
|
||||
constant int32_t & s1,
|
||||
constant int32_t & p0,
|
||||
constant int32_t & p1,
|
||||
constant int32_t & d0,
|
||||
constant int32_t & d1,
|
||||
constant int32_t & N,
|
||||
constant int32_t & KH,
|
||||
constant int32_t & KW,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
|
||||
const int32_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]
|
||||
|
||||
const int32_t d = tgpig[0] / CHW;
|
||||
const int32_t chw = tgpig[0] % CHW;
|
||||
const int32_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
|
||||
const int32_t HW = tgpig[0] % KHW;
|
||||
|
||||
const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0];
|
||||
if (tpitg_0 >= N) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t tpitg_1 = HW / KW;
|
||||
const int32_t tpitg_2 = HW % KW;
|
||||
|
||||
const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
|
||||
const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
|
||||
|
||||
const int32_t offset_dst =
|
||||
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
|
||||
(tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
|
||||
|
||||
device T * pdst = (device T *) (dst);
|
||||
|
||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||
pdst[offset_dst] = 0.0f;
|
||||
} else {
|
||||
const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
|
||||
pdst[offset_dst] = x[offset_src + iih * IW + iiw];
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
|
||||
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
|
||||
|
||||
kernel void kernel_upscale_f32(
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
|
|
@ -6372,3 +6451,102 @@ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t
|
|||
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
|
||||
|
||||
kernel void kernel_pool_2d_max_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant int32_t & k0,
|
||||
constant int32_t & k1,
|
||||
constant int32_t & s0,
|
||||
constant int32_t & s1,
|
||||
constant int32_t & p0,
|
||||
constant int32_t & p1,
|
||||
constant int64_t & IH,
|
||||
constant int64_t & IW,
|
||||
constant int64_t & OH,
|
||||
constant int64_t & OW,
|
||||
constant int64_t & parallel_elements,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= parallel_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int idx = gid;
|
||||
const int I_HW = IH * IW;
|
||||
const int O_HW = OH * OW;
|
||||
const int nc = idx / O_HW;
|
||||
const int cur_oh = idx % O_HW / OW;
|
||||
const int cur_ow = idx % O_HW % OW;
|
||||
|
||||
device const float * i_ptr = src0 + nc * I_HW;
|
||||
device float * o_ptr = dst + nc * O_HW;
|
||||
|
||||
const int start_h = cur_oh * s1 - p1;
|
||||
const int bh = MAX(0, start_h);
|
||||
const int eh = MIN(IH, start_h + k1);
|
||||
const int start_w = cur_ow * s0 - p0;
|
||||
const int bw = MAX(0, start_w);
|
||||
const int ew = MIN(IW, start_w + k0);
|
||||
|
||||
float res = -INFINITY;
|
||||
|
||||
for (int i = bh; i < eh; i += 1) {
|
||||
for (int j = bw; j < ew; j += 1) {
|
||||
res = MAX(res, i_ptr[i * IW + j]);
|
||||
}
|
||||
}
|
||||
|
||||
o_ptr[cur_oh * OW + cur_ow] = res;
|
||||
}
|
||||
|
||||
kernel void kernel_pool_2d_avg_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant int32_t & k0,
|
||||
constant int32_t & k1,
|
||||
constant int32_t & s0,
|
||||
constant int32_t & s1,
|
||||
constant int32_t & p0,
|
||||
constant int32_t & p1,
|
||||
constant int64_t & IH,
|
||||
constant int64_t & IW,
|
||||
constant int64_t & OH,
|
||||
constant int64_t & OW,
|
||||
constant int64_t & parallel_elements,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= parallel_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int idx = gid;
|
||||
const int I_HW = IH * IW;
|
||||
const int O_HW = OH * OW;
|
||||
const int nc = idx / O_HW;
|
||||
const int cur_oh = idx % O_HW / OW;
|
||||
const int cur_ow = idx % O_HW % OW;
|
||||
|
||||
device const float * i_ptr = src0 + nc * I_HW;
|
||||
device float * o_ptr = dst + nc * O_HW;
|
||||
|
||||
const int start_h = cur_oh * s1 - p1;
|
||||
const int bh = MAX(0, start_h);
|
||||
const int eh = MIN(IH, start_h + k1);
|
||||
const int start_w = cur_ow * s0 - p0;
|
||||
const int bw = MAX(0, start_w);
|
||||
const int ew = MIN(IW, start_w + k0);
|
||||
// const float scale = 1. / ((eh - bh) * (ew - bw));
|
||||
const float scale = 1. / (k0 * k1);
|
||||
|
||||
float res = 0;
|
||||
|
||||
for (int i = bh; i < eh; i += 1) {
|
||||
for (int j = bw; j < ew; j += 1) {
|
||||
float cur = i_ptr[i * IW + j];
|
||||
res += cur * scale;
|
||||
}
|
||||
}
|
||||
|
||||
o_ptr[cur_oh * OW + cur_ow] = res;
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue