clip : refactor set input for cgraph + fix qwen2.5vl input (#13136)
* clip : refactor set input for cgraph * more strict assert * minicpmv : use clip_n_mmproj_embd instead of copying the same code everywhere * split qwen2 and qwen2.5 code blocks * minor style fix
This commit is contained in:
parent
a4c340f974
commit
5fa9e63be8
1 changed files with 215 additions and 258 deletions
|
@ -170,8 +170,8 @@ struct clip_hparams {
|
||||||
std::vector<int32_t> image_grid_pinpoints;
|
std::vector<int32_t> image_grid_pinpoints;
|
||||||
int32_t image_crop_resolution;
|
int32_t image_crop_resolution;
|
||||||
std::unordered_set<int32_t> vision_feature_layer;
|
std::unordered_set<int32_t> vision_feature_layer;
|
||||||
int32_t attn_window_size;
|
int32_t attn_window_size = 0;
|
||||||
int32_t n_wa_pattern;
|
int32_t n_wa_pattern = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct clip_layer {
|
struct clip_layer {
|
||||||
|
@ -325,7 +325,6 @@ struct clip_ctx {
|
||||||
float image_std[3];
|
float image_std[3];
|
||||||
bool use_gelu = false;
|
bool use_gelu = false;
|
||||||
bool use_silu = false;
|
bool use_silu = false;
|
||||||
int32_t ftype = 1;
|
|
||||||
|
|
||||||
gguf_context_ptr ctx_gguf;
|
gguf_context_ptr ctx_gguf;
|
||||||
ggml_context_ptr ctx_data;
|
ggml_context_ptr ctx_data;
|
||||||
|
@ -776,7 +775,6 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_
|
||||||
const int image_size_width = imgs.entries[0]->nx;
|
const int image_size_width = imgs.entries[0]->nx;
|
||||||
const int image_size_height = imgs.entries[0]->ny;
|
const int image_size_height = imgs.entries[0]->ny;
|
||||||
|
|
||||||
const bool use_mrope = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL;
|
|
||||||
const bool use_window_attn = hparams.n_wa_pattern > 0;
|
const bool use_window_attn = hparams.n_wa_pattern > 0;
|
||||||
|
|
||||||
const int n_wa_pattern = hparams.n_wa_pattern;
|
const int n_wa_pattern = hparams.n_wa_pattern;
|
||||||
|
@ -785,10 +783,11 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_
|
||||||
const int patches_w = image_size_width / patch_size;
|
const int patches_w = image_size_width / patch_size;
|
||||||
const int patches_h = image_size_height / patch_size;
|
const int patches_h = image_size_height / patch_size;
|
||||||
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
|
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
|
||||||
const int num_position_ids = use_mrope ? num_positions * 4 : num_positions;
|
const int num_position_ids = num_positions * 4; // m-rope requires 4 dim per position
|
||||||
const int hidden_size = hparams.hidden_size;
|
const int hidden_size = hparams.hidden_size;
|
||||||
const int n_head = hparams.n_head;
|
const int n_head = hparams.n_head;
|
||||||
const int d_head = hidden_size / n_head;
|
const int d_head = hidden_size / n_head;
|
||||||
|
const int n_layer = hparams.n_layer;
|
||||||
const float eps = hparams.eps;
|
const float eps = hparams.eps;
|
||||||
|
|
||||||
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
|
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
|
||||||
|
@ -870,7 +869,7 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_
|
||||||
}
|
}
|
||||||
|
|
||||||
// loop over layers
|
// loop over layers
|
||||||
for (int il = 0; il < ctx->max_feature_layer; il++) {
|
for (int il = 0; il < n_layer; il++) {
|
||||||
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
|
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
|
||||||
|
|
||||||
// rmsnorm1
|
// rmsnorm1
|
||||||
|
@ -1115,15 +1114,8 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
||||||
if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
|
if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
|
||||||
int pos_w = image_size_width/patch_size;
|
int pos_w = image_size_width/patch_size;
|
||||||
int pos_h = image_size_height/patch_size;
|
int pos_h = image_size_height/patch_size;
|
||||||
if (ctx->minicpmv_version == 2) {
|
int n_output_dim = clip_n_mmproj_embd(ctx);
|
||||||
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1);
|
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_output_dim, pos_w * pos_h, 1);
|
||||||
}
|
|
||||||
else if (ctx->minicpmv_version == 3) {
|
|
||||||
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
|
|
||||||
}
|
|
||||||
else if (ctx->minicpmv_version == 4) {
|
|
||||||
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
|
|
||||||
}
|
|
||||||
ggml_set_name(pos_embed, "pos_embed");
|
ggml_set_name(pos_embed, "pos_embed");
|
||||||
ggml_set_input(pos_embed);
|
ggml_set_input(pos_embed);
|
||||||
}
|
}
|
||||||
|
@ -1461,23 +1453,17 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
||||||
}
|
}
|
||||||
|
|
||||||
{ // attention
|
{ // attention
|
||||||
int hidden_size = 4096;
|
int hidden_size = clip_n_mmproj_embd(ctx);
|
||||||
const int d_head = 128;
|
const int d_head = 128;
|
||||||
int n_head = hidden_size/d_head;
|
int n_head = hidden_size/d_head;
|
||||||
int num_query = 96;
|
int num_query = 96;
|
||||||
if (ctx->minicpmv_version == 2) {
|
if (ctx->minicpmv_version == 2) {
|
||||||
hidden_size = 4096;
|
|
||||||
n_head = hidden_size/d_head;
|
|
||||||
num_query = 96;
|
num_query = 96;
|
||||||
}
|
}
|
||||||
else if (ctx->minicpmv_version == 3) {
|
else if (ctx->minicpmv_version == 3) {
|
||||||
hidden_size = 3584;
|
|
||||||
n_head = hidden_size/d_head;
|
|
||||||
num_query = 64;
|
num_query = 64;
|
||||||
}
|
}
|
||||||
else if (ctx->minicpmv_version == 4) {
|
else if (ctx->minicpmv_version == 4) {
|
||||||
hidden_size = 3584;
|
|
||||||
n_head = hidden_size/d_head;
|
|
||||||
num_query = 64;
|
num_query = 64;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1760,6 +1746,8 @@ struct clip_model_loader {
|
||||||
LOG_INF("%s: projector: %s\n", __func__, proj_type.c_str());
|
LOG_INF("%s: projector: %s\n", __func__, proj_type.c_str());
|
||||||
LOG_INF("%s: has_llava_proj: %d\n", __func__, ctx_clip.has_llava_projector);
|
LOG_INF("%s: has_llava_proj: %d\n", __func__, ctx_clip.has_llava_projector);
|
||||||
LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version);
|
LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version);
|
||||||
|
LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor);
|
||||||
|
LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
|
||||||
LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
|
LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
|
||||||
LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
|
LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
|
||||||
}
|
}
|
||||||
|
@ -3038,15 +3026,43 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||||
const int patch_size = hparams.patch_size;
|
const int patch_size = hparams.patch_size;
|
||||||
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
|
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
|
||||||
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
|
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
|
||||||
const int pos_w = ctx->load_image_size.width / patch_size;
|
const int pos_w = ctx->load_image_size.width / patch_size;
|
||||||
const int pos_h = ctx->load_image_size.height / patch_size;
|
const int pos_h = ctx->load_image_size.height / patch_size;
|
||||||
|
|
||||||
const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl
|
const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl
|
||||||
|
|
||||||
|
auto get_inp_tensor = [&gf](const char * name) {
|
||||||
|
struct ggml_tensor * inp = ggml_graph_get_tensor(gf, name);
|
||||||
|
if (inp == nullptr) {
|
||||||
|
GGML_ABORT("Failed to get tensor %s", name);
|
||||||
|
}
|
||||||
|
if (!(inp->flags & GGML_TENSOR_FLAG_INPUT)) {
|
||||||
|
GGML_ABORT("Tensor %s is not an input tensor", name);
|
||||||
|
}
|
||||||
|
return inp;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto set_input_f32 = [&get_inp_tensor](const char * name, std::vector<float> & values) {
|
||||||
|
ggml_tensor * cur = get_inp_tensor(name);
|
||||||
|
GGML_ASSERT(cur->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(ggml_nelements(cur) == (int64_t)values.size());
|
||||||
|
ggml_backend_tensor_set(cur, values.data(), 0, ggml_nbytes(cur));
|
||||||
|
};
|
||||||
|
|
||||||
|
auto set_input_i32 = [&get_inp_tensor](const char * name, std::vector<int32_t> & values) {
|
||||||
|
ggml_tensor * cur = get_inp_tensor(name);
|
||||||
|
GGML_ASSERT(cur->type == GGML_TYPE_I32);
|
||||||
|
GGML_ASSERT(ggml_nelements(cur) == (int64_t)values.size());
|
||||||
|
ggml_backend_tensor_set(cur, values.data(), 0, ggml_nbytes(cur));
|
||||||
|
};
|
||||||
|
|
||||||
|
// set input pixel values
|
||||||
{
|
{
|
||||||
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
|
size_t nelem = 0;
|
||||||
std::vector<float> inp_data(ggml_nelements(inp_raw));
|
for (const auto & img : imgs.entries) {
|
||||||
float * data = inp_data.data();
|
nelem += img->nx * img->ny * 3;
|
||||||
|
}
|
||||||
|
std::vector<float> inp_raw(nelem);
|
||||||
|
|
||||||
// layout of data (note: the channel dim is unrolled to better visualize the layout):
|
// layout of data (note: the channel dim is unrolled to better visualize the layout):
|
||||||
//
|
//
|
||||||
|
@ -3065,7 +3081,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||||
const int n = nx * ny;
|
const int n = nx * ny;
|
||||||
|
|
||||||
for (int b = 0; b < batch_size; b++) {
|
for (int b = 0; b < batch_size; b++) {
|
||||||
float * batch_entry = data + b * (3*n);
|
float * batch_entry = inp_raw.data() + b * (3*n);
|
||||||
for (int y = 0; y < ny; y++) {
|
for (int y = 0; y < ny; y++) {
|
||||||
for (int x = 0; x < nx; x++) {
|
for (int x = 0; x < nx; x++) {
|
||||||
size_t base_src = 3*(y * nx + x); // idx of the first channel
|
size_t base_src = 3*(y * nx + x); // idx of the first channel
|
||||||
|
@ -3077,266 +3093,207 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
|
set_input_f32("inp_raw", inp_raw);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
|
// set input per projector
|
||||||
{
|
switch (ctx->proj_type) {
|
||||||
// inspired from siglip:
|
case PROJECTOR_TYPE_MINICPMV:
|
||||||
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit
|
{
|
||||||
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
|
// inspired from siglip:
|
||||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit
|
||||||
std::vector<int> pos_data(ggml_nelements(positions));
|
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
|
||||||
int * data = pos_data.data();
|
std::vector<int32_t> positions(pos_h * pos_w);
|
||||||
int bucket_coords_h[1024];
|
int bucket_coords_h[1024];
|
||||||
int bucket_coords_w[1024];
|
int bucket_coords_w[1024];
|
||||||
for (int i = 0; i < pos_h; i++){
|
for (int i = 0; i < pos_h; i++){
|
||||||
bucket_coords_h[i] = std::floor(70.0*i/pos_h);
|
bucket_coords_h[i] = std::floor(70.0*i/pos_h);
|
||||||
}
|
|
||||||
for (int i = 0; i < pos_w; i++){
|
|
||||||
bucket_coords_w[i] = std::floor(70.0*i/pos_w);
|
|
||||||
}
|
|
||||||
for (int i = 0, id = 0; i < pos_h; i++){
|
|
||||||
for (int j = 0; j < pos_w; j++){
|
|
||||||
data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
|
|
||||||
}
|
}
|
||||||
}
|
for (int i = 0; i < pos_w; i++){
|
||||||
ggml_backend_tensor_set(positions, data, 0, ggml_nbytes(positions));
|
bucket_coords_w[i] = std::floor(70.0*i/pos_w);
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
// inspired from resampler of Qwen-VL:
|
|
||||||
// -> https://huggingface.co/Qwen/Qwen-VL/tree/main
|
|
||||||
// -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23
|
|
||||||
struct ggml_tensor * pos_embed = ggml_graph_get_tensor(gf, "pos_embed");
|
|
||||||
int embed_dim = 4096;
|
|
||||||
if (ctx->minicpmv_version == 2) {
|
|
||||||
embed_dim = 4096;
|
|
||||||
}
|
|
||||||
else if (ctx->minicpmv_version == 3) {
|
|
||||||
embed_dim = 3584;
|
|
||||||
}
|
|
||||||
else if (ctx->minicpmv_version == 4) {
|
|
||||||
embed_dim = 3584;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
GGML_ABORT("Unknown minicpmv version");
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO @ngxson : this is very inefficient, can we do this using ggml_sin and ggml_cos?
|
|
||||||
auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
|
|
||||||
|
|
||||||
std::vector<float> pos_data(ggml_nelements(pos_embed));
|
|
||||||
float * data = pos_data.data();
|
|
||||||
for(int i = 0; i < pos_w * pos_h; ++i){
|
|
||||||
for(int j = 0; j < embed_dim; ++j){
|
|
||||||
data[i * embed_dim + j] = pos_embed_t[i][j];
|
|
||||||
}
|
}
|
||||||
}
|
for (int i = 0, id = 0; i < pos_h; i++){
|
||||||
|
for (int j = 0; j < pos_w; j++){
|
||||||
|
positions[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
set_input_i32("positions", positions);
|
||||||
|
|
||||||
ggml_backend_tensor_set(pos_embed, data, 0, ggml_nbytes(pos_embed));
|
// inspired from resampler of Qwen-VL:
|
||||||
}
|
// -> https://huggingface.co/Qwen/Qwen-VL/tree/main
|
||||||
}
|
// -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23
|
||||||
else {
|
int embed_dim = clip_n_mmproj_embd(ctx);
|
||||||
// non-minicpmv models
|
|
||||||
|
|
||||||
if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
|
// TODO @ngxson : this is very inefficient, can we do this using ggml_sin and ggml_cos?
|
||||||
// pw * ph = number of tokens output by ViT after apply patch merger
|
auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
|
||||||
// ipw * ipw = number of vision token been processed inside ViT
|
|
||||||
const int merge_ratio = 2;
|
|
||||||
const int pw = image_size_width / patch_size / merge_ratio;
|
|
||||||
const int ph = image_size_height / patch_size / merge_ratio;
|
|
||||||
const int ipw = image_size_width / patch_size;
|
|
||||||
const int iph = image_size_height / patch_size;
|
|
||||||
|
|
||||||
std::vector<int> idx (ph * pw);
|
std::vector<float> pos_embed(embed_dim * pos_w * pos_h);
|
||||||
std::vector<int> inv_idx(ph * pw);
|
for(int i = 0; i < pos_w * pos_h; ++i){
|
||||||
|
for(int j = 0; j < embed_dim; ++j){
|
||||||
|
pos_embed[i * embed_dim + j] = pos_embed_t[i][j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (use_window_attn) {
|
set_input_f32("pos_embed", pos_embed);
|
||||||
const int attn_window_size = 112;
|
} break;
|
||||||
struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
|
case PROJECTOR_TYPE_QWEN2VL:
|
||||||
struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
|
{
|
||||||
struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
|
const int merge_ratio = 2;
|
||||||
|
const int pw = image_size_width / patch_size;
|
||||||
const int grid_window = attn_window_size / patch_size / merge_ratio;
|
const int ph = image_size_height / patch_size;
|
||||||
int dst = 0;
|
std::vector<int> positions(num_positions * 4);
|
||||||
// [num_vision_tokens, num_vision_tokens] attention mask tensor
|
int ptr = 0;
|
||||||
std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
|
for (int y = 0; y < ph; y += merge_ratio) {
|
||||||
int mask_row = 0;
|
for (int x = 0; x < pw; x += merge_ratio) {
|
||||||
|
for (int dy = 0; dy < 2; dy++) {
|
||||||
for (int y = 0; y < ph; y += grid_window)
|
for (int dx = 0; dx < 2; dx++) {
|
||||||
{
|
positions[ ptr] = y + dy;
|
||||||
for (int x = 0; x < pw; x += grid_window)
|
positions[ num_patches + ptr] = x + dx;
|
||||||
{
|
positions[2 * num_patches + ptr] = y + dy;
|
||||||
const int win_h = std::min(grid_window, ph - y);
|
positions[3 * num_patches + ptr] = x + dx;
|
||||||
const int win_w = std::min(grid_window, pw - x);
|
ptr++;
|
||||||
const int dst_0 = dst;
|
|
||||||
// group all tokens belong to the same window togather (to a continue range)
|
|
||||||
for (int dy = 0; dy < win_h; dy++) {
|
|
||||||
for (int dx = 0; dx < win_w; dx++) {
|
|
||||||
const int src = (y + dy) * pw + (x + dx);
|
|
||||||
assert(src < (int)idx.size());
|
|
||||||
assert(dst < (int)inv_idx.size());
|
|
||||||
idx [src] = dst;
|
|
||||||
inv_idx[dst] = src;
|
|
||||||
dst++;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
|
|
||||||
int row_offset = mask_row * (ipw * iph);
|
|
||||||
std::fill(
|
|
||||||
mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
|
|
||||||
mask.begin() + row_offset + (dst * merge_ratio * merge_ratio),
|
|
||||||
0.0);
|
|
||||||
mask_row++;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
|
set_input_i32("positions", positions);
|
||||||
ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
|
} break;
|
||||||
ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
|
case PROJECTOR_TYPE_QWEN25VL:
|
||||||
} else {
|
|
||||||
std::iota(idx.begin(), idx.end(), 0);
|
|
||||||
std::iota(inv_idx.begin(), inv_idx.end(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
|
||||||
const int mpow = merge_ratio * merge_ratio;
|
|
||||||
std::vector<int> positions_data(ggml_nelements(positions));
|
|
||||||
int * data = positions_data.data();
|
|
||||||
|
|
||||||
int ptr = 0;
|
|
||||||
for (int y = 0; y < iph; y += merge_ratio)
|
|
||||||
{
|
{
|
||||||
for (int x = 0; x < ipw; x += merge_ratio)
|
// pw * ph = number of tokens output by ViT after apply patch merger
|
||||||
{
|
// ipw * ipw = number of vision token been processed inside ViT
|
||||||
for (int dy = 0; dy < 2; dy++) {
|
const int merge_ratio = 2;
|
||||||
for (int dx = 0; dx < 2; dx++) {
|
const int pw = image_size_width / patch_size / merge_ratio;
|
||||||
auto remap = idx[ptr / mpow];
|
const int ph = image_size_height / patch_size / merge_ratio;
|
||||||
remap = remap * mpow + (ptr % mpow);
|
const int ipw = image_size_width / patch_size;
|
||||||
|
const int iph = image_size_height / patch_size;
|
||||||
|
|
||||||
data[ remap] = y + dy;
|
std::vector<int> idx (ph * pw);
|
||||||
data[ num_patches + remap] = x + dx;
|
std::vector<int> inv_idx(ph * pw);
|
||||||
data[2 * num_patches + remap] = y + dy;
|
|
||||||
data[3 * num_patches + remap] = x + dx;
|
if (use_window_attn) {
|
||||||
ptr++;
|
const int attn_window_size = 112;
|
||||||
|
const int grid_window = attn_window_size / patch_size / merge_ratio;
|
||||||
|
int dst = 0;
|
||||||
|
// [num_vision_tokens, num_vision_tokens] attention mask tensor
|
||||||
|
std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
|
||||||
|
int mask_row = 0;
|
||||||
|
|
||||||
|
for (int y = 0; y < ph; y += grid_window) {
|
||||||
|
for (int x = 0; x < pw; x += grid_window) {
|
||||||
|
const int win_h = std::min(grid_window, ph - y);
|
||||||
|
const int win_w = std::min(grid_window, pw - x);
|
||||||
|
const int dst_0 = dst;
|
||||||
|
// group all tokens belong to the same window togather (to a continue range)
|
||||||
|
for (int dy = 0; dy < win_h; dy++) {
|
||||||
|
for (int dx = 0; dx < win_w; dx++) {
|
||||||
|
const int src = (y + dy) * pw + (x + dx);
|
||||||
|
GGML_ASSERT(src < (int)idx.size());
|
||||||
|
GGML_ASSERT(dst < (int)inv_idx.size());
|
||||||
|
idx [src] = dst;
|
||||||
|
inv_idx[dst] = src;
|
||||||
|
dst++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
|
||||||
|
int row_offset = mask_row * (ipw * iph);
|
||||||
|
std::fill(
|
||||||
|
mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
|
||||||
|
mask.begin() + row_offset + (dst * merge_ratio * merge_ratio),
|
||||||
|
0.0);
|
||||||
|
mask_row++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
set_input_i32("window_idx", idx);
|
||||||
|
set_input_i32("inv_window_idx", inv_idx);
|
||||||
|
set_input_f32("window_mask", mask);
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < ph * pw; i++) {
|
||||||
|
idx[i] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const int mpow = merge_ratio * merge_ratio;
|
||||||
|
std::vector<int> positions(num_positions * 4);
|
||||||
|
|
||||||
|
int ptr = 0;
|
||||||
|
for (int y = 0; y < iph; y += merge_ratio) {
|
||||||
|
for (int x = 0; x < ipw; x += merge_ratio) {
|
||||||
|
for (int dy = 0; dy < 2; dy++) {
|
||||||
|
for (int dx = 0; dx < 2; dx++) {
|
||||||
|
auto remap = idx[ptr / mpow];
|
||||||
|
remap = (remap * mpow) + (ptr % mpow);
|
||||||
|
|
||||||
|
positions[ remap] = y + dy;
|
||||||
|
positions[ num_patches + remap] = x + dx;
|
||||||
|
positions[2 * num_patches + remap] = y + dy;
|
||||||
|
positions[3 * num_patches + remap] = x + dx;
|
||||||
|
ptr++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
ggml_backend_tensor_set(positions, data, 0, ggml_nbytes(positions));
|
set_input_i32("positions", positions);
|
||||||
}
|
} break;
|
||||||
else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
case PROJECTOR_TYPE_PIXTRAL:
|
||||||
// do nothing
|
{
|
||||||
}
|
// set the 2D positions
|
||||||
else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
int n_patches_per_col = image_size_width / patch_size;
|
||||||
// do nothing
|
std::vector<int> pos_data(num_positions);
|
||||||
}
|
// dimension H
|
||||||
else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
|
for (int i = 0; i < num_positions; i++) {
|
||||||
// set the 2D positions
|
pos_data[i] = i / n_patches_per_col;
|
||||||
int n_patches_per_col = image_size_width / patch_size;
|
}
|
||||||
std::vector<int> pos_data(num_positions);
|
set_input_i32("pos_h", pos_data);
|
||||||
struct ggml_tensor * pos;
|
// dimension W
|
||||||
// dimension H
|
for (int i = 0; i < num_positions; i++) {
|
||||||
pos = ggml_graph_get_tensor(gf, "pos_h");
|
pos_data[i] = i % n_patches_per_col;
|
||||||
for (int i = 0; i < num_positions; i++) {
|
}
|
||||||
pos_data[i] = i / n_patches_per_col;
|
set_input_i32("pos_w", pos_data);
|
||||||
}
|
} break;
|
||||||
ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos));
|
case PROJECTOR_TYPE_GLM_EDGE:
|
||||||
// dimension W
|
{
|
||||||
pos = ggml_graph_get_tensor(gf, "pos_w");
|
|
||||||
for (int i = 0; i < num_positions; i++) {
|
|
||||||
pos_data[i] = i % n_patches_per_col;
|
|
||||||
}
|
|
||||||
ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos));
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
// llava and other models
|
// llava and other models
|
||||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
std::vector<int32_t> positions(num_positions);
|
||||||
|
|
||||||
int* positions_data = (int*)malloc(ggml_nbytes(positions));
|
|
||||||
for (int i = 0; i < num_positions; i++) {
|
for (int i = 0; i < num_positions; i++) {
|
||||||
positions_data[i] = i;
|
positions[i] = i;
|
||||||
}
|
}
|
||||||
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
|
set_input_i32("positions", positions);
|
||||||
free(positions_data);
|
} break;
|
||||||
|
case PROJECTOR_TYPE_MLP:
|
||||||
|
case PROJECTOR_TYPE_MLP_NORM:
|
||||||
|
case PROJECTOR_TYPE_LDP:
|
||||||
|
case PROJECTOR_TYPE_LDPV2:
|
||||||
|
{
|
||||||
|
// llava and other models
|
||||||
|
std::vector<int32_t> positions(num_positions);
|
||||||
|
for (int i = 0; i < num_positions; i++) {
|
||||||
|
positions[i] = i;
|
||||||
|
}
|
||||||
|
set_input_i32("positions", positions);
|
||||||
|
|
||||||
if (ctx->proj_type != PROJECTOR_TYPE_GLM_EDGE) {
|
|
||||||
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
|
|
||||||
// The patches vector is used to get rows to index into the embeds with;
|
// The patches vector is used to get rows to index into the embeds with;
|
||||||
// we should skip dim 0 only if we have CLS to avoid going out of bounds
|
// we should skip dim 0 only if we have CLS to avoid going out of bounds
|
||||||
// when retrieving the rows.
|
// when retrieving the rows.
|
||||||
int patch_offset = model.class_embedding ? 1 : 0;
|
int patch_offset = model.class_embedding ? 1 : 0;
|
||||||
int* patches_data = (int*)malloc(ggml_nbytes(patches));
|
std::vector<int32_t> patches(num_patches);
|
||||||
for (int i = 0; i < num_patches; i++) {
|
for (int i = 0; i < num_patches; i++) {
|
||||||
patches_data[i] = i + patch_offset;
|
patches[i] = i + patch_offset;
|
||||||
}
|
}
|
||||||
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
|
set_input_i32("patches", patches);
|
||||||
free(patches_data);
|
} break;
|
||||||
}
|
case PROJECTOR_TYPE_GEMMA3:
|
||||||
}
|
case PROJECTOR_TYPE_IDEFICS3:
|
||||||
}
|
|
||||||
|
|
||||||
if (use_window_attn && (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL)) {
|
|
||||||
struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
|
|
||||||
struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
|
|
||||||
struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
|
|
||||||
|
|
||||||
const int merge_ratio = 2;
|
|
||||||
const int attn_window_size = 112;
|
|
||||||
const int pw = image_size_width / patch_size / merge_ratio;
|
|
||||||
const int ph = image_size_height / patch_size / merge_ratio;
|
|
||||||
const int grid_window = attn_window_size / patch_size / merge_ratio;
|
|
||||||
const int ipw = image_size_width / patch_size;
|
|
||||||
const int iph = image_size_height / patch_size;
|
|
||||||
/*
|
|
||||||
pw * ph = number of tokens output by ViT after apply patch merger
|
|
||||||
ipw * ipw = number of vision token been processed inside ViT
|
|
||||||
*/
|
|
||||||
|
|
||||||
std::vector<int> idx(ph * pw);
|
|
||||||
std::vector<int> inv_idx(ph * pw);
|
|
||||||
int dst = 0;
|
|
||||||
// [num_vision_tokens, num_vision_tokens] attention mask tensor
|
|
||||||
std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
|
|
||||||
int mask_row = 0;
|
|
||||||
|
|
||||||
for (int y = 0; y < ph; y+=grid_window)
|
|
||||||
{
|
|
||||||
for (int x = 0; x < pw; x+=grid_window)
|
|
||||||
{
|
{
|
||||||
const int win_h = std::min(grid_window, ph - y);
|
// do nothing
|
||||||
const int win_w = std::min(grid_window, pw - x);
|
} break;
|
||||||
const int dst_0 = dst;
|
default:
|
||||||
// group all tokens belong to the same window togather (to a continue range)
|
GGML_ABORT("Unknown projector type");
|
||||||
for (int dy = 0; dy < win_h; dy++) {
|
|
||||||
for (int dx = 0; dx < win_w; dx++) {
|
|
||||||
const int src = (y + dy) * pw + (x + dx);
|
|
||||||
assert(src < (int)idx.size());
|
|
||||||
assert(dst < (int)inv_idx.size());
|
|
||||||
idx[src] = dst;
|
|
||||||
inv_idx[dst] = src;
|
|
||||||
dst++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
|
|
||||||
int row_offset = mask_row * (ipw * iph);
|
|
||||||
std::fill(
|
|
||||||
mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
|
|
||||||
mask.begin() + row_offset + (dst * merge_ratio * merge_ratio),
|
|
||||||
0.0);
|
|
||||||
mask_row++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
|
|
||||||
ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
|
|
||||||
ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads);
|
ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads);
|
||||||
|
@ -3537,7 +3494,7 @@ bool clip_is_glm(const struct clip_ctx * ctx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
|
bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
|
||||||
return ctx->proj_type == PROJECTOR_TYPE_QWEN2VL;
|
return ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool clip_is_llava(const struct clip_ctx * ctx) {
|
bool clip_is_llava(const struct clip_ctx * ctx) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue