vulkan: implement more backpropagation operators (#11914)
* vulkan: implement GGML_OP_ROPE_BACK * vulkan: implement GGML_OP_RMS_NORM_BACK * vulkan: implement GGML_OP_SILU_BACK * vulkan: implement GGML_OP_SOFTMAX_BACK
This commit is contained in:
parent
0b52745649
commit
61d4f39dfe
6 changed files with 233 additions and 7 deletions
55
ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp
Normal file
55
ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp
Normal file
|
@ -0,0 +1,55 @@
|
|||
#version 450
|
||||
|
||||
#include "generic_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#define BLOCK_SIZE 512
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer G {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer X {B_TYPE data_b[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
shared FLOAT_TYPE sum_xx[BLOCK_SIZE];
|
||||
shared FLOAT_TYPE sum_xg[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
// Compute derivative of x[i]/norm(x) = g[i]/norm(x) - x[i] dot(x,g)/KX / norm(x)^1.5
|
||||
|
||||
// partial sums for thread in warp
|
||||
sum_xx[tid] = FLOAT_TYPE(0.0f);
|
||||
sum_xg[tid] = FLOAT_TYPE(0.0f);
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
const FLOAT_TYPE gi = FLOAT_TYPE(data_a[row*p.KX + col]);
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_b[row*p.KX + col]);
|
||||
sum_xx[tid] += xi * xi;
|
||||
sum_xg[tid] += xi * gi;
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
barrier();
|
||||
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
sum_xx[tid] += sum_xx[tid + s];
|
||||
sum_xg[tid] += sum_xg[tid + s];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
const FLOAT_TYPE eps = FLOAT_TYPE(p.param1);
|
||||
const FLOAT_TYPE mean = sum_xx[0] / FLOAT_TYPE(p.KX);
|
||||
const FLOAT_TYPE scale_g = inversesqrt(mean + eps);
|
||||
const FLOAT_TYPE scale_x = -scale_g * sum_xg[0] / (sum_xx[0] + FLOAT_TYPE(p.KX) * eps);
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
data_d[row*p.KX + col] = D_TYPE(
|
||||
scale_g * FLOAT_TYPE(data_a[row*p.KX + col]) +
|
||||
scale_x * FLOAT_TYPE(data_b[row*p.KX + col]));
|
||||
}
|
||||
}
|
|
@ -29,6 +29,7 @@ layout (push_constant) uniform parameter {
|
|||
uint s1;
|
||||
uint s2;
|
||||
int sections[4];
|
||||
uint is_back;
|
||||
} p;
|
||||
|
||||
float rope_yarn_ramp(const float low, const float high, const uint i0) {
|
||||
|
@ -48,6 +49,10 @@ void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out
|
|||
// Get n-d magnitude scaling corrected for interpolation
|
||||
mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
|
||||
}
|
||||
// Backprogagation uses inverted rotation
|
||||
if (p.is_back != 0) {
|
||||
theta = -theta;
|
||||
}
|
||||
cos_theta = cos(theta) * mscale;
|
||||
sin_theta = sin(theta) * mscale;
|
||||
}
|
||||
|
|
26
ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp
Normal file
26
ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp
Normal file
|
@ -0,0 +1,26 @@
|
|||
#version 450
|
||||
|
||||
#include "generic_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer G {A_TYPE data_g[];};
|
||||
layout (binding = 1) readonly buffer X {B_TYPE data_x[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute derivative of SiLU(x): 1/(1+exp(-x)) - x*exp(-x)/(1+exp(-x))^2
|
||||
|
||||
const float xi = float(data_x[i]);
|
||||
const float s = 1.0f / (1.0f + exp(-xi));
|
||||
data_d[i] = D_TYPE(data_g[i] * (s + xi * s * (1 - s)));
|
||||
}
|
50
ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp
Normal file
50
ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp
Normal file
|
@ -0,0 +1,50 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
#include "generic_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
// In this shader Y = softmax(X) and X is not provided as input.
|
||||
|
||||
layout (binding = 0) readonly buffer G {A_TYPE data_g[];};
|
||||
layout (binding = 1) readonly buffer Y {B_TYPE data_y[];};
|
||||
layout (binding = 2) buffer D {D_TYPE data_d[];};
|
||||
|
||||
shared FLOAT_TYPE sum_yg[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
FLOAT_TYPE scale = p.param1;
|
||||
|
||||
// partial sums for thread in warp
|
||||
sum_yg[tid] = FLOAT_TYPE(0.0f);
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
const FLOAT_TYPE gi = FLOAT_TYPE(data_g[row*p.KX + col]);
|
||||
const FLOAT_TYPE yi = FLOAT_TYPE(data_y[row*p.KX + col]);
|
||||
sum_yg[tid] += yi * gi;
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
barrier();
|
||||
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
sum_yg[tid] += sum_yg[tid + s];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
const FLOAT_TYPE dot_yg = sum_yg[0];
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
data_d[row*p.KX + col] = D_TYPE(scale
|
||||
* (FLOAT_TYPE(data_g[row*p.KX + col]) - dot_yg)
|
||||
* FLOAT_TYPE(data_y[row*p.KX + col]));
|
||||
}
|
||||
}
|
|
@ -427,6 +427,7 @@ void process_shaders() {
|
|||
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
||||
|
@ -477,6 +478,7 @@ void process_shaders() {
|
|||
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
@ -485,6 +487,7 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue