mtmd : support SmolVLM (version 1 and 2) (#13050)
* mtmd : support SmolVLM (version 1 and 2) * correct chat template * fix n_patches * scale_factor is an int * add more models to test
This commit is contained in:
parent
ab47dec3d3
commit
dc39a5e7a8
10 changed files with 279 additions and 65 deletions
|
@ -33,13 +33,13 @@
|
|||
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
|
||||
#define KEY_PROJ_DIM "clip.%s.projection_dim"
|
||||
#define KEY_TOKENS "tokenizer.ggml.tokens"
|
||||
#define KEY_N_POSITIONS "clip.text.context_length"
|
||||
#define KEY_IMAGE_SIZE "clip.vision.image_size"
|
||||
#define KEY_PATCH_SIZE "clip.vision.patch_size"
|
||||
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
|
||||
#define KEY_IMAGE_STD "clip.vision.image_std"
|
||||
#define KEY_PROJ_TYPE "clip.projector_type"
|
||||
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
|
||||
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
|
||||
#define KEY_PROJ_TYPE "clip.projector_type"
|
||||
|
||||
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
|
||||
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
|
||||
|
@ -72,6 +72,7 @@
|
|||
#define TN_IMAGE_NEWLINE "model.image_newline"
|
||||
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
|
||||
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
|
||||
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3
|
||||
|
||||
// mimicpmv
|
||||
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
|
||||
|
@ -99,6 +100,7 @@ enum projector_type {
|
|||
PROJECTOR_TYPE_GLM_EDGE,
|
||||
PROJECTOR_TYPE_MERGER,
|
||||
PROJECTOR_TYPE_GEMMA3,
|
||||
PROJECTOR_TYPE_IDEFICS3,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
|
@ -110,6 +112,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
|||
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
|
||||
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
|
||||
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
|
||||
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
|
||||
};
|
||||
|
||||
static projector_type clip_projector_type_from_string(const std::string & str) {
|
||||
|
|
|
@ -159,6 +159,7 @@ struct clip_hparams {
|
|||
int32_t projection_dim;
|
||||
int32_t n_head;
|
||||
int32_t n_layer;
|
||||
int32_t proj_scale_factor = 0; // idefics3
|
||||
|
||||
patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT;
|
||||
|
||||
|
@ -506,6 +507,35 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
|
|||
embeddings = ggml_mul_mat(ctx0,
|
||||
ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)),
|
||||
embeddings);
|
||||
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
||||
// https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
|
||||
|
||||
ggml_tensor * cur = embeddings;
|
||||
const int scale_factor = model.hparams.proj_scale_factor;
|
||||
const int n_embd = cur->ne[0];
|
||||
const int seq = cur->ne[1];
|
||||
const int bsz = 1; // batch size, always 1 for now since we don't support batching
|
||||
const int height = std::sqrt(seq);
|
||||
const int width = std::sqrt(seq);
|
||||
GGML_ASSERT(scale_factor != 0);
|
||||
cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz);
|
||||
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
||||
cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur),
|
||||
n_embd * scale_factor * scale_factor,
|
||||
height / scale_factor,
|
||||
width / scale_factor,
|
||||
bsz);
|
||||
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
||||
cur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, cur),
|
||||
n_embd * scale_factor * scale_factor,
|
||||
seq / (scale_factor * scale_factor),
|
||||
bsz);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.projection, cur);
|
||||
embeddings = cur;
|
||||
} else {
|
||||
GGML_ABORT("SigLIP: Unsupported projector type");
|
||||
}
|
||||
|
||||
// build the graph
|
||||
|
@ -1081,12 +1111,20 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
|||
}
|
||||
|
||||
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
return clip_image_build_graph_siglip(ctx, imgs);
|
||||
} else {
|
||||
// TODO: we should have one build_* function per model
|
||||
return clip_image_build_graph_legacy(ctx, imgs, load_image_size, is_inf);
|
||||
ggml_cgraph * res;
|
||||
switch (ctx->proj_type) {
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
case PROJECTOR_TYPE_IDEFICS3:
|
||||
{
|
||||
res = clip_image_build_graph_siglip(ctx, imgs);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
// TODO: we should have one build_* function per model
|
||||
res = clip_image_build_graph_legacy(ctx, imgs, load_image_size, is_inf);
|
||||
} break;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
struct clip_model_loader {
|
||||
|
@ -1147,6 +1185,8 @@ struct clip_model_loader {
|
|||
}
|
||||
|
||||
void load_hparams() {
|
||||
auto & hparams = ctx_clip.vision_model.hparams;
|
||||
|
||||
// projector type
|
||||
{
|
||||
std::string proj_type;
|
||||
|
@ -1177,7 +1217,6 @@ struct clip_model_loader {
|
|||
get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false);
|
||||
get_bool(KEY_USE_SILU, ctx_clip.use_silu, false);
|
||||
|
||||
auto & hparams = ctx_clip.vision_model.hparams;
|
||||
get_u32(string_format(KEY_N_EMBD, "vision"), hparams.hidden_size);
|
||||
get_u32(string_format(KEY_N_HEAD, "vision"), hparams.n_head);
|
||||
get_u32(string_format(KEY_N_FF, "vision"), hparams.n_intermediate);
|
||||
|
@ -1233,6 +1272,16 @@ struct clip_model_loader {
|
|||
LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
|
||||
LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
|
||||
}
|
||||
|
||||
// model-specific params
|
||||
switch (ctx_clip.proj_type) {
|
||||
case PROJECTOR_TYPE_IDEFICS3:
|
||||
{
|
||||
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void load_tensors() {
|
||||
|
@ -1422,6 +1471,10 @@ struct clip_model_loader {
|
|||
vision_model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ);
|
||||
vision_model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_IDEFICS3:
|
||||
{
|
||||
vision_model.projection = get_tensor(TN_MM_PROJECTOR);
|
||||
} break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown projector type");
|
||||
}
|
||||
|
@ -2195,10 +2248,12 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
|||
return true;
|
||||
}
|
||||
|
||||
if (ctx->has_glm_projector || ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
if (ctx->has_glm_projector
|
||||
|| ctx->proj_type == PROJECTOR_TYPE_GEMMA3
|
||||
|| ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
||||
clip_image_u8 resized_image;
|
||||
int sz = params.image_size;
|
||||
image_manipulation::bicubic_resize(*img, resized_image, sz, sz);
|
||||
image_manipulation::resize_and_pad_image(*img, resized_image, {sz, sz});
|
||||
clip_image_f32_ptr img_f32(clip_image_f32_init());
|
||||
//clip_image_save_to_bmp(resized_image, "resized.bmp");
|
||||
normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
|
||||
|
@ -2330,6 +2385,8 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
|
|||
n_patches = x_patch * y_patch;
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
n_patches = 256;
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
||||
n_patches /= ctx->vision_model.hparams.proj_scale_factor;
|
||||
}
|
||||
|
||||
return n_patches;
|
||||
|
@ -2597,6 +2654,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
// do nothing
|
||||
}
|
||||
else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
||||
// do nothing
|
||||
}
|
||||
else {
|
||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
||||
|
||||
|
@ -2783,37 +2843,34 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
|
|||
}
|
||||
|
||||
int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_LDP) {
|
||||
return ctx->vision_model.mm_model_block_1_block_2_1_b->ne[0];
|
||||
}
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_LDPV2) {
|
||||
return ctx->vision_model.mm_model_peg_0_b->ne[0];
|
||||
}
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
|
||||
return ctx->vision_model.mm_2_b->ne[0];
|
||||
}
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
|
||||
return ctx->vision_model.mm_3_b->ne[0];
|
||||
}
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
|
||||
if (ctx->minicpmv_version == 2) {
|
||||
return 4096;
|
||||
}
|
||||
else if (ctx->minicpmv_version == 3) {
|
||||
return 3584;
|
||||
}
|
||||
else if (ctx->minicpmv_version == 4) {
|
||||
return 3584;
|
||||
}
|
||||
}
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE){
|
||||
return ctx->vision_model.mm_model_mlp_3_w->ne[1];
|
||||
}
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
return ctx->vision_model.mm_1_b->ne[0];
|
||||
}
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
return ctx->vision_model.mm_input_proj_w->ne[0];
|
||||
switch (ctx->proj_type) {
|
||||
case PROJECTOR_TYPE_LDP:
|
||||
return ctx->vision_model.mm_model_block_1_block_2_1_b->ne[0];
|
||||
case PROJECTOR_TYPE_LDPV2:
|
||||
return ctx->vision_model.mm_model_peg_0_b->ne[0];
|
||||
case PROJECTOR_TYPE_MLP:
|
||||
return ctx->vision_model.mm_2_b->ne[0];
|
||||
case PROJECTOR_TYPE_MLP_NORM:
|
||||
return ctx->vision_model.mm_3_b->ne[0];
|
||||
case PROJECTOR_TYPE_RESAMPLER:
|
||||
if (ctx->minicpmv_version == 2) {
|
||||
return 4096;
|
||||
} else if (ctx->minicpmv_version == 3) {
|
||||
return 3584;
|
||||
} else if (ctx->minicpmv_version == 4) {
|
||||
return 3584;
|
||||
}
|
||||
break; // Should not happen if version is valid
|
||||
case PROJECTOR_TYPE_GLM_EDGE:
|
||||
return ctx->vision_model.mm_model_mlp_3_w->ne[1];
|
||||
case PROJECTOR_TYPE_MERGER:
|
||||
return ctx->vision_model.mm_1_b->ne[0];
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
return ctx->vision_model.mm_input_proj_w->ne[0];
|
||||
case PROJECTOR_TYPE_IDEFICS3:
|
||||
return ctx->vision_model.projection->ne[1];
|
||||
default:
|
||||
break; // Fall through to throw
|
||||
}
|
||||
|
||||
std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
|
||||
|
|
|
@ -176,6 +176,8 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|
|||
|
||||
std::string prompt_modified(text.text);
|
||||
std::string marker_modified(ctx->image_marker);
|
||||
projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
|
||||
|
||||
// a bit hacky here, but works for now
|
||||
// for some models, we need to add prefix and suffix to the image embeddings
|
||||
if (clip_is_gemma3(ctx->ctx_clip)) {
|
||||
|
@ -183,6 +185,11 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|
|||
// <start_of_image> ... (image embeddings) ... <end_of_image>
|
||||
marker_modified = "<start_of_image>" + ctx->image_marker + "<end_of_image>";
|
||||
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
||||
|
||||
} else if (proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
||||
// https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
|
||||
marker_modified = "<fake_token_around_image><global-img>" + ctx->image_marker + "<fake_token_around_image>";
|
||||
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
||||
}
|
||||
|
||||
// llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
|
||||
|
|
|
@ -28,6 +28,9 @@ add_test() {
|
|||
arr_tmpl+=("$tmpl")
|
||||
}
|
||||
|
||||
add_test "llama-mtmd-cli" "ggml-org/SmolVLM-500M-Instruct-GGUF:Q8_0"
|
||||
add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-2.2B-Instruct-GGUF:Q4_K_M"
|
||||
add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0"
|
||||
add_test "llama-mtmd-cli" "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M"
|
||||
add_test "llama-mtmd-cli" "guinmoon/MobileVLM-3B-GGUF:Q4_K_M" "deepseek"
|
||||
add_test "llama-mtmd-cli" "THUDM/glm-edge-v-5b-gguf:Q4_K_M"
|
||||
|
@ -39,7 +42,13 @@ add_test "llama-mtmd-cli" "openbmb/MiniCPM-V-2_6-gguf:Q2_K"
|
|||
add_test "llama-mtmd-cli" "openbmb/MiniCPM-o-2_6-gguf:Q4_0"
|
||||
add_test "llama-qwen2vl-cli" "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
|
||||
|
||||
# add_test "llama-mtmd-cli" "cmp-nct/Yi-VL-6B-GGUF:Q5_K" # this model has broken chat template, not usable
|
||||
# these models always give the wrong answer, not sure why
|
||||
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M"
|
||||
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-256M-Instruct-GGUF:Q8_0"
|
||||
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-256M-Video-Instruct-GGUF:Q8_0"
|
||||
|
||||
# this model has broken chat template, not usable
|
||||
# add_test "llama-mtmd-cli" "cmp-nct/Yi-VL-6B-GGUF:Q5_K"
|
||||
|
||||
###############
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue