opencl: Noncontiguous norm
, rms_norm
, disable fp16
for some ops (#12217)
* opencl: support noncontiguous `norm` * opencl: support noncontiguous `rms_norm` * opencl: disable fp16 for `ADD`, `MUL`, `SCALE`, `RELU`, `GELU`, `SILU`, `CLAMP`
This commit is contained in:
parent
776f9e59cc
commit
d76a86d967
2 changed files with 65 additions and 31 deletions
|
@ -506,14 +506,23 @@ kernel void kernel_norm(
|
|||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
float eps,
|
||||
local float * sum
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
dst = (global void*)((global char*)dst + offsetd);
|
||||
|
||||
global float * x = (global float *) ((global char *) src0 + get_group_id(0)*nb01);
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0);
|
||||
|
||||
global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
|
||||
// MEAN
|
||||
// parallel sum
|
||||
|
@ -533,7 +542,7 @@ kernel void kernel_norm(
|
|||
|
||||
// recenter and VARIANCE
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
global float * y = dst + get_group_id(0)*ne00;
|
||||
global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
sum[get_local_id(0)] = 0.0f;
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
y[i00] = x[i00] - mean;
|
||||
|
@ -566,14 +575,23 @@ kernel void kernel_rms_norm(
|
|||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
float eps,
|
||||
local float * sum // Note, the size depends on number of subgroups
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
global float4 * x = (global float4 *) ((global char *) src0 + get_group_id(0)*nb01);
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0);
|
||||
|
||||
global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
global float * x_scalar = (global float *) x;
|
||||
float4 sumf = 0;
|
||||
float all_sum = 0;
|
||||
|
@ -607,7 +625,7 @@ kernel void kernel_rms_norm(
|
|||
const float mean = sum[0];
|
||||
const float scale = 1.0f/sqrt(mean + eps);
|
||||
|
||||
global float4 * y = (global float4 *) (dst + get_group_id(0)*ne00);
|
||||
global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
global float * y_scalar = (global float *) y;
|
||||
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
||||
y[i00] = x[i00] * scale;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue