vulkan: support multi/vision rope, and noncontiguous rope (#11902)
This commit is contained in:
parent
c2ea16f260
commit
bf42a23d0a
7 changed files with 204 additions and 41 deletions
|
@ -25,6 +25,10 @@ layout (push_constant) uniform parameter {
|
|||
float corr_dims[2];
|
||||
float theta_scale;
|
||||
uint has_ff;
|
||||
uint ne02;
|
||||
uint s1;
|
||||
uint s2;
|
||||
int sections[4];
|
||||
} p;
|
||||
|
||||
float rope_yarn_ramp(const float low, const float high, const uint i0) {
|
||||
|
|
60
ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp
Normal file
60
ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp
Normal file
|
@ -0,0 +1,60 @@
|
|||
#version 450
|
||||
|
||||
#include "rope_head.comp"
|
||||
|
||||
void main() {
|
||||
const uint i0 = 2*gl_GlobalInvocationID.y;
|
||||
uint ne0 = p.ncols;
|
||||
uint ne1 = p.p_delta_rows;
|
||||
uint ne2 = p.ne02;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint row_dst = gl_GlobalInvocationID.x;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
const uint i = row_dst*ne0 + i0;
|
||||
|
||||
data_d[i + 0] = data_a[i + 0];
|
||||
data_d[i + 1] = data_a[i + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const uint row_x = row_dst % ne1;
|
||||
const uint channel_x = row_dst / ne1;
|
||||
|
||||
const uint idst = row_dst*ne0 + i0/2;
|
||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
||||
|
||||
const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
|
||||
const int sec_w = p.sections[1] + p.sections[0];
|
||||
const uint sector = (i0 / 2) % sect_dims;
|
||||
|
||||
float theta_base = 0.0;
|
||||
if (sector < p.sections[0]) {
|
||||
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= p.sections[0] && sector < sec_w) {
|
||||
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
|
||||
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w + p.sections[2]) {
|
||||
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
|
||||
|
||||
const float x0 = float(data_a[ix + 0]);
|
||||
const float x1 = float(data_a[ix + p.n_dims/2]);
|
||||
|
||||
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
|
||||
data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||
}
|
|
@ -3,15 +3,18 @@
|
|||
#include "rope_head.comp"
|
||||
|
||||
void main() {
|
||||
const uint col = gl_GlobalInvocationID.y * 2;
|
||||
const uint row = gl_GlobalInvocationID.x;
|
||||
const uint i0 = 2*gl_GlobalInvocationID.y;
|
||||
uint ne0 = p.ncols;
|
||||
uint ne1 = p.p_delta_rows;
|
||||
|
||||
if (col >= p.ncols) {
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (col >= p.n_dims) {
|
||||
const uint i = row*p.ncols + col;
|
||||
const uint row_dst = gl_GlobalInvocationID.x;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
const uint i = row_dst*ne0 + i0;
|
||||
|
||||
data_d[i + 0] = data_a[i + 0];
|
||||
data_d[i + 1] = data_a[i + 1];
|
||||
|
@ -19,19 +22,22 @@ void main() {
|
|||
return;
|
||||
}
|
||||
|
||||
const uint i = row*p.ncols + col/2;
|
||||
const uint i2 = row/p.p_delta_rows;
|
||||
const uint row_x = row_dst % ne1;
|
||||
const uint channel_x = row_dst / ne1;
|
||||
|
||||
const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
|
||||
const uint idst = row_dst*ne0 + i0/2;
|
||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
|
||||
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
|
||||
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
|
||||
|
||||
const float x0 = float(data_a[i + 0]);
|
||||
const float x1 = float(data_a[i + p.n_dims/2]);
|
||||
const float x0 = float(data_a[ix + 0]);
|
||||
const float x1 = float(data_a[ix + p.n_dims/2]);
|
||||
|
||||
data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
|
||||
data_d[i + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
|
||||
data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
|
|
|
@ -3,15 +3,18 @@
|
|||
#include "rope_head.comp"
|
||||
|
||||
void main() {
|
||||
const uint col = gl_GlobalInvocationID.y * 2;
|
||||
const uint row = gl_GlobalInvocationID.x;
|
||||
const uint i0 = 2*gl_GlobalInvocationID.y;
|
||||
uint ne0 = p.ncols;
|
||||
uint ne1 = p.p_delta_rows;
|
||||
|
||||
if (col >= p.ncols) {
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (col >= p.n_dims) {
|
||||
const uint i = row*p.ncols + col;
|
||||
const uint row_dst = gl_GlobalInvocationID.x;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
const uint i = row_dst*ne0 + i0;
|
||||
|
||||
data_d[i + 0] = data_a[i + 0];
|
||||
data_d[i + 1] = data_a[i + 1];
|
||||
|
@ -19,19 +22,22 @@ void main() {
|
|||
return;
|
||||
}
|
||||
|
||||
const uint i = row*p.ncols + col;
|
||||
const uint i2 = row/p.p_delta_rows;
|
||||
const uint row_x = row_dst % ne1;
|
||||
const uint channel_x = row_dst / ne1;
|
||||
|
||||
const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
|
||||
const uint idst = row_dst*ne0 + i0;
|
||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
|
||||
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
|
||||
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
|
||||
|
||||
const float x0 = float(data_a[i + 0]);
|
||||
const float x1 = float(data_a[i + 1]);
|
||||
const float x0 = float(data_a[ix + 0]);
|
||||
const float x1 = float(data_a[ix + 1]);
|
||||
|
||||
data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
|
||||
data_d[i + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
|
||||
data_d[idst + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
|
|
47
ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp
Normal file
47
ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp
Normal file
|
@ -0,0 +1,47 @@
|
|||
#version 450
|
||||
|
||||
#include "rope_head.comp"
|
||||
|
||||
void main() {
|
||||
const uint i0 = 2*gl_GlobalInvocationID.y;
|
||||
uint ne0 = p.ncols;
|
||||
uint ne1 = p.p_delta_rows;
|
||||
uint ne2 = p.ne02;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint row_dst = gl_GlobalInvocationID.x;
|
||||
|
||||
const uint row_x = row_dst % ne1;
|
||||
const uint channel_x = row_dst / ne1;
|
||||
|
||||
const uint idst = row_dst*ne0 + i0/2;
|
||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
||||
|
||||
const int sect_dims = p.sections[0] + p.sections[1];
|
||||
const int sec_w = p.sections[1] + p.sections[0];
|
||||
const uint sector = (i0 / 2) % sect_dims;
|
||||
|
||||
float theta_base = 0.0;
|
||||
if (sector < p.sections[0]) {
|
||||
const uint p0 = sector;
|
||||
theta_base = data_pos[channel_x]*pow(p.theta_scale, p0);
|
||||
}
|
||||
else if (sector >= p.sections[0] && sector < sec_w) {
|
||||
const uint p0 = sector - p.sections[0];
|
||||
theta_base = data_pos[channel_x + ne2]*pow(p.theta_scale, p0);
|
||||
}
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
|
||||
|
||||
const float x0 = float(data_a[ix + 0]);
|
||||
const float x1 = float(data_a[ix + p.n_dims]);
|
||||
|
||||
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
|
||||
data_d[idst + p.n_dims] = D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||
}
|
|
@ -491,6 +491,14 @@ void process_shaders() {
|
|||
string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
|
||||
string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
|
||||
string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
|
||||
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
|
||||
|
||||
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue