opencl: add multi and vision rope, gelu_quick and im2col (#12600)

* opencl: add `im2col`

* opencl: add `gelu_quick`

* opencl: add mrope

* opencl: add vision rope
This commit is contained in:
lhez 2025-03-27 08:08:08 -07:00 committed by GitHub
parent f125b8dccf
commit 5dec47dcd4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 774 additions and 14 deletions

View file

@ -404,6 +404,7 @@ kernel void kernel_scale(
// gelu
//------------------------------------------------------------------------------
#define GELU_COEF_A 0.044715f
#define GELU_QUICK_COEF -1.702f
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
kernel void kernel_gelu(
@ -434,6 +435,32 @@ kernel void kernel_gelu_4(
dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}
kernel void kernel_gelu_quick(
global float * src0,
ulong offset0,
global float * dst,
ulong offsetd
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
float x = src0[get_global_id(0)];
dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
}
kernel void kernel_gelu_quick_4(
global float4 * src0,
ulong offset0,
global float4 * dst,
ulong offsetd
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
float4 x = src0[get_global_id(0)];
dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
}
//------------------------------------------------------------------------------
// silu
//------------------------------------------------------------------------------
@ -1325,6 +1352,368 @@ kernel void kernel_rope_neox_f16(
}
}
kernel void kernel_rope_multi_f32(
global void * src0,
ulong offset0,
global int * src1,
ulong offset1,
global float * src2,
ulong offset2,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne0,
int ne1,
int ne2,
int ne3,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3,
int n_past,
int n_dims,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow,
int4 sections
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
src2 = (global float*)((global char*)src2 + offset2);
dst = (global float*)((global char*)dst + offsetd);
int i3 = get_group_id(2);
int i2 = get_group_id(1);
int i1 = get_group_id(0);
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
global int * pos = src1;
const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3;
const int sec_w = sections.s1 + sections.s0;
float inv_ndims = -1.f/n_dims;
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
if (i0 < n_dims) {
int ic = i0/2;
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0f;
if (sector < sections.s0) {
theta_base = pos[i2];
}
else if (sector >= sections.s0 && sector < sec_w) {
theta_base = pos[i2 + ne2 * 1];
}
else if (sector >= sec_w && sector < sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 2];
}
else if (sector >= sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 3];
}
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims/2];
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
} else {
global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
kernel void kernel_rope_multi_f16(
global void * src0,
ulong offset0,
global int * src1,
ulong offset1,
global float * src2,
ulong offset2,
global half * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne0,
int ne1,
int ne2,
int ne3,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3,
int n_past,
int n_dims,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow,
int4 sections
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
src2 = (global float*)((global char*)src2 + offset2);
dst = (global float*)((global char*)dst + offsetd);
int i3 = get_group_id(2);
int i2 = get_group_id(1);
int i1 = get_group_id(0);
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
global int * pos = src1;
const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3;
const int sec_w = sections.s1 + sections.s0;
float inv_ndims = -1.f/n_dims;
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
if (i0 < n_dims) {
int ic = i0/2;
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0f;
if (sector < sections.s0) {
theta_base = pos[i2];
}
else if (sector >= sections.s0 && sector < sec_w) {
theta_base = pos[i2 + ne2 * 1];
}
else if (sector >= sec_w && sector < sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 2];
}
else if (sector >= sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 3];
}
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims/2];
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
} else {
global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
kernel void kernel_rope_vision_f32(
global void * src0,
ulong offset0,
global int * src1,
ulong offset1,
global float * src2,
ulong offset2,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne0,
int ne1,
int ne2,
int ne3,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3,
int n_past,
int n_dims,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow,
int4 sections
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
src2 = (global float*)((global char*)src2 + offset2);
dst = (global float*)((global char*)dst + offsetd);
int i3 = get_group_id(2);
int i2 = get_group_id(1);
int i1 = get_group_id(0);
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
global int * pos = src1;
const int sect_dims = sections.s0 + sections.s1;
const int sec_w = sections.s1 + sections.s0;
float inv_ndims = -1.f/n_dims;
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
int ic = i0/2;
const int sector = (i0/2) % sect_dims;
float theta_base = 0.0f;
if (sector < sections.s0) {
const int p = sector;
theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p);
} else if (sector >= sections.s0 && sector < sec_w) {
const int p = sector - sections.s0;
theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p);
}
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims];
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
}
}
kernel void kernel_rope_vision_f16(
global void * src0,
ulong offset0,
global int * src1,
ulong offset1,
global float * src2,
ulong offset2,
global half * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne0,
int ne1,
int ne2,
int ne3,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3,
int n_past,
int n_dims,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow,
int4 sections
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
src2 = (global float*)((global char*)src2 + offset2);
dst = (global float*)((global char*)dst + offsetd);
int i3 = get_group_id(2);
int i2 = get_group_id(1);
int i1 = get_group_id(0);
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
global int * pos = src1;
const int sect_dims = sections.s0 + sections.s1;
const int sec_w = sections.s1 + sections.s0;
float inv_ndims = -1.f/n_dims;
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
int ic = i0/2;
const int sector = (i0/2) % sect_dims;
float theta_base = 0.0f;
if (sector < sections.s0) {
const int p = sector;
theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p);
} else if (sector >= sections.s0 && sector < sec_w) {
const int p = sector - sections.s0;
theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p);
}
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims];
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
}
}
//------------------------------------------------------------------------------
// cpy
//------------------------------------------------------------------------------

View file

@ -0,0 +1,146 @@
#ifdef cl_khr_fp16
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#elif defined(cl_amd_fp16)
#pragma OPENCL EXTENSION cl_amd_fp16 : enable
#else
#error "Half precision floating point not supportedby OpenCL implementation on your device."
#endif
#ifdef cl_khr_subgroups
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#elif defined(cl_intel_subgroups)
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#error "Subgroup not supported on your device."
#endif
#ifdef cl_intel_required_subgroup_size
// Always use subgroup size of 32 on Intel.
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
// Always use subgroups size of 64 on Adreno.
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#else
// TODO: do not know how to choose subgroup size on other GPUs.
#error "Selecting subgroup size is not supported on your device."
#endif
kernel void kernel_im2col_f32(
global float * src1,
ulong offset1,
global float * dst,
ulong offsetd,
ulong batch_offset,
ulong delta_offset,
long IW,
long IH,
long IC,
long OW,
long OH,
long KW,
long KH,
long pelements,
long CHW,
int s0,
int s1,
int p0,
int p1,
int d0,
int d1
) {
// threadIdx.x + blockIdx.x * blockDim.x
long i = get_global_id(0);
if (i >= pelements) {
return;
}
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
long ksize = OW * (KH > 1 ? KW : 1);
long kx = i / ksize;
long kd = kx * ksize;
long ky = (i - kd) / OW;
long ix = i % OW;
long oh = get_group_id(1);
long batch = get_group_id(2) / IC;
long ic = get_group_id(2) % IC;
long iiw = ix * s0 + kx * d0 - p0;
long iih = oh * s1 + ky * d1 - p1;
long offset_dst =
((batch * OH + oh) * OW + ix) * CHW +
(ic * (KW * KH) + ky * KW + kx);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = 0.0f;
} else {
long offset_src = ic * delta_offset + batch * batch_offset;
dst[offset_dst] = src1[offset_src + iih * IW + iiw];
}
}
kernel void kernel_im2col_f16(
global float * src1,
ulong offset1,
global half * dst,
ulong offsetd,
ulong batch_offset,
ulong delta_offset,
long IW,
long IH,
long IC,
long OW,
long OH,
long KW,
long KH,
long pelements,
long CHW,
int s0,
int s1,
int p0,
int p1,
int d0,
int d1
) {
long i = get_global_id(0);
if (i >= pelements) {
return;
}
src1 = (global float*)((global char*)src1 + offset1);
dst = (global half*)((global char*)dst + offsetd);
long ksize = OW * (KH > 1 ? KW : 1);
long kx = i / ksize;
long kd = kx * ksize;
long ky = (i - kd) / OW;
long ix = i % OW;
long oh = get_group_id(1);
long batch = get_group_id(2) / IC;
long ic = get_group_id(2) % IC;
long iiw = ix * s0 + kx * d0 - p0;
long iih = oh * s1 + ky * d1 - p1;
long offset_dst =
((batch * OH + oh) * OW + ix) * CHW +
(ic * (KW * KH) + ky * KW + kx);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = 0.0f;
} else {
long offset_src = ic * delta_offset + batch * batch_offset;
dst[offset_dst] = src1[offset_src + iih * IW + iiw];
}
}