vulkan: implement more backpropagation operators (#11914)

* vulkan: implement GGML_OP_ROPE_BACK

* vulkan: implement GGML_OP_RMS_NORM_BACK

* vulkan: implement GGML_OP_SILU_BACK

* vulkan: implement GGML_OP_SOFTMAX_BACK
This commit is contained in:
Rémy O 2025-02-25 12:04:45 +01:00 committed by GitHub
parent 0b52745649
commit 61d4f39dfe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 233 additions and 7 deletions

View file

@ -0,0 +1,55 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer G {A_TYPE data_a[];};
layout (binding = 1) readonly buffer X {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
shared FLOAT_TYPE sum_xx[BLOCK_SIZE];
shared FLOAT_TYPE sum_xg[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
// Compute derivative of x[i]/norm(x) = g[i]/norm(x) - x[i] dot(x,g)/KX / norm(x)^1.5
// partial sums for thread in warp
sum_xx[tid] = FLOAT_TYPE(0.0f);
sum_xg[tid] = FLOAT_TYPE(0.0f);
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
const FLOAT_TYPE gi = FLOAT_TYPE(data_a[row*p.KX + col]);
const FLOAT_TYPE xi = FLOAT_TYPE(data_b[row*p.KX + col]);
sum_xx[tid] += xi * xi;
sum_xg[tid] += xi * gi;
}
// sum up partial sums and write back result
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
sum_xx[tid] += sum_xx[tid + s];
sum_xg[tid] += sum_xg[tid + s];
}
barrier();
}
const FLOAT_TYPE eps = FLOAT_TYPE(p.param1);
const FLOAT_TYPE mean = sum_xx[0] / FLOAT_TYPE(p.KX);
const FLOAT_TYPE scale_g = inversesqrt(mean + eps);
const FLOAT_TYPE scale_x = -scale_g * sum_xg[0] / (sum_xx[0] + FLOAT_TYPE(p.KX) * eps);
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
data_d[row*p.KX + col] = D_TYPE(
scale_g * FLOAT_TYPE(data_a[row*p.KX + col]) +
scale_x * FLOAT_TYPE(data_b[row*p.KX + col]));
}
}

View file

@ -29,6 +29,7 @@ layout (push_constant) uniform parameter {
uint s1;
uint s2;
int sections[4];
uint is_back;
} p;
float rope_yarn_ramp(const float low, const float high, const uint i0) {
@ -48,6 +49,10 @@ void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
}
// Backprogagation uses inverted rotation
if (p.is_back != 0) {
theta = -theta;
}
cos_theta = cos(theta) * mscale;
sin_theta = sin(theta) * mscale;
}

View file

@ -0,0 +1,26 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer G {A_TYPE data_g[];};
layout (binding = 1) readonly buffer X {B_TYPE data_x[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
// Compute derivative of SiLU(x): 1/(1+exp(-x)) - x*exp(-x)/(1+exp(-x))^2
const float xi = float(data_x[i]);
const float s = 1.0f / (1.0f + exp(-xi));
data_d[i] = D_TYPE(data_g[i] * (s + xi * s * (1 - s)));
}

View file

@ -0,0 +1,50 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#include "generic_head.comp"
#include "types.comp"
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
// In this shader Y = softmax(X) and X is not provided as input.
layout (binding = 0) readonly buffer G {A_TYPE data_g[];};
layout (binding = 1) readonly buffer Y {B_TYPE data_y[];};
layout (binding = 2) buffer D {D_TYPE data_d[];};
shared FLOAT_TYPE sum_yg[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
FLOAT_TYPE scale = p.param1;
// partial sums for thread in warp
sum_yg[tid] = FLOAT_TYPE(0.0f);
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
const FLOAT_TYPE gi = FLOAT_TYPE(data_g[row*p.KX + col]);
const FLOAT_TYPE yi = FLOAT_TYPE(data_y[row*p.KX + col]);
sum_yg[tid] += yi * gi;
}
// sum up partial sums and write back result
barrier();
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
sum_yg[tid] += sum_yg[tid + s];
}
barrier();
}
const FLOAT_TYPE dot_yg = sum_yg[0];
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
data_d[row*p.KX + col] = D_TYPE(scale
* (FLOAT_TYPE(data_g[row*p.KX + col]) - dot_yg)
* FLOAT_TYPE(data_y[row*p.KX + col]));
}
}

View file

@ -427,6 +427,7 @@ void process_shaders() {
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
@ -477,6 +478,7 @@ void process_shaders() {
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
@ -485,6 +487,7 @@ void process_shaders() {
string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});