diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index cf35fb86..ea3a951b 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2554,11 +2554,12 @@ class Qwen2VLModel(TextModel): except FileNotFoundError: self._set_vocab_gpt2() - def get_tensors(self) -> Iterator[tuple[str, Tensor]]: - for name, data in super().get_tensors(): - if name.startswith("visual."): - continue - yield name, data + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + if name.startswith("visual."): + # skip visual tensors + return [] + return [(self.map_tensor_name(name), data_torch)] @ModelBase.register("WavTokenizerDec") diff --git a/examples/llava/clip-impl.h b/examples/llava/clip-impl.h index 16d0a8ef..04bfcbb5 100644 --- a/examples/llava/clip-impl.h +++ b/examples/llava/clip-impl.h @@ -34,9 +34,14 @@ #define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor" #define KEY_PROJ_TYPE "clip.projector_type" +#define KEY_USE_GLU_MLP "clip.use_glu_mlp" // for qwen2.5vl +#define KEY_USE_RMS_NORM "clip.use_rms_norm" // for qwen2.5vl + #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" #define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" +#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern" +#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size" // @@ -55,6 +60,7 @@ #define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s" #define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s" #define TN_FFN_UP "%s.blk.%d.ffn_up.%s" +#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s" #define TN_LN_1 "%s.blk.%d.ln1.%s" #define TN_LN_2 "%s.blk.%d.ln2.%s" #define TN_LN_PRE "%s.pre_ln.%s" @@ -95,6 +101,7 @@ enum projector_type { PROJECTOR_TYPE_GEMMA3, PROJECTOR_TYPE_IDEFICS3, PROJECTOR_TYPE_PIXTRAL, + PROJECTOR_TYPE_QWEN25VL, PROJECTOR_TYPE_UNKNOWN, }; @@ -105,6 +112,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_MINICPMV, "resampler"}, { PROJECTOR_TYPE_GLM_EDGE, "adapter"}, { PROJECTOR_TYPE_QWEN2VL, "qwen2vl_merger"}, + { PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger"}, { PROJECTOR_TYPE_GEMMA3, "gemma3"}, { PROJECTOR_TYPE_IDEFICS3, "idefics3"}, { PROJECTOR_TYPE_PIXTRAL, "pixtral"}, diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index e8c01c68..b6a1f40e 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -28,6 +28,7 @@ #include #include #include +#include struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL}; @@ -169,6 +170,8 @@ struct clip_hparams { std::vector image_grid_pinpoints; int32_t image_crop_resolution; std::unordered_set vision_feature_layer; + int32_t attn_window_size; + int32_t n_wa_pattern; }; struct clip_layer { @@ -200,6 +203,9 @@ struct clip_layer { struct ggml_tensor * ff_down_w = nullptr; struct ggml_tensor * ff_down_b = nullptr; + struct ggml_tensor * ff_g_w = NULL; + struct ggml_tensor * ff_g_b = NULL; + // layernorm 2 struct ggml_tensor * ln_2_w = nullptr; struct ggml_tensor * ln_2_b = nullptr; @@ -319,6 +325,7 @@ struct clip_ctx { float image_std[3]; bool use_gelu = false; bool use_silu = false; + int32_t ftype = 1; gguf_context_ptr ctx_gguf; ggml_context_ptr ctx_data; @@ -762,6 +769,236 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i return gf; } +static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_image_f32_batch & imgs) { + const auto & model = ctx->vision_model; + const auto & hparams = model.hparams; + + const int image_size_width = imgs.entries[0]->nx; + const int image_size_height = imgs.entries[0]->ny; + + const bool use_mrope = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL; + const bool use_window_attn = hparams.n_wa_pattern > 0; + + const int n_wa_pattern = hparams.n_wa_pattern; + const int patch_size = hparams.patch_size; + const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); + const int patches_w = image_size_width / patch_size; + const int patches_h = image_size_height / patch_size; + const int num_positions = num_patches + (model.class_embedding ? 1 : 0); + const int num_position_ids = use_mrope ? num_positions * 4 : num_positions; + const int hidden_size = hparams.hidden_size; + const int n_head = hparams.n_head; + const int d_head = hidden_size / n_head; + const float eps = hparams.eps; + + int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; + + const int batch_size = imgs.entries.size(); + GGML_ASSERT(batch_size == 1); + + struct ggml_init_params params = { + /*.mem_size =*/ ctx->buf_compute_meta.size(), + /*.mem_buffer =*/ ctx->buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx0_ptr(ggml_init(params)); + auto ctx0 = ctx0_ptr.get(); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size); + ggml_set_name(inp_raw, "inp_raw"); + ggml_set_input(inp_raw); + + struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + + GGML_ASSERT(image_size_width % (patch_size * 2) == 0); + GGML_ASSERT(image_size_height % (patch_size * 2) == 0); + + auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + inp = ggml_add(ctx0, inp, inp_1); + + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b] + inp = ggml_reshape_4d( + ctx0, inp, + hidden_size * 2, patches_w / 2, patches_h, batch_size); + inp = ggml_reshape_4d( + ctx0, inp, + hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2)); + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3)); + inp = ggml_reshape_3d( + ctx0, inp, + hidden_size, patches_w * patches_h, batch_size); + + if (model.patch_bias) { + // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp)); + inp = ggml_add(ctx0, inp, model.patch_bias); + } + struct ggml_tensor * embeddings = inp; + struct ggml_tensor * window_mask = nullptr; + struct ggml_tensor * window_idx = nullptr; + struct ggml_tensor * inv_window_idx = nullptr; + + struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + // pre-layernorm + if (model.pre_ln_w) { + embeddings = ggml_rms_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "pre_ln"); + + embeddings = ggml_mul(ctx0, embeddings, model.pre_ln_w); + } + + if (use_window_attn) { + // handle window attention inputs + inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4); + ggml_set_name(inv_window_idx, "inv_window_idx"); + ggml_set_input(inv_window_idx); + // mask for window attention + window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, num_positions, num_positions); + ggml_set_name(window_mask, "window_mask"); + ggml_set_input(window_mask); + + // embeddings shape: [hidden_size, patches_w * patches_h, batch_size] + GGML_ASSERT(batch_size == 1); + embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4); + embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx); + embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size); + } + + // loop over layers + for (int il = 0; il < ctx->max_feature_layer; il++) { + struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states + + // rmsnorm1 + cur = ggml_rms_norm(ctx0, cur, eps); + cur = ggml_mul(ctx0, cur, model.layers[il].ln_1_w); + + // self-attention + { + + struct ggml_tensor * Q = + ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b); + + Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size); + Q = ggml_rope_multi( + ctx0, Q, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); + Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size); + + struct ggml_tensor * K = + ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b); + + K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); + K = ggml_rope_multi( + ctx0, K, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); + K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); + + struct ggml_tensor * V = + ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b); + + V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size); + V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); + V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); + + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + const bool full_attn = use_window_attn ? (il + 1) % n_wa_pattern == 0 : true; + if (full_attn) { + KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); + } else { + KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f / sqrtf((float)d_head), 0.0f); + } + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); + KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); + KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size); + } + + // attention output + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b); + + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, embeddings); + + embeddings = cur; // embeddings = residual, cur = hidden_states + + // rms norm2 + cur = ggml_rms_norm(ctx0, cur, eps); + cur = ggml_mul(ctx0, cur, model.layers[il].ln_2_w); + + // mlp + // ffn_up + auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); + cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_o_b); + + auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_g_w, cur); + cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_g_b); + // TODO : only 2 of these 3 are actually used, should we remove one of them? + if (ctx->use_gelu) { + cur_gate = ggml_gelu_inplace(ctx0, cur_gate); + } else if (ctx->use_silu) { + cur_gate = ggml_silu_inplace(ctx0, cur_gate); + } else { + cur_gate = ggml_gelu_quick_inplace(ctx0, cur_gate); + } + cur = ggml_mul(ctx0, cur_gate, cur_up); + + // ffn_down + cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); + + // residual 2 + cur = ggml_add(ctx0, embeddings, cur); + + embeddings = cur; + } + + // post-layernorm + if (model.post_ln_w) { + embeddings = ggml_rms_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "post_ln"); + + embeddings = ggml_mul(ctx0, embeddings, model.post_ln_w); + } + + embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size); + + embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); + + // GELU activation + embeddings = ggml_gelu(ctx0, embeddings); + + // Second linear layer + embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); + + if (use_window_attn) { + window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4); + ggml_set_name(window_idx, "window_idx"); + ggml_set_input(window_idx); + + // embeddings shape: [hidden_size, patches_w * patches_h, batch_size] + GGML_ASSERT(batch_size == 1); + embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4); + embeddings = ggml_get_rows(ctx0, embeddings, window_idx); + embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4, batch_size); + } + + // build the graph + ggml_build_forward_expand(gf, embeddings); + + return gf; +} + static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) { const auto & model = ctx->vision_model; const auto & hparams = model.hparams; @@ -1331,6 +1568,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 GGML_ASSERT(imgs.entries.size() == 1); res = clip_image_build_graph_pixtral(ctx, *imgs.entries[0]); } break; + case PROJECTOR_TYPE_QWEN25VL: + { + res = clip_image_build_graph_qwen25vl(ctx, imgs); + } break; default: { // TODO: we should have one build_* function per model @@ -1507,6 +1748,10 @@ struct clip_model_loader { { hparams.rope_theta = 10000.0f; } break; + case PROJECTOR_TYPE_QWEN25VL: + { + get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern); + } break; default: break; } @@ -1600,8 +1845,10 @@ struct clip_model_loader { // legacy naming (the in and out is reversed! don't ask me why) layer.ff_i_w = layer.ff_down_w; layer.ff_o_w = layer.ff_up_w; + layer.ff_g_w = layer.ff_gate_w; layer.ff_i_b = layer.ff_down_b; layer.ff_o_b = layer.ff_up_b; + layer.ff_g_b = layer.ff_gate_b; } switch (ctx_clip.proj_type) { @@ -1700,6 +1947,7 @@ struct clip_model_loader { vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H,"weight")); } break; case PROJECTOR_TYPE_QWEN2VL: + case PROJECTOR_TYPE_QWEN25VL: { vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); @@ -2651,7 +2899,7 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i else { GGML_ABORT("Unknown minicpmv version"); } - } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) { + } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) { int patch_size = params.patch_size * 2; int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0); int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0); @@ -2792,6 +3040,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima const int pos_w = ctx->load_image_size.width / patch_size; const int pos_h = ctx->load_image_size.height / patch_size; + const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl + { struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw"); std::vector inp_data(ggml_nelements(inp_raw)); @@ -2890,31 +3140,93 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // non-minicpmv models if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) { - struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); + // pw * ph = number of tokens output by ViT after apply patch merger + // ipw * ipw = number of vision token been processed inside ViT + const int merge_ratio = 2; + const int pw = image_size_width / patch_size / merge_ratio; + const int ph = image_size_height / patch_size / merge_ratio; + const int ipw = image_size_width / patch_size; + const int iph = image_size_height / patch_size; - const int pw = image_size_width / patch_size; - const int ph = image_size_height / patch_size; - int* positions_data = (int*)malloc(ggml_nbytes(positions)); + std::vector idx (ph * pw); + std::vector inv_idx(ph * pw); + + if (use_window_attn) { + const int attn_window_size = 112; + struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx"); + struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx"); + struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask"); + + const int grid_window = attn_window_size / patch_size / merge_ratio; + int dst = 0; + // [num_vision_tokens, num_vision_tokens] attention mask tensor + std::vector mask(pow(ipw * iph, 2), std::numeric_limits::lowest()); + int mask_row = 0; + + for (int y = 0; y < ph; y += grid_window) + { + for (int x = 0; x < pw; x += grid_window) + { + const int win_h = std::min(grid_window, ph - y); + const int win_w = std::min(grid_window, pw - x); + const int dst_0 = dst; + // group all tokens belong to the same window togather (to a continue range) + for (int dy = 0; dy < win_h; dy++) { + for (int dx = 0; dx < win_w; dx++) { + const int src = (y + dy) * pw + (x + dx); + assert(src < (int)idx.size()); + assert(dst < (int)inv_idx.size()); + idx [src] = dst; + inv_idx[dst] = src; + dst++; + } + } + + for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) { + int row_offset = mask_row * (ipw * iph); + std::fill( + mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio), + mask.begin() + row_offset + (dst * merge_ratio * merge_ratio), + 0.0); + mask_row++; + } + } + } + + ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx)); + ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx)); + ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask)); + } else { + std::iota(idx.begin(), idx.end(), 0); + std::iota(inv_idx.begin(), inv_idx.end(), 0); + } + + struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); + const int mpow = merge_ratio * merge_ratio; + std::vector positions_data(ggml_nelements(positions)); + int * data = positions_data.data(); int ptr = 0; - for (int y = 0; y < ph; y+=2) + for (int y = 0; y < iph; y += merge_ratio) { - for (int x = 0; x < pw; x+=2) + for (int x = 0; x < ipw; x += merge_ratio) { for (int dy = 0; dy < 2; dy++) { for (int dx = 0; dx < 2; dx++) { - positions_data[ptr] = y + dy; - positions_data[num_patches + ptr] = x + dx; - positions_data[num_patches * 2 + ptr] = y + dy; - positions_data[num_patches * 3 + ptr] = x + dx; + auto remap = idx[ptr / mpow]; + remap = remap * mpow + (ptr % mpow); + + data[ remap] = y + dy; + data[ num_patches + remap] = x + dx; + data[2 * num_patches + remap] = y + dy; + data[3 * num_patches + remap] = x + dx; ptr++; } } } } - ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); - free(positions_data); + ggml_backend_tensor_set(positions, data, 0, ggml_nbytes(positions)); } else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { // do nothing @@ -2967,6 +3279,65 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } } + if (use_window_attn && ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) { + struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx"); + struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx"); + struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask"); + + const int merge_ratio = 2; + const int attn_window_size = 112; + const int pw = image_size_width / patch_size / merge_ratio; + const int ph = image_size_height / patch_size / merge_ratio; + const int grid_window = attn_window_size / patch_size / merge_ratio; + const int ipw = image_size_width / patch_size; + const int iph = image_size_height / patch_size; + /* + pw * ph = number of tokens output by ViT after apply patch merger + ipw * ipw = number of vision token been processed inside ViT + */ + + std::vector idx(ph * pw); + std::vector inv_idx(ph * pw); + int dst = 0; + // [num_vision_tokens, num_vision_tokens] attention mask tensor + std::vector mask(pow(ipw * iph, 2), std::numeric_limits::lowest()); + int mask_row = 0; + + for (int y = 0; y < ph; y+=grid_window) + { + for (int x = 0; x < pw; x+=grid_window) + { + const int win_h = std::min(grid_window, ph - y); + const int win_w = std::min(grid_window, pw - x); + const int dst_0 = dst; + // group all tokens belong to the same window togather (to a continue range) + for (int dy = 0; dy < win_h; dy++) { + for (int dx = 0; dx < win_w; dx++) { + const int src = (y + dy) * pw + (x + dx); + assert(src < (int)idx.size()); + assert(dst < (int)inv_idx.size()); + idx[src] = dst; + inv_idx[dst] = src; + dst++; + } + } + + for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) { + int row_offset = mask_row * (ipw * iph); + std::fill( + mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio), + mask.begin() + row_offset + (dst * merge_ratio * merge_ratio), + 0.0); + mask_row++; + } + } + } + + ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx)); + ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx)); + ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask)); + } + ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads); auto status = ggml_backend_sched_graph_compute(ctx->sched.get(), gf); @@ -3142,6 +3513,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_GLM_EDGE: return ctx->vision_model.mm_model_mlp_3_w->ne[1]; case PROJECTOR_TYPE_QWEN2VL: + case PROJECTOR_TYPE_QWEN25VL: return ctx->vision_model.mm_1_b->ne[0]; case PROJECTOR_TYPE_GEMMA3: return ctx->vision_model.mm_input_proj_w->ne[0]; diff --git a/examples/llava/qwen2_vl_surgery.py b/examples/llava/qwen2_vl_surgery.py index c87606b4..7951a6fa 100644 --- a/examples/llava/qwen2_vl_surgery.py +++ b/examples/llava/qwen2_vl_surgery.py @@ -1,14 +1,16 @@ import argparse -from typing import Dict +from typing import Dict, List, Optional import torch import numpy as np from gguf import * from transformers import ( - Qwen2VLForConditionalGeneration, - Qwen2VLProcessor, AutoProcessor, - Qwen2VLConfig + Qwen2VLConfig, + Qwen2VLProcessor, + Qwen2VLForConditionalGeneration, + Qwen2_5_VLConfig, # type: ignore[reportAttributeAccessIssue] + Qwen2_5_VLForConditionalGeneration, # type: ignore[reportAttributeAccessIssue] ) @@ -19,61 +21,93 @@ def k(raw_key: str, arch: str) -> str: return raw_key.format(arch=arch) -def to_gguf_name(name: str) -> str: - og = name - name = name.replace("text_model", "t").replace("vision_model", "v") - name = name.replace("blocks", "blk").replace("embeddings.", "") - name = name.replace("attn.", "attn_") - name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.") - # name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln") - name = name.replace("norm1", "ln1").replace("norm2", "ln2") - name = name.replace("merger.mlp", 'mm') - print(f"[to_gguf_name] {og} --> {name}") - return name +def get_n_wa_pattern(fullatt_block_indexes: Optional[List[int]]): + if fullatt_block_indexes is None: + return 0 + n_wa = fullatt_block_indexes[0] + for a, b in zip(fullatt_block_indexes, fullatt_block_indexes[1:]): + if b - a - 1 != n_wa: + raise ValueError( + f"window/full attention layer should have fix pattern of " + f"for each full-attention layer followed by {n_wa} window-attention layers" + ) + return n_wa + 1 -def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]: - vision_model = qwen2vl.visual - tensor_map = {} - for name, ten in vision_model.state_dict().items(): - ten = ten.numpy() - if 'qkv' in name: - if ten.ndim == 2: # weight - c3, _ = ten.shape - else: # bias - c3 = ten.shape[0] - assert c3 % 3 == 0 - c = c3 // 3 - wq = ten[:c] - wk = ten[c: c * 2] - wv = ten[c * 2:] - tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq - tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk - tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv - elif 'merger' in name: - if name.endswith("ln_q.weight"): - tensor_map['v.post_ln.weight'] = ten - elif name.endswith("ln_q.bias"): - tensor_map['v.post_ln.bias'] = ten +class VL2: + + @staticmethod + def to_gguf_name(name: str) -> str: + og = name + name = name.replace("text_model", "t").replace("vision_model", "v") + name = name.replace("blocks", "blk").replace("embeddings.", "") + name = name.replace("attn.", "attn_") + name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.") + # name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln") + name = name.replace("norm1", "ln1").replace("norm2", "ln2") + name = name.replace("merger.mlp", 'mm') + print(f"[to_gguf_name] {og} --> {name}") + return name + + @classmethod + def find_vision_tensors(cls, qwen2vl, dtype) -> Dict[str, np.ndarray]: + vision_model = qwen2vl.visual + tensor_map = {} + for name, ten in vision_model.state_dict().items(): + ten = ten.numpy() + if 'qkv' in name: + if ten.ndim == 2: # weight + c3, _ = ten.shape + else: # bias + c3 = ten.shape[0] + assert c3 % 3 == 0 + c = c3 // 3 + wq = ten[:c] + wk = ten[c: c * 2] + wv = ten[c * 2:] + tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq + tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk + tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv + elif 'merger' in name: + if name.endswith("ln_q.weight"): + tensor_map['v.post_ln.weight'] = ten + elif name.endswith("ln_q.bias"): + tensor_map['v.post_ln.bias'] = ten + else: + # "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias" + tensor_map[cls.to_gguf_name(name)] = ten + elif 'patch_embed.proj.weight' in name: + # NOTE: split Conv3D into Conv2Ds + c1, c2, kt, kh, kw = ten.shape + assert kt == 2, "Current implmentation only support temporal_patch_size of 2" + tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...] + tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...] else: - # "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias" - tensor_map[to_gguf_name(name)] = ten - elif 'patch_embed.proj.weight' in name: - # NOTE: split Conv3D into Conv2Ds - c1, c2, kt, kh, kw = ten.shape - assert kt == 2, "Current implmentation only support temporal_patch_size of 2" - tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...] - tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...] - else: - tensor_map[to_gguf_name(f"vision_model.{name}")] = ten + tensor_map[cls.to_gguf_name(f"vision_model.{name}")] = ten - for new_name, ten in tensor_map.items(): - if ten.ndim <= 1 or new_name.endswith("_norm.weight"): - tensor_map[new_name] = ten.astype(np.float32) - else: - tensor_map[new_name] = ten.astype(dtype) - tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder - return tensor_map + for new_name, ten in tensor_map.items(): + if ten.ndim <= 1 or new_name.endswith("_norm.weight"): + tensor_map[new_name] = ten.astype(np.float32) + else: + tensor_map[new_name] = ten.astype(dtype) + tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder + return tensor_map + + +class VL25(VL2): + + @staticmethod + def to_gguf_name(name: str) -> str: + og = name + name = name.replace("text_model", "t").replace("vision_model", "v") + name = name.replace("blocks", "blk").replace("embeddings.", "") + name = name.replace("attn.", "attn_") + name = name.replace("mlp.down_proj", "ffn_down").replace("mlp.up_proj", "ffn_up") + name = name.replace("mlp.gate_proj", "ffn_gate").replace("proj.", "out.") + name = name.replace("norm1", "ln1").replace("norm2", "ln2") + name = name.replace("merger.mlp", 'mm') + print(f"[vl25][to_gguf_name] {og} --> {name}") + return name def main(args): @@ -82,7 +116,7 @@ def main(args): np_dtype = np.float32 ftype = 0 elif args.data_type == 'fp16': - dtype = torch.float32 + dtype = torch.float16 np_dtype = np.float16 ftype = 1 else: @@ -92,11 +126,18 @@ def main(args): model_path = "" model_name = args.model_name print("model_name: ", model_name) - qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained( - model_name, torch_dtype=dtype, device_map="cpu" - ) - cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType] - vcfg = cfg.vision_config + if args.model_type == "qwen2vl": + qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained( + model_name, torch_dtype=dtype, device_map="cpu" + ) + cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType] + vcfg = cfg.vision_config + else: + qwen2vl = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model_name, torch_dtype=dtype, device_map="cpu" + ) + cfg: Qwen2_5_VLConfig = qwen2vl.config # type: ignore[reportAssignmentType] + vcfg = cfg.vision_config if os.path.isdir(model_name): local_model = True @@ -113,7 +154,6 @@ def main(args): fout.add_bool("clip.has_text_encoder", False) fout.add_bool("clip.has_vision_encoder", True) fout.add_bool("clip.has_qwen2vl_merger", True) - fout.add_string("clip.projector_type", "qwen2vl_merger") print(cfg.vision_config) if 'silu' in cfg.vision_config.hidden_act.lower(): @@ -125,14 +165,25 @@ def main(args): else: raise ValueError() - tensor_map = find_vision_tensors(qwen2vl, np_dtype) + if args.model_type == "qwen2.5vl": + fout.add_uint32("clip.vision.n_wa_pattern", get_n_wa_pattern(vcfg.fullatt_block_indexes)) + fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.hidden_size) + fout.add_uint32("clip.vision.projection_dim", vcfg.out_hidden_size) + fout.add_string("clip.projector_type", "qwen2.5vl_merger") + else: + fout.add_string("clip.projector_type", "qwen2vl_merger") + fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim) + fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size) + + if args.model_type == "qwen2.5vl": + tensor_map = VL25.find_vision_tensors(qwen2vl, np_dtype) + else: + tensor_map = VL2.find_vision_tensors(qwen2vl, np_dtype) for name, data in tensor_map.items(): fout.add_tensor(name, data) fout.add_uint32("clip.vision.patch_size", vcfg.patch_size) fout.add_uint32("clip.vision.image_size", 14 * 40) # some reasonable size that is divable by (14*2) - fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim) - fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size) fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads) fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth) @@ -160,6 +211,7 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct") + parser.add_argument("--model_type", nargs='?', choices=['qwen2vl', 'qwen2.5vl'], default="qwen2vl") parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32") args = parser.parse_args() main(args) diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index eca7b7f1..cf427108 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -23,6 +23,9 @@ #include #include #include +#include +#include +#include static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, @@ -367,14 +370,14 @@ static void debug_test_mrope_2d() { // 1. Initialize backend ggml_backend_t backend = NULL; std::string backend_name = ""; -#ifdef GGML_USE_CUDA - fprintf(stderr, "%s: using CUDA backend\n", __func__); - backend = ggml_backend_cuda_init(0); // init device 0 - backend_name = "cuda"; - if (!backend) { - fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); - } -#endif +// #ifdef GGML_USE_CUDA +// fprintf(stderr, "%s: using CUDA backend\n", __func__); +// backend = ggml_backend_cuda_init(0); // init device 0 +// backend_name = "cuda"; +// if (!backend) { +// fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); +// } +// #endif // if there aren't GPU Backends fallback to CPU backend if (!backend) { backend = ggml_backend_cpu_init(); @@ -483,28 +486,82 @@ static void debug_test_mrope_2d() { ggml_backend_free(backend); } -static void debug_dump_img_embed(struct llava_context * ctx_llava) { - int n_embd = llama_model_n_embd(llama_get_model(ctx_llava->ctx_llama)); - int ne = n_embd * 4; - float vals[56 * 56 * 3]; +enum model_output_type { + conv3d, + patch_embed, + patch_win_attn_scatter, + first_attn_layer, + last_attn_layer, + attn_softmax, + final_layer, +}; + +static void debug_dump_img_embed(struct llava_context * ctx_llava, model_output_type output_type) { + constexpr int ih = 140; + constexpr int iw = 196; + // constexpr int ih = 56; + // constexpr int iw = 56; + // int n_embd = llama_model_n_embd(llama_get_model(ctx_llava->ctx_llama)); + int n_embd = 1280; + int merge = 1; + if (output_type == model_output_type::final_layer) { + n_embd = 2048; + merge = 2; + } + else if (output_type == model_output_type::attn_softmax) { + merge = 1; + n_embd = (ih/14/merge) * (iw/14/merge) * 16; + } + + int ne = (ih/14/merge) * (iw/14/merge) * n_embd; + float vals[iw * ih * 3]; // float embd[ne]; std::vector embd; embd.resize(ne); - for (int i = 0; i < 56*56; i++) + for (int i = 0; i < iw*ih; i++) { for (int c = 0; c < 3; c++) - vals[i * 3 + c] = (float)(i % (56 * 56)) / (56*56); + vals[i * 3 + c] = (float)i / (iw*ih); } - clip_encode_float_image(ctx_llava->ctx_clip, 16, vals, 56, 56, embd.data()); + clip_encode_float_image(ctx_llava->ctx_clip, 8, vals, ih, iw, embd.data()); - std::ofstream outFile("img_embed.bin", std::ios::binary); + std::string file_postfix = ""; + switch (output_type) + { + case model_output_type::conv3d: + file_postfix = "conv3d"; + break; + case model_output_type::patch_embed: + file_postfix = "patch_embed"; + break; + case model_output_type::patch_win_attn_scatter: + file_postfix = "scatter"; + break; + case model_output_type::first_attn_layer: + file_postfix = "first_attn"; + break; + case model_output_type::last_attn_layer: + file_postfix = "last_attn"; + break; + case model_output_type::attn_softmax: + file_postfix = "attn_softmax"; + break; + case model_output_type::final_layer: + file_postfix = "final"; + break; + default: + break; + } + auto output_path = "img_embed_" + file_postfix + ".bin"; + + std::ofstream outFile(output_path, std::ios::binary); if (outFile.is_open()) { outFile.write(reinterpret_cast(embd.data()), ne * sizeof(float)); outFile.close(); - std::cout << "Data successfully written to mrope.bin" << std::endl; + std::cout << "Data successfully written to ::[ " << output_path << std::endl; } else { std::cerr << "Error opening file!" << std::endl; } @@ -551,8 +608,9 @@ int main(int argc, char ** argv) { } else if (params.image[0].empty()) { auto ctx_llava = llava_init_context(¶ms, model); - debug_test_mrope_2d(); - debug_dump_img_embed(ctx_llava); + // debug_test_mrope_2d(); + debug_dump_img_embed(ctx_llava, model_output_type::final_layer); + // debug_dump_img_embed(ctx_llava, model_output_type::last_attn_layer); llama_perf_context_print(ctx_llava->ctx_llama); ctx_llava->model = NULL; diff --git a/examples/llava/tests.sh b/examples/llava/tests.sh index e612857e..4002f9d5 100755 --- a/examples/llava/tests.sh +++ b/examples/llava/tests.sh @@ -55,6 +55,7 @@ add_test "llama-mtmd-cli" "second-state/MiniCPM-Llama3-V-2_5-GGUF:Q2_K" # mode add_test "llama-mtmd-cli" "openbmb/MiniCPM-V-2_6-gguf:Q2_K" add_test "llama-mtmd-cli" "openbmb/MiniCPM-o-2_6-gguf:Q4_0" add_test "llama-qwen2vl-cli" "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M" +add_test "llama-qwen2vl-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M" # to test the big models, run: ./tests.sh big add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"