clip : Add Qwen2.5VL support (#12402)
* implment vision model architecture, gguf convertor * handle window attention inputs * add debug utils * fix few incorrect tensor memory layout * move position id remap out of ggml to avoid int32 cuda operations * cleaning up * ignore transformers Qwen2_5_xxx type check * remove not so often use `qwen2vl-cli` debug functions * remove commented-out code blocks * fix attn weight scaling after rebase * add `PROJECTOR_TYPE_QWEN2_5_VL` * remove `KEY_USE_GLU_MLP`, `KEY_USE_RMS_NORM` * replace `KEY_FULLATTN_BLK_IDX` with `KEY_WIN_ATTN_PATTERN` * remove `attn_window_size` from gguf * fix model conversion * clean up * fix merging problem * add test --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
This commit is contained in:
parent
2d451c8059
commit
ca2bb89eac
6 changed files with 594 additions and 102 deletions
|
@ -2554,11 +2554,12 @@ class Qwen2VLModel(TextModel):
|
|||
except FileNotFoundError:
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||
for name, data in super().get_tensors():
|
||||
if name.startswith("visual."):
|
||||
continue
|
||||
yield name, data
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
if name.startswith("visual."):
|
||||
# skip visual tensors
|
||||
return []
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register("WavTokenizerDec")
|
||||
|
|
|
@ -34,9 +34,14 @@
|
|||
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
|
||||
#define KEY_PROJ_TYPE "clip.projector_type"
|
||||
|
||||
#define KEY_USE_GLU_MLP "clip.use_glu_mlp" // for qwen2.5vl
|
||||
#define KEY_USE_RMS_NORM "clip.use_rms_norm" // for qwen2.5vl
|
||||
|
||||
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
|
||||
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
|
||||
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
|
||||
#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern"
|
||||
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
|
||||
|
||||
|
||||
//
|
||||
|
@ -55,6 +60,7 @@
|
|||
#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s"
|
||||
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
|
||||
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
|
||||
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
|
||||
#define TN_LN_1 "%s.blk.%d.ln1.%s"
|
||||
#define TN_LN_2 "%s.blk.%d.ln2.%s"
|
||||
#define TN_LN_PRE "%s.pre_ln.%s"
|
||||
|
@ -95,6 +101,7 @@ enum projector_type {
|
|||
PROJECTOR_TYPE_GEMMA3,
|
||||
PROJECTOR_TYPE_IDEFICS3,
|
||||
PROJECTOR_TYPE_PIXTRAL,
|
||||
PROJECTOR_TYPE_QWEN25VL,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
|
@ -105,6 +112,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
|||
{ PROJECTOR_TYPE_MINICPMV, "resampler"},
|
||||
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
|
||||
{ PROJECTOR_TYPE_QWEN2VL, "qwen2vl_merger"},
|
||||
{ PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger"},
|
||||
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
|
||||
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
|
||||
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include <cinttypes>
|
||||
#include <limits>
|
||||
#include <array>
|
||||
#include <numeric>
|
||||
|
||||
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
|
||||
|
||||
|
@ -169,6 +170,8 @@ struct clip_hparams {
|
|||
std::vector<int32_t> image_grid_pinpoints;
|
||||
int32_t image_crop_resolution;
|
||||
std::unordered_set<int32_t> vision_feature_layer;
|
||||
int32_t attn_window_size;
|
||||
int32_t n_wa_pattern;
|
||||
};
|
||||
|
||||
struct clip_layer {
|
||||
|
@ -200,6 +203,9 @@ struct clip_layer {
|
|||
struct ggml_tensor * ff_down_w = nullptr;
|
||||
struct ggml_tensor * ff_down_b = nullptr;
|
||||
|
||||
struct ggml_tensor * ff_g_w = NULL;
|
||||
struct ggml_tensor * ff_g_b = NULL;
|
||||
|
||||
// layernorm 2
|
||||
struct ggml_tensor * ln_2_w = nullptr;
|
||||
struct ggml_tensor * ln_2_b = nullptr;
|
||||
|
@ -319,6 +325,7 @@ struct clip_ctx {
|
|||
float image_std[3];
|
||||
bool use_gelu = false;
|
||||
bool use_silu = false;
|
||||
int32_t ftype = 1;
|
||||
|
||||
gguf_context_ptr ctx_gguf;
|
||||
ggml_context_ptr ctx_data;
|
||||
|
@ -762,6 +769,236 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
|
|||
return gf;
|
||||
}
|
||||
|
||||
static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
|
||||
const auto & model = ctx->vision_model;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int image_size_width = imgs.entries[0]->nx;
|
||||
const int image_size_height = imgs.entries[0]->ny;
|
||||
|
||||
const bool use_mrope = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL;
|
||||
const bool use_window_attn = hparams.n_wa_pattern > 0;
|
||||
|
||||
const int n_wa_pattern = hparams.n_wa_pattern;
|
||||
const int patch_size = hparams.patch_size;
|
||||
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
|
||||
const int patches_w = image_size_width / patch_size;
|
||||
const int patches_h = image_size_height / patch_size;
|
||||
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
|
||||
const int num_position_ids = use_mrope ? num_positions * 4 : num_positions;
|
||||
const int hidden_size = hparams.hidden_size;
|
||||
const int n_head = hparams.n_head;
|
||||
const int d_head = hidden_size / n_head;
|
||||
const float eps = hparams.eps;
|
||||
|
||||
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
|
||||
|
||||
const int batch_size = imgs.entries.size();
|
||||
GGML_ASSERT(batch_size == 1);
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ ctx->buf_compute_meta.size(),
|
||||
/*.mem_buffer =*/ ctx->buf_compute_meta.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
ggml_context_ptr ctx0_ptr(ggml_init(params));
|
||||
auto ctx0 = ctx0_ptr.get();
|
||||
|
||||
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size);
|
||||
ggml_set_name(inp_raw, "inp_raw");
|
||||
ggml_set_input(inp_raw);
|
||||
|
||||
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
|
||||
GGML_ASSERT(image_size_width % (patch_size * 2) == 0);
|
||||
GGML_ASSERT(image_size_height % (patch_size * 2) == 0);
|
||||
|
||||
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
inp = ggml_add(ctx0, inp, inp_1);
|
||||
|
||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b]
|
||||
inp = ggml_reshape_4d(
|
||||
ctx0, inp,
|
||||
hidden_size * 2, patches_w / 2, patches_h, batch_size);
|
||||
inp = ggml_reshape_4d(
|
||||
ctx0, inp,
|
||||
hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
|
||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
|
||||
inp = ggml_reshape_3d(
|
||||
ctx0, inp,
|
||||
hidden_size, patches_w * patches_h, batch_size);
|
||||
|
||||
if (model.patch_bias) {
|
||||
// inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
|
||||
inp = ggml_add(ctx0, inp, model.patch_bias);
|
||||
}
|
||||
struct ggml_tensor * embeddings = inp;
|
||||
struct ggml_tensor * window_mask = nullptr;
|
||||
struct ggml_tensor * window_idx = nullptr;
|
||||
struct ggml_tensor * inv_window_idx = nullptr;
|
||||
|
||||
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
|
||||
ggml_set_name(positions, "positions");
|
||||
ggml_set_input(positions);
|
||||
|
||||
// pre-layernorm
|
||||
if (model.pre_ln_w) {
|
||||
embeddings = ggml_rms_norm(ctx0, embeddings, eps);
|
||||
ggml_set_name(embeddings, "pre_ln");
|
||||
|
||||
embeddings = ggml_mul(ctx0, embeddings, model.pre_ln_w);
|
||||
}
|
||||
|
||||
if (use_window_attn) {
|
||||
// handle window attention inputs
|
||||
inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
|
||||
ggml_set_name(inv_window_idx, "inv_window_idx");
|
||||
ggml_set_input(inv_window_idx);
|
||||
// mask for window attention
|
||||
window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, num_positions, num_positions);
|
||||
ggml_set_name(window_mask, "window_mask");
|
||||
ggml_set_input(window_mask);
|
||||
|
||||
// embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
|
||||
GGML_ASSERT(batch_size == 1);
|
||||
embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4);
|
||||
embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx);
|
||||
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size);
|
||||
}
|
||||
|
||||
// loop over layers
|
||||
for (int il = 0; il < ctx->max_feature_layer; il++) {
|
||||
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
|
||||
|
||||
// rmsnorm1
|
||||
cur = ggml_rms_norm(ctx0, cur, eps);
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].ln_1_w);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
|
||||
struct ggml_tensor * Q =
|
||||
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
|
||||
|
||||
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
|
||||
Q = ggml_rope_multi(
|
||||
ctx0, Q, positions, nullptr,
|
||||
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
|
||||
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
|
||||
Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
|
||||
|
||||
struct ggml_tensor * K =
|
||||
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
|
||||
|
||||
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
|
||||
K = ggml_rope_multi(
|
||||
ctx0, K, positions, nullptr,
|
||||
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
|
||||
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
|
||||
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
|
||||
|
||||
struct ggml_tensor * V =
|
||||
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
|
||||
|
||||
V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
|
||||
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
|
||||
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
|
||||
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
const bool full_attn = use_window_attn ? (il + 1) % n_wa_pattern == 0 : true;
|
||||
if (full_attn) {
|
||||
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
|
||||
} else {
|
||||
KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f / sqrtf((float)d_head), 0.0f);
|
||||
}
|
||||
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
|
||||
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
|
||||
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
|
||||
cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size);
|
||||
}
|
||||
|
||||
// attention output
|
||||
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
|
||||
|
||||
// re-add the layer input, e.g., residual
|
||||
cur = ggml_add(ctx0, cur, embeddings);
|
||||
|
||||
embeddings = cur; // embeddings = residual, cur = hidden_states
|
||||
|
||||
// rms norm2
|
||||
cur = ggml_rms_norm(ctx0, cur, eps);
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].ln_2_w);
|
||||
|
||||
// mlp
|
||||
// ffn_up
|
||||
auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
|
||||
cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_o_b);
|
||||
|
||||
auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_g_w, cur);
|
||||
cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_g_b);
|
||||
// TODO : only 2 of these 3 are actually used, should we remove one of them?
|
||||
if (ctx->use_gelu) {
|
||||
cur_gate = ggml_gelu_inplace(ctx0, cur_gate);
|
||||
} else if (ctx->use_silu) {
|
||||
cur_gate = ggml_silu_inplace(ctx0, cur_gate);
|
||||
} else {
|
||||
cur_gate = ggml_gelu_quick_inplace(ctx0, cur_gate);
|
||||
}
|
||||
cur = ggml_mul(ctx0, cur_gate, cur_up);
|
||||
|
||||
// ffn_down
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
|
||||
|
||||
// residual 2
|
||||
cur = ggml_add(ctx0, embeddings, cur);
|
||||
|
||||
embeddings = cur;
|
||||
}
|
||||
|
||||
// post-layernorm
|
||||
if (model.post_ln_w) {
|
||||
embeddings = ggml_rms_norm(ctx0, embeddings, eps);
|
||||
ggml_set_name(embeddings, "post_ln");
|
||||
|
||||
embeddings = ggml_mul(ctx0, embeddings, model.post_ln_w);
|
||||
}
|
||||
|
||||
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
|
||||
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
||||
|
||||
// GELU activation
|
||||
embeddings = ggml_gelu(ctx0, embeddings);
|
||||
|
||||
// Second linear layer
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
|
||||
|
||||
if (use_window_attn) {
|
||||
window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
|
||||
ggml_set_name(window_idx, "window_idx");
|
||||
ggml_set_input(window_idx);
|
||||
|
||||
// embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
|
||||
GGML_ASSERT(batch_size == 1);
|
||||
embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4);
|
||||
embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
|
||||
embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4, batch_size);
|
||||
}
|
||||
|
||||
// build the graph
|
||||
ggml_build_forward_expand(gf, embeddings);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
|
||||
const auto & model = ctx->vision_model;
|
||||
const auto & hparams = model.hparams;
|
||||
|
@ -1331,6 +1568,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
GGML_ASSERT(imgs.entries.size() == 1);
|
||||
res = clip_image_build_graph_pixtral(ctx, *imgs.entries[0]);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN25VL:
|
||||
{
|
||||
res = clip_image_build_graph_qwen25vl(ctx, imgs);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
// TODO: we should have one build_* function per model
|
||||
|
@ -1507,6 +1748,10 @@ struct clip_model_loader {
|
|||
{
|
||||
hparams.rope_theta = 10000.0f;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN25VL:
|
||||
{
|
||||
get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern);
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -1600,8 +1845,10 @@ struct clip_model_loader {
|
|||
// legacy naming (the in and out is reversed! don't ask me why)
|
||||
layer.ff_i_w = layer.ff_down_w;
|
||||
layer.ff_o_w = layer.ff_up_w;
|
||||
layer.ff_g_w = layer.ff_gate_w;
|
||||
layer.ff_i_b = layer.ff_down_b;
|
||||
layer.ff_o_b = layer.ff_up_b;
|
||||
layer.ff_g_b = layer.ff_gate_b;
|
||||
}
|
||||
|
||||
switch (ctx_clip.proj_type) {
|
||||
|
@ -1700,6 +1947,7 @@ struct clip_model_loader {
|
|||
vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H,"weight"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN2VL:
|
||||
case PROJECTOR_TYPE_QWEN25VL:
|
||||
{
|
||||
vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
|
||||
vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
|
||||
|
@ -2651,7 +2899,7 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
|
|||
else {
|
||||
GGML_ABORT("Unknown minicpmv version");
|
||||
}
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
|
||||
int patch_size = params.patch_size * 2;
|
||||
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
|
||||
int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
|
||||
|
@ -2792,6 +3040,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
const int pos_w = ctx->load_image_size.width / patch_size;
|
||||
const int pos_h = ctx->load_image_size.height / patch_size;
|
||||
|
||||
const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl
|
||||
|
||||
{
|
||||
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
|
||||
std::vector<float> inp_data(ggml_nelements(inp_raw));
|
||||
|
@ -2890,31 +3140,93 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
// non-minicpmv models
|
||||
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
|
||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
||||
// pw * ph = number of tokens output by ViT after apply patch merger
|
||||
// ipw * ipw = number of vision token been processed inside ViT
|
||||
const int merge_ratio = 2;
|
||||
const int pw = image_size_width / patch_size / merge_ratio;
|
||||
const int ph = image_size_height / patch_size / merge_ratio;
|
||||
const int ipw = image_size_width / patch_size;
|
||||
const int iph = image_size_height / patch_size;
|
||||
|
||||
const int pw = image_size_width / patch_size;
|
||||
const int ph = image_size_height / patch_size;
|
||||
int* positions_data = (int*)malloc(ggml_nbytes(positions));
|
||||
std::vector<int> idx (ph * pw);
|
||||
std::vector<int> inv_idx(ph * pw);
|
||||
|
||||
if (use_window_attn) {
|
||||
const int attn_window_size = 112;
|
||||
struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
|
||||
struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
|
||||
struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
|
||||
|
||||
const int grid_window = attn_window_size / patch_size / merge_ratio;
|
||||
int dst = 0;
|
||||
// [num_vision_tokens, num_vision_tokens] attention mask tensor
|
||||
std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
|
||||
int mask_row = 0;
|
||||
|
||||
for (int y = 0; y < ph; y += grid_window)
|
||||
{
|
||||
for (int x = 0; x < pw; x += grid_window)
|
||||
{
|
||||
const int win_h = std::min(grid_window, ph - y);
|
||||
const int win_w = std::min(grid_window, pw - x);
|
||||
const int dst_0 = dst;
|
||||
// group all tokens belong to the same window togather (to a continue range)
|
||||
for (int dy = 0; dy < win_h; dy++) {
|
||||
for (int dx = 0; dx < win_w; dx++) {
|
||||
const int src = (y + dy) * pw + (x + dx);
|
||||
assert(src < (int)idx.size());
|
||||
assert(dst < (int)inv_idx.size());
|
||||
idx [src] = dst;
|
||||
inv_idx[dst] = src;
|
||||
dst++;
|
||||
}
|
||||
}
|
||||
|
||||
for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
|
||||
int row_offset = mask_row * (ipw * iph);
|
||||
std::fill(
|
||||
mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
|
||||
mask.begin() + row_offset + (dst * merge_ratio * merge_ratio),
|
||||
0.0);
|
||||
mask_row++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
|
||||
ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
|
||||
ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
|
||||
} else {
|
||||
std::iota(idx.begin(), idx.end(), 0);
|
||||
std::iota(inv_idx.begin(), inv_idx.end(), 0);
|
||||
}
|
||||
|
||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
||||
const int mpow = merge_ratio * merge_ratio;
|
||||
std::vector<int> positions_data(ggml_nelements(positions));
|
||||
int * data = positions_data.data();
|
||||
|
||||
int ptr = 0;
|
||||
for (int y = 0; y < ph; y+=2)
|
||||
for (int y = 0; y < iph; y += merge_ratio)
|
||||
{
|
||||
for (int x = 0; x < pw; x+=2)
|
||||
for (int x = 0; x < ipw; x += merge_ratio)
|
||||
{
|
||||
for (int dy = 0; dy < 2; dy++) {
|
||||
for (int dx = 0; dx < 2; dx++) {
|
||||
positions_data[ptr] = y + dy;
|
||||
positions_data[num_patches + ptr] = x + dx;
|
||||
positions_data[num_patches * 2 + ptr] = y + dy;
|
||||
positions_data[num_patches * 3 + ptr] = x + dx;
|
||||
auto remap = idx[ptr / mpow];
|
||||
remap = remap * mpow + (ptr % mpow);
|
||||
|
||||
data[ remap] = y + dy;
|
||||
data[ num_patches + remap] = x + dx;
|
||||
data[2 * num_patches + remap] = y + dy;
|
||||
data[3 * num_patches + remap] = x + dx;
|
||||
ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
|
||||
free(positions_data);
|
||||
ggml_backend_tensor_set(positions, data, 0, ggml_nbytes(positions));
|
||||
}
|
||||
else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
// do nothing
|
||||
|
@ -2967,6 +3279,65 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
}
|
||||
}
|
||||
|
||||
if (use_window_attn && ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
|
||||
struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
|
||||
struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
|
||||
struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
|
||||
|
||||
const int merge_ratio = 2;
|
||||
const int attn_window_size = 112;
|
||||
const int pw = image_size_width / patch_size / merge_ratio;
|
||||
const int ph = image_size_height / patch_size / merge_ratio;
|
||||
const int grid_window = attn_window_size / patch_size / merge_ratio;
|
||||
const int ipw = image_size_width / patch_size;
|
||||
const int iph = image_size_height / patch_size;
|
||||
/*
|
||||
pw * ph = number of tokens output by ViT after apply patch merger
|
||||
ipw * ipw = number of vision token been processed inside ViT
|
||||
*/
|
||||
|
||||
std::vector<int> idx(ph * pw);
|
||||
std::vector<int> inv_idx(ph * pw);
|
||||
int dst = 0;
|
||||
// [num_vision_tokens, num_vision_tokens] attention mask tensor
|
||||
std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
|
||||
int mask_row = 0;
|
||||
|
||||
for (int y = 0; y < ph; y+=grid_window)
|
||||
{
|
||||
for (int x = 0; x < pw; x+=grid_window)
|
||||
{
|
||||
const int win_h = std::min(grid_window, ph - y);
|
||||
const int win_w = std::min(grid_window, pw - x);
|
||||
const int dst_0 = dst;
|
||||
// group all tokens belong to the same window togather (to a continue range)
|
||||
for (int dy = 0; dy < win_h; dy++) {
|
||||
for (int dx = 0; dx < win_w; dx++) {
|
||||
const int src = (y + dy) * pw + (x + dx);
|
||||
assert(src < (int)idx.size());
|
||||
assert(dst < (int)inv_idx.size());
|
||||
idx[src] = dst;
|
||||
inv_idx[dst] = src;
|
||||
dst++;
|
||||
}
|
||||
}
|
||||
|
||||
for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
|
||||
int row_offset = mask_row * (ipw * iph);
|
||||
std::fill(
|
||||
mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
|
||||
mask.begin() + row_offset + (dst * merge_ratio * merge_ratio),
|
||||
0.0);
|
||||
mask_row++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
|
||||
ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
|
||||
ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
|
||||
}
|
||||
|
||||
ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads);
|
||||
|
||||
auto status = ggml_backend_sched_graph_compute(ctx->sched.get(), gf);
|
||||
|
@ -3142,6 +3513,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
|||
case PROJECTOR_TYPE_GLM_EDGE:
|
||||
return ctx->vision_model.mm_model_mlp_3_w->ne[1];
|
||||
case PROJECTOR_TYPE_QWEN2VL:
|
||||
case PROJECTOR_TYPE_QWEN25VL:
|
||||
return ctx->vision_model.mm_1_b->ne[0];
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
return ctx->vision_model.mm_input_proj_w->ne[0];
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
import argparse
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from gguf import *
|
||||
from transformers import (
|
||||
Qwen2VLForConditionalGeneration,
|
||||
Qwen2VLProcessor,
|
||||
AutoProcessor,
|
||||
Qwen2VLConfig
|
||||
Qwen2VLConfig,
|
||||
Qwen2VLProcessor,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
Qwen2_5_VLConfig, # type: ignore[reportAttributeAccessIssue]
|
||||
Qwen2_5_VLForConditionalGeneration, # type: ignore[reportAttributeAccessIssue]
|
||||
)
|
||||
|
||||
|
||||
|
@ -19,61 +21,93 @@ def k(raw_key: str, arch: str) -> str:
|
|||
return raw_key.format(arch=arch)
|
||||
|
||||
|
||||
def to_gguf_name(name: str) -> str:
|
||||
og = name
|
||||
name = name.replace("text_model", "t").replace("vision_model", "v")
|
||||
name = name.replace("blocks", "blk").replace("embeddings.", "")
|
||||
name = name.replace("attn.", "attn_")
|
||||
name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.")
|
||||
# name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln")
|
||||
name = name.replace("norm1", "ln1").replace("norm2", "ln2")
|
||||
name = name.replace("merger.mlp", 'mm')
|
||||
print(f"[to_gguf_name] {og} --> {name}")
|
||||
return name
|
||||
def get_n_wa_pattern(fullatt_block_indexes: Optional[List[int]]):
|
||||
if fullatt_block_indexes is None:
|
||||
return 0
|
||||
n_wa = fullatt_block_indexes[0]
|
||||
for a, b in zip(fullatt_block_indexes, fullatt_block_indexes[1:]):
|
||||
if b - a - 1 != n_wa:
|
||||
raise ValueError(
|
||||
f"window/full attention layer should have fix pattern of "
|
||||
f"for each full-attention layer followed by {n_wa} window-attention layers"
|
||||
)
|
||||
return n_wa + 1
|
||||
|
||||
|
||||
def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]:
|
||||
vision_model = qwen2vl.visual
|
||||
tensor_map = {}
|
||||
for name, ten in vision_model.state_dict().items():
|
||||
ten = ten.numpy()
|
||||
if 'qkv' in name:
|
||||
if ten.ndim == 2: # weight
|
||||
c3, _ = ten.shape
|
||||
else: # bias
|
||||
c3 = ten.shape[0]
|
||||
assert c3 % 3 == 0
|
||||
c = c3 // 3
|
||||
wq = ten[:c]
|
||||
wk = ten[c: c * 2]
|
||||
wv = ten[c * 2:]
|
||||
tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq
|
||||
tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk
|
||||
tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv
|
||||
elif 'merger' in name:
|
||||
if name.endswith("ln_q.weight"):
|
||||
tensor_map['v.post_ln.weight'] = ten
|
||||
elif name.endswith("ln_q.bias"):
|
||||
tensor_map['v.post_ln.bias'] = ten
|
||||
class VL2:
|
||||
|
||||
@staticmethod
|
||||
def to_gguf_name(name: str) -> str:
|
||||
og = name
|
||||
name = name.replace("text_model", "t").replace("vision_model", "v")
|
||||
name = name.replace("blocks", "blk").replace("embeddings.", "")
|
||||
name = name.replace("attn.", "attn_")
|
||||
name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.")
|
||||
# name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln")
|
||||
name = name.replace("norm1", "ln1").replace("norm2", "ln2")
|
||||
name = name.replace("merger.mlp", 'mm')
|
||||
print(f"[to_gguf_name] {og} --> {name}")
|
||||
return name
|
||||
|
||||
@classmethod
|
||||
def find_vision_tensors(cls, qwen2vl, dtype) -> Dict[str, np.ndarray]:
|
||||
vision_model = qwen2vl.visual
|
||||
tensor_map = {}
|
||||
for name, ten in vision_model.state_dict().items():
|
||||
ten = ten.numpy()
|
||||
if 'qkv' in name:
|
||||
if ten.ndim == 2: # weight
|
||||
c3, _ = ten.shape
|
||||
else: # bias
|
||||
c3 = ten.shape[0]
|
||||
assert c3 % 3 == 0
|
||||
c = c3 // 3
|
||||
wq = ten[:c]
|
||||
wk = ten[c: c * 2]
|
||||
wv = ten[c * 2:]
|
||||
tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq
|
||||
tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk
|
||||
tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv
|
||||
elif 'merger' in name:
|
||||
if name.endswith("ln_q.weight"):
|
||||
tensor_map['v.post_ln.weight'] = ten
|
||||
elif name.endswith("ln_q.bias"):
|
||||
tensor_map['v.post_ln.bias'] = ten
|
||||
else:
|
||||
# "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias"
|
||||
tensor_map[cls.to_gguf_name(name)] = ten
|
||||
elif 'patch_embed.proj.weight' in name:
|
||||
# NOTE: split Conv3D into Conv2Ds
|
||||
c1, c2, kt, kh, kw = ten.shape
|
||||
assert kt == 2, "Current implmentation only support temporal_patch_size of 2"
|
||||
tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...]
|
||||
tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...]
|
||||
else:
|
||||
# "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias"
|
||||
tensor_map[to_gguf_name(name)] = ten
|
||||
elif 'patch_embed.proj.weight' in name:
|
||||
# NOTE: split Conv3D into Conv2Ds
|
||||
c1, c2, kt, kh, kw = ten.shape
|
||||
assert kt == 2, "Current implmentation only support temporal_patch_size of 2"
|
||||
tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...]
|
||||
tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...]
|
||||
else:
|
||||
tensor_map[to_gguf_name(f"vision_model.{name}")] = ten
|
||||
tensor_map[cls.to_gguf_name(f"vision_model.{name}")] = ten
|
||||
|
||||
for new_name, ten in tensor_map.items():
|
||||
if ten.ndim <= 1 or new_name.endswith("_norm.weight"):
|
||||
tensor_map[new_name] = ten.astype(np.float32)
|
||||
else:
|
||||
tensor_map[new_name] = ten.astype(dtype)
|
||||
tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder
|
||||
return tensor_map
|
||||
for new_name, ten in tensor_map.items():
|
||||
if ten.ndim <= 1 or new_name.endswith("_norm.weight"):
|
||||
tensor_map[new_name] = ten.astype(np.float32)
|
||||
else:
|
||||
tensor_map[new_name] = ten.astype(dtype)
|
||||
tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder
|
||||
return tensor_map
|
||||
|
||||
|
||||
class VL25(VL2):
|
||||
|
||||
@staticmethod
|
||||
def to_gguf_name(name: str) -> str:
|
||||
og = name
|
||||
name = name.replace("text_model", "t").replace("vision_model", "v")
|
||||
name = name.replace("blocks", "blk").replace("embeddings.", "")
|
||||
name = name.replace("attn.", "attn_")
|
||||
name = name.replace("mlp.down_proj", "ffn_down").replace("mlp.up_proj", "ffn_up")
|
||||
name = name.replace("mlp.gate_proj", "ffn_gate").replace("proj.", "out.")
|
||||
name = name.replace("norm1", "ln1").replace("norm2", "ln2")
|
||||
name = name.replace("merger.mlp", 'mm')
|
||||
print(f"[vl25][to_gguf_name] {og} --> {name}")
|
||||
return name
|
||||
|
||||
|
||||
def main(args):
|
||||
|
@ -82,7 +116,7 @@ def main(args):
|
|||
np_dtype = np.float32
|
||||
ftype = 0
|
||||
elif args.data_type == 'fp16':
|
||||
dtype = torch.float32
|
||||
dtype = torch.float16
|
||||
np_dtype = np.float16
|
||||
ftype = 1
|
||||
else:
|
||||
|
@ -92,11 +126,18 @@ def main(args):
|
|||
model_path = ""
|
||||
model_name = args.model_name
|
||||
print("model_name: ", model_name)
|
||||
qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_name, torch_dtype=dtype, device_map="cpu"
|
||||
)
|
||||
cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType]
|
||||
vcfg = cfg.vision_config
|
||||
if args.model_type == "qwen2vl":
|
||||
qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_name, torch_dtype=dtype, device_map="cpu"
|
||||
)
|
||||
cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType]
|
||||
vcfg = cfg.vision_config
|
||||
else:
|
||||
qwen2vl = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
model_name, torch_dtype=dtype, device_map="cpu"
|
||||
)
|
||||
cfg: Qwen2_5_VLConfig = qwen2vl.config # type: ignore[reportAssignmentType]
|
||||
vcfg = cfg.vision_config
|
||||
|
||||
if os.path.isdir(model_name):
|
||||
local_model = True
|
||||
|
@ -113,7 +154,6 @@ def main(args):
|
|||
fout.add_bool("clip.has_text_encoder", False)
|
||||
fout.add_bool("clip.has_vision_encoder", True)
|
||||
fout.add_bool("clip.has_qwen2vl_merger", True)
|
||||
fout.add_string("clip.projector_type", "qwen2vl_merger")
|
||||
|
||||
print(cfg.vision_config)
|
||||
if 'silu' in cfg.vision_config.hidden_act.lower():
|
||||
|
@ -125,14 +165,25 @@ def main(args):
|
|||
else:
|
||||
raise ValueError()
|
||||
|
||||
tensor_map = find_vision_tensors(qwen2vl, np_dtype)
|
||||
if args.model_type == "qwen2.5vl":
|
||||
fout.add_uint32("clip.vision.n_wa_pattern", get_n_wa_pattern(vcfg.fullatt_block_indexes))
|
||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.hidden_size)
|
||||
fout.add_uint32("clip.vision.projection_dim", vcfg.out_hidden_size)
|
||||
fout.add_string("clip.projector_type", "qwen2.5vl_merger")
|
||||
else:
|
||||
fout.add_string("clip.projector_type", "qwen2vl_merger")
|
||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim)
|
||||
fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size)
|
||||
|
||||
if args.model_type == "qwen2.5vl":
|
||||
tensor_map = VL25.find_vision_tensors(qwen2vl, np_dtype)
|
||||
else:
|
||||
tensor_map = VL2.find_vision_tensors(qwen2vl, np_dtype)
|
||||
for name, data in tensor_map.items():
|
||||
fout.add_tensor(name, data)
|
||||
|
||||
fout.add_uint32("clip.vision.patch_size", vcfg.patch_size)
|
||||
fout.add_uint32("clip.vision.image_size", 14 * 40) # some reasonable size that is divable by (14*2)
|
||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim)
|
||||
fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size)
|
||||
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads)
|
||||
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
|
||||
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth)
|
||||
|
@ -160,6 +211,7 @@ def main(args):
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct")
|
||||
parser.add_argument("--model_type", nargs='?', choices=['qwen2vl', 'qwen2.5vl'], default="qwen2vl")
|
||||
parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
|
|
@ -23,6 +23,9 @@
|
|||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <limits>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
|
||||
static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed,
|
||||
|
@ -367,14 +370,14 @@ static void debug_test_mrope_2d() {
|
|||
// 1. Initialize backend
|
||||
ggml_backend_t backend = NULL;
|
||||
std::string backend_name = "";
|
||||
#ifdef GGML_USE_CUDA
|
||||
fprintf(stderr, "%s: using CUDA backend\n", __func__);
|
||||
backend = ggml_backend_cuda_init(0); // init device 0
|
||||
backend_name = "cuda";
|
||||
if (!backend) {
|
||||
fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
|
||||
}
|
||||
#endif
|
||||
// #ifdef GGML_USE_CUDA
|
||||
// fprintf(stderr, "%s: using CUDA backend\n", __func__);
|
||||
// backend = ggml_backend_cuda_init(0); // init device 0
|
||||
// backend_name = "cuda";
|
||||
// if (!backend) {
|
||||
// fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
|
||||
// }
|
||||
// #endif
|
||||
// if there aren't GPU Backends fallback to CPU backend
|
||||
if (!backend) {
|
||||
backend = ggml_backend_cpu_init();
|
||||
|
@ -483,28 +486,82 @@ static void debug_test_mrope_2d() {
|
|||
ggml_backend_free(backend);
|
||||
}
|
||||
|
||||
static void debug_dump_img_embed(struct llava_context * ctx_llava) {
|
||||
int n_embd = llama_model_n_embd(llama_get_model(ctx_llava->ctx_llama));
|
||||
int ne = n_embd * 4;
|
||||
float vals[56 * 56 * 3];
|
||||
enum model_output_type {
|
||||
conv3d,
|
||||
patch_embed,
|
||||
patch_win_attn_scatter,
|
||||
first_attn_layer,
|
||||
last_attn_layer,
|
||||
attn_softmax,
|
||||
final_layer,
|
||||
};
|
||||
|
||||
static void debug_dump_img_embed(struct llava_context * ctx_llava, model_output_type output_type) {
|
||||
constexpr int ih = 140;
|
||||
constexpr int iw = 196;
|
||||
// constexpr int ih = 56;
|
||||
// constexpr int iw = 56;
|
||||
// int n_embd = llama_model_n_embd(llama_get_model(ctx_llava->ctx_llama));
|
||||
int n_embd = 1280;
|
||||
int merge = 1;
|
||||
if (output_type == model_output_type::final_layer) {
|
||||
n_embd = 2048;
|
||||
merge = 2;
|
||||
}
|
||||
else if (output_type == model_output_type::attn_softmax) {
|
||||
merge = 1;
|
||||
n_embd = (ih/14/merge) * (iw/14/merge) * 16;
|
||||
}
|
||||
|
||||
int ne = (ih/14/merge) * (iw/14/merge) * n_embd;
|
||||
float vals[iw * ih * 3];
|
||||
// float embd[ne];
|
||||
std::vector<float> embd;
|
||||
embd.resize(ne);
|
||||
|
||||
for (int i = 0; i < 56*56; i++)
|
||||
for (int i = 0; i < iw*ih; i++)
|
||||
{
|
||||
for (int c = 0; c < 3; c++)
|
||||
vals[i * 3 + c] = (float)(i % (56 * 56)) / (56*56);
|
||||
vals[i * 3 + c] = (float)i / (iw*ih);
|
||||
}
|
||||
|
||||
clip_encode_float_image(ctx_llava->ctx_clip, 16, vals, 56, 56, embd.data());
|
||||
clip_encode_float_image(ctx_llava->ctx_clip, 8, vals, ih, iw, embd.data());
|
||||
|
||||
std::ofstream outFile("img_embed.bin", std::ios::binary);
|
||||
std::string file_postfix = "";
|
||||
switch (output_type)
|
||||
{
|
||||
case model_output_type::conv3d:
|
||||
file_postfix = "conv3d";
|
||||
break;
|
||||
case model_output_type::patch_embed:
|
||||
file_postfix = "patch_embed";
|
||||
break;
|
||||
case model_output_type::patch_win_attn_scatter:
|
||||
file_postfix = "scatter";
|
||||
break;
|
||||
case model_output_type::first_attn_layer:
|
||||
file_postfix = "first_attn";
|
||||
break;
|
||||
case model_output_type::last_attn_layer:
|
||||
file_postfix = "last_attn";
|
||||
break;
|
||||
case model_output_type::attn_softmax:
|
||||
file_postfix = "attn_softmax";
|
||||
break;
|
||||
case model_output_type::final_layer:
|
||||
file_postfix = "final";
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
auto output_path = "img_embed_" + file_postfix + ".bin";
|
||||
|
||||
std::ofstream outFile(output_path, std::ios::binary);
|
||||
if (outFile.is_open()) {
|
||||
outFile.write(reinterpret_cast<const char*>(embd.data()), ne * sizeof(float));
|
||||
|
||||
outFile.close();
|
||||
std::cout << "Data successfully written to mrope.bin" << std::endl;
|
||||
std::cout << "Data successfully written to ::[ " << output_path << std::endl;
|
||||
} else {
|
||||
std::cerr << "Error opening file!" << std::endl;
|
||||
}
|
||||
|
@ -551,8 +608,9 @@ int main(int argc, char ** argv) {
|
|||
} else if (params.image[0].empty()) {
|
||||
auto ctx_llava = llava_init_context(¶ms, model);
|
||||
|
||||
debug_test_mrope_2d();
|
||||
debug_dump_img_embed(ctx_llava);
|
||||
// debug_test_mrope_2d();
|
||||
debug_dump_img_embed(ctx_llava, model_output_type::final_layer);
|
||||
// debug_dump_img_embed(ctx_llava, model_output_type::last_attn_layer);
|
||||
|
||||
llama_perf_context_print(ctx_llava->ctx_llama);
|
||||
ctx_llava->model = NULL;
|
||||
|
|
|
@ -55,6 +55,7 @@ add_test "llama-mtmd-cli" "second-state/MiniCPM-Llama3-V-2_5-GGUF:Q2_K" # mode
|
|||
add_test "llama-mtmd-cli" "openbmb/MiniCPM-V-2_6-gguf:Q2_K"
|
||||
add_test "llama-mtmd-cli" "openbmb/MiniCPM-o-2_6-gguf:Q4_0"
|
||||
add_test "llama-qwen2vl-cli" "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
|
||||
add_test "llama-qwen2vl-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
|
||||
|
||||
# to test the big models, run: ./tests.sh big
|
||||
add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue