clip : Add Qwen2.5VL support (#12402)

* implment vision model architecture, gguf convertor

* handle window attention inputs

* add debug utils

* fix few incorrect tensor memory layout

* move position id remap out of ggml to avoid int32 cuda operations

* cleaning up

* ignore transformers Qwen2_5_xxx type check

* remove not so often use `qwen2vl-cli` debug functions

* remove commented-out code blocks

* fix attn weight scaling after rebase

* add `PROJECTOR_TYPE_QWEN2_5_VL`

* remove `KEY_USE_GLU_MLP`, `KEY_USE_RMS_NORM`

* replace `KEY_FULLATTN_BLK_IDX` with `KEY_WIN_ATTN_PATTERN`

* remove `attn_window_size` from gguf

* fix model conversion

* clean up

* fix merging problem

* add test

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
This commit is contained in:
HimariO 2025-04-27 16:10:34 +08:00 committed by GitHub
parent 2d451c8059
commit ca2bb89eac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 594 additions and 102 deletions

View file

@ -28,6 +28,7 @@
#include <cinttypes>
#include <limits>
#include <array>
#include <numeric>
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
@ -169,6 +170,8 @@ struct clip_hparams {
std::vector<int32_t> image_grid_pinpoints;
int32_t image_crop_resolution;
std::unordered_set<int32_t> vision_feature_layer;
int32_t attn_window_size;
int32_t n_wa_pattern;
};
struct clip_layer {
@ -200,6 +203,9 @@ struct clip_layer {
struct ggml_tensor * ff_down_w = nullptr;
struct ggml_tensor * ff_down_b = nullptr;
struct ggml_tensor * ff_g_w = NULL;
struct ggml_tensor * ff_g_b = NULL;
// layernorm 2
struct ggml_tensor * ln_2_w = nullptr;
struct ggml_tensor * ln_2_b = nullptr;
@ -319,6 +325,7 @@ struct clip_ctx {
float image_std[3];
bool use_gelu = false;
bool use_silu = false;
int32_t ftype = 1;
gguf_context_ptr ctx_gguf;
ggml_context_ptr ctx_data;
@ -762,6 +769,236 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
return gf;
}
static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
const auto & model = ctx->vision_model;
const auto & hparams = model.hparams;
const int image_size_width = imgs.entries[0]->nx;
const int image_size_height = imgs.entries[0]->ny;
const bool use_mrope = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL;
const bool use_window_attn = hparams.n_wa_pattern > 0;
const int n_wa_pattern = hparams.n_wa_pattern;
const int patch_size = hparams.patch_size;
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
const int patches_w = image_size_width / patch_size;
const int patches_h = image_size_height / patch_size;
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
const int num_position_ids = use_mrope ? num_positions * 4 : num_positions;
const int hidden_size = hparams.hidden_size;
const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head;
const float eps = hparams.eps;
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
const int batch_size = imgs.entries.size();
GGML_ASSERT(batch_size == 1);
struct ggml_init_params params = {
/*.mem_size =*/ ctx->buf_compute_meta.size(),
/*.mem_buffer =*/ ctx->buf_compute_meta.data(),
/*.no_alloc =*/ true,
};
ggml_context_ptr ctx0_ptr(ggml_init(params));
auto ctx0 = ctx0_ptr.get();
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size);
ggml_set_name(inp_raw, "inp_raw");
ggml_set_input(inp_raw);
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
GGML_ASSERT(image_size_width % (patch_size * 2) == 0);
GGML_ASSERT(image_size_height % (patch_size * 2) == 0);
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
inp = ggml_add(ctx0, inp, inp_1);
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b]
inp = ggml_reshape_4d(
ctx0, inp,
hidden_size * 2, patches_w / 2, patches_h, batch_size);
inp = ggml_reshape_4d(
ctx0, inp,
hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
inp = ggml_reshape_3d(
ctx0, inp,
hidden_size, patches_w * patches_h, batch_size);
if (model.patch_bias) {
// inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
inp = ggml_add(ctx0, inp, model.patch_bias);
}
struct ggml_tensor * embeddings = inp;
struct ggml_tensor * window_mask = nullptr;
struct ggml_tensor * window_idx = nullptr;
struct ggml_tensor * inv_window_idx = nullptr;
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
ggml_set_name(positions, "positions");
ggml_set_input(positions);
// pre-layernorm
if (model.pre_ln_w) {
embeddings = ggml_rms_norm(ctx0, embeddings, eps);
ggml_set_name(embeddings, "pre_ln");
embeddings = ggml_mul(ctx0, embeddings, model.pre_ln_w);
}
if (use_window_attn) {
// handle window attention inputs
inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
ggml_set_name(inv_window_idx, "inv_window_idx");
ggml_set_input(inv_window_idx);
// mask for window attention
window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, num_positions, num_positions);
ggml_set_name(window_mask, "window_mask");
ggml_set_input(window_mask);
// embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
GGML_ASSERT(batch_size == 1);
embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4);
embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx);
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size);
}
// loop over layers
for (int il = 0; il < ctx->max_feature_layer; il++) {
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
// rmsnorm1
cur = ggml_rms_norm(ctx0, cur, eps);
cur = ggml_mul(ctx0, cur, model.layers[il].ln_1_w);
// self-attention
{
struct ggml_tensor * Q =
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
Q = ggml_rope_multi(
ctx0, Q, positions, nullptr,
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
struct ggml_tensor * K =
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
K = ggml_rope_multi(
ctx0, K, positions, nullptr,
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
struct ggml_tensor * V =
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
const bool full_attn = use_window_attn ? (il + 1) % n_wa_pattern == 0 : true;
if (full_attn) {
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
} else {
KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f / sqrtf((float)d_head), 0.0f);
}
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size);
}
// attention output
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
// re-add the layer input, e.g., residual
cur = ggml_add(ctx0, cur, embeddings);
embeddings = cur; // embeddings = residual, cur = hidden_states
// rms norm2
cur = ggml_rms_norm(ctx0, cur, eps);
cur = ggml_mul(ctx0, cur, model.layers[il].ln_2_w);
// mlp
// ffn_up
auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_o_b);
auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_g_w, cur);
cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_g_b);
// TODO : only 2 of these 3 are actually used, should we remove one of them?
if (ctx->use_gelu) {
cur_gate = ggml_gelu_inplace(ctx0, cur_gate);
} else if (ctx->use_silu) {
cur_gate = ggml_silu_inplace(ctx0, cur_gate);
} else {
cur_gate = ggml_gelu_quick_inplace(ctx0, cur_gate);
}
cur = ggml_mul(ctx0, cur_gate, cur_up);
// ffn_down
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
// residual 2
cur = ggml_add(ctx0, embeddings, cur);
embeddings = cur;
}
// post-layernorm
if (model.post_ln_w) {
embeddings = ggml_rms_norm(ctx0, embeddings, eps);
ggml_set_name(embeddings, "post_ln");
embeddings = ggml_mul(ctx0, embeddings, model.post_ln_w);
}
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
// GELU activation
embeddings = ggml_gelu(ctx0, embeddings);
// Second linear layer
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
if (use_window_attn) {
window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
ggml_set_name(window_idx, "window_idx");
ggml_set_input(window_idx);
// embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
GGML_ASSERT(batch_size == 1);
embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4);
embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4, batch_size);
}
// build the graph
ggml_build_forward_expand(gf, embeddings);
return gf;
}
static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
const auto & model = ctx->vision_model;
const auto & hparams = model.hparams;
@ -1331,6 +1568,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
GGML_ASSERT(imgs.entries.size() == 1);
res = clip_image_build_graph_pixtral(ctx, *imgs.entries[0]);
} break;
case PROJECTOR_TYPE_QWEN25VL:
{
res = clip_image_build_graph_qwen25vl(ctx, imgs);
} break;
default:
{
// TODO: we should have one build_* function per model
@ -1507,6 +1748,10 @@ struct clip_model_loader {
{
hparams.rope_theta = 10000.0f;
} break;
case PROJECTOR_TYPE_QWEN25VL:
{
get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern);
} break;
default:
break;
}
@ -1600,8 +1845,10 @@ struct clip_model_loader {
// legacy naming (the in and out is reversed! don't ask me why)
layer.ff_i_w = layer.ff_down_w;
layer.ff_o_w = layer.ff_up_w;
layer.ff_g_w = layer.ff_gate_w;
layer.ff_i_b = layer.ff_down_b;
layer.ff_o_b = layer.ff_up_b;
layer.ff_g_b = layer.ff_gate_b;
}
switch (ctx_clip.proj_type) {
@ -1700,6 +1947,7 @@ struct clip_model_loader {
vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H,"weight"));
} break;
case PROJECTOR_TYPE_QWEN2VL:
case PROJECTOR_TYPE_QWEN25VL:
{
vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
@ -2651,7 +2899,7 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
else {
GGML_ABORT("Unknown minicpmv version");
}
} else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
} else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
int patch_size = params.patch_size * 2;
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
@ -2792,6 +3040,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
const int pos_w = ctx->load_image_size.width / patch_size;
const int pos_h = ctx->load_image_size.height / patch_size;
const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl
{
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
std::vector<float> inp_data(ggml_nelements(inp_raw));
@ -2890,31 +3140,93 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
// non-minicpmv models
if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
// pw * ph = number of tokens output by ViT after apply patch merger
// ipw * ipw = number of vision token been processed inside ViT
const int merge_ratio = 2;
const int pw = image_size_width / patch_size / merge_ratio;
const int ph = image_size_height / patch_size / merge_ratio;
const int ipw = image_size_width / patch_size;
const int iph = image_size_height / patch_size;
const int pw = image_size_width / patch_size;
const int ph = image_size_height / patch_size;
int* positions_data = (int*)malloc(ggml_nbytes(positions));
std::vector<int> idx (ph * pw);
std::vector<int> inv_idx(ph * pw);
if (use_window_attn) {
const int attn_window_size = 112;
struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
const int grid_window = attn_window_size / patch_size / merge_ratio;
int dst = 0;
// [num_vision_tokens, num_vision_tokens] attention mask tensor
std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
int mask_row = 0;
for (int y = 0; y < ph; y += grid_window)
{
for (int x = 0; x < pw; x += grid_window)
{
const int win_h = std::min(grid_window, ph - y);
const int win_w = std::min(grid_window, pw - x);
const int dst_0 = dst;
// group all tokens belong to the same window togather (to a continue range)
for (int dy = 0; dy < win_h; dy++) {
for (int dx = 0; dx < win_w; dx++) {
const int src = (y + dy) * pw + (x + dx);
assert(src < (int)idx.size());
assert(dst < (int)inv_idx.size());
idx [src] = dst;
inv_idx[dst] = src;
dst++;
}
}
for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
int row_offset = mask_row * (ipw * iph);
std::fill(
mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
mask.begin() + row_offset + (dst * merge_ratio * merge_ratio),
0.0);
mask_row++;
}
}
}
ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
} else {
std::iota(idx.begin(), idx.end(), 0);
std::iota(inv_idx.begin(), inv_idx.end(), 0);
}
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
const int mpow = merge_ratio * merge_ratio;
std::vector<int> positions_data(ggml_nelements(positions));
int * data = positions_data.data();
int ptr = 0;
for (int y = 0; y < ph; y+=2)
for (int y = 0; y < iph; y += merge_ratio)
{
for (int x = 0; x < pw; x+=2)
for (int x = 0; x < ipw; x += merge_ratio)
{
for (int dy = 0; dy < 2; dy++) {
for (int dx = 0; dx < 2; dx++) {
positions_data[ptr] = y + dy;
positions_data[num_patches + ptr] = x + dx;
positions_data[num_patches * 2 + ptr] = y + dy;
positions_data[num_patches * 3 + ptr] = x + dx;
auto remap = idx[ptr / mpow];
remap = remap * mpow + (ptr % mpow);
data[ remap] = y + dy;
data[ num_patches + remap] = x + dx;
data[2 * num_patches + remap] = y + dy;
data[3 * num_patches + remap] = x + dx;
ptr++;
}
}
}
}
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
free(positions_data);
ggml_backend_tensor_set(positions, data, 0, ggml_nbytes(positions));
}
else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
// do nothing
@ -2967,6 +3279,65 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
}
if (use_window_attn && ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
const int merge_ratio = 2;
const int attn_window_size = 112;
const int pw = image_size_width / patch_size / merge_ratio;
const int ph = image_size_height / patch_size / merge_ratio;
const int grid_window = attn_window_size / patch_size / merge_ratio;
const int ipw = image_size_width / patch_size;
const int iph = image_size_height / patch_size;
/*
pw * ph = number of tokens output by ViT after apply patch merger
ipw * ipw = number of vision token been processed inside ViT
*/
std::vector<int> idx(ph * pw);
std::vector<int> inv_idx(ph * pw);
int dst = 0;
// [num_vision_tokens, num_vision_tokens] attention mask tensor
std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
int mask_row = 0;
for (int y = 0; y < ph; y+=grid_window)
{
for (int x = 0; x < pw; x+=grid_window)
{
const int win_h = std::min(grid_window, ph - y);
const int win_w = std::min(grid_window, pw - x);
const int dst_0 = dst;
// group all tokens belong to the same window togather (to a continue range)
for (int dy = 0; dy < win_h; dy++) {
for (int dx = 0; dx < win_w; dx++) {
const int src = (y + dy) * pw + (x + dx);
assert(src < (int)idx.size());
assert(dst < (int)inv_idx.size());
idx[src] = dst;
inv_idx[dst] = src;
dst++;
}
}
for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
int row_offset = mask_row * (ipw * iph);
std::fill(
mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
mask.begin() + row_offset + (dst * merge_ratio * merge_ratio),
0.0);
mask_row++;
}
}
}
ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
}
ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads);
auto status = ggml_backend_sched_graph_compute(ctx->sched.get(), gf);
@ -3142,6 +3513,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
case PROJECTOR_TYPE_GLM_EDGE:
return ctx->vision_model.mm_model_mlp_3_w->ne[1];
case PROJECTOR_TYPE_QWEN2VL:
case PROJECTOR_TYPE_QWEN25VL:
return ctx->vision_model.mm_1_b->ne[0];
case PROJECTOR_TYPE_GEMMA3:
return ctx->vision_model.mm_input_proj_w->ne[0];