mtmd : add vision support for llama 4 (#13282)

* wip llama 4 conversion

* rm redundant __init__

* fix conversion

* fix conversion

* test impl

* try this

* reshape patch_embeddings_0

* fix view

* rm ffn_post_norm

* cgraph ok

* f32 for pos embd

* add image marker tokens

* Llama4UnfoldConvolution

* correct pixel shuffle

* fix merge conflicts

* correct

* add debug_graph

* logits matched, but it still preceives the image incorrectly

* fix style

* add image_grid_pinpoints

* handle llama 4 preprocessing

* rm load_image_size

* rm unused line

* fix

* small fix 2

* add test & docs

* fix llava-1.6 test

* test: add notion of huge models

* add comment

* add warn about degraded quality
This commit is contained in:
Xuan-Son Nguyen 2025-05-19 13:04:14 +02:00 committed by GitHub
parent f71f40a284
commit 92ecdcc06a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 424 additions and 82 deletions

View file

@ -359,9 +359,12 @@ struct clip_ctx {
int max_nodes = 8192;
ggml_backend_sched_ptr sched;
clip_image_size load_image_size;
// for debugging
bool debug_graph = false;
std::vector<ggml_tensor *> debug_print_tensors;
clip_ctx(clip_context_params & ctx_params) {
debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr;
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
if (!backend_cpu) {
throw std::runtime_error("failed to initialize CPU backend");
@ -440,7 +443,7 @@ struct clip_graph {
};
ctx0_ptr.reset(ggml_init(params));
ctx0 = ctx0_ptr.get();
gf = ggml_new_graph(ctx0);
gf = ggml_new_graph_custom(ctx0, ctx->max_nodes, false);
}
ggml_cgraph * build_siglip() {
@ -522,7 +525,7 @@ struct clip_graph {
ggml_set_input(pos_w);
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
return build_rope_2d(ctx0, cur, pos_h, pos_w, hparams.rope_theta);
return build_rope_2d(ctx0, cur, pos_h, pos_w, hparams.rope_theta, true);
};
ggml_tensor * inp = build_inp();
@ -936,6 +939,101 @@ struct clip_graph {
return gf;
}
ggml_cgraph * build_llama4() {
GGML_ASSERT(model.class_embedding != nullptr);
GGML_ASSERT(model.position_embeddings != nullptr);
const int n_pos = n_patches + 1; // +1 for [CLS]
// 2D input positions
ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
ggml_set_name(pos_h, "pos_h");
ggml_set_input(pos_h);
ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
ggml_set_name(pos_w, "pos_w");
ggml_set_input(pos_w);
ggml_tensor * inp = build_inp_raw();
// Llama4UnfoldConvolution
{
ggml_tensor * kernel = ggml_reshape_4d(ctx0, model.patch_embeddings_0,
patch_size, patch_size, 3, n_embd);
inp = ggml_im2col(ctx0, kernel, inp, patch_size, patch_size, 0, 0, 1, 1, true, inp->type);
inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp);
inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches);
cb(inp, "patch_conv", -1);
}
// add CLS token
inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
// build ViT with 2D position embeddings
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
// first half is X axis and second half is Y axis
// ref: https://github.com/huggingface/transformers/blob/40a493c7ed4f19f08eadb0639cf26d49bfa5e180/src/transformers/models/llama4/modeling_llama4.py#L1312
// ref: https://github.com/Blaizzy/mlx-vlm/blob/a57156aa87b33cca6e5ee6cfc14dd4ef8f611be6/mlx_vlm/models/llama4/vision.py#L441
return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false);
};
ggml_tensor * cur = build_vit(
inp, n_pos,
NORM_TYPE_NORMAL,
hparams.ffn_op,
model.position_embeddings,
add_pos);
// remove CLS token
cur = ggml_view_2d(ctx0, cur,
n_embd, n_patches,
ggml_row_size(cur->type, n_embd), 0);
// pixel shuffle
// based on Llama4VisionPixelShuffleMLP
// https://github.com/huggingface/transformers/blob/2932f318a20d9e54cc7aea052e040164d85de7d6/src/transformers/models/llama4/modeling_llama4.py#L1151
{
const int scale_factor = model.hparams.proj_scale_factor;
const int bsz = 1; // batch size, always 1 for now since we don't support batching
GGML_ASSERT(scale_factor > 0);
GGML_ASSERT(n_patches_x == n_patches_y); // llama4 only supports square images
cur = ggml_reshape_4d(ctx0, cur,
n_embd * scale_factor,
n_patches_x / scale_factor,
n_patches_y,
bsz);
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur),
n_embd * scale_factor * scale_factor,
n_patches_x / scale_factor,
n_patches_y / scale_factor,
bsz);
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
// flatten to 2D
cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur),
n_embd * scale_factor * scale_factor,
n_patches / scale_factor / scale_factor);
cb(cur, "pixel_shuffle", -1);
}
// based on Llama4VisionMLP2 (always uses GELU activation, no bias)
{
cur = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, cur);
cur = ggml_gelu(ctx0, cur);
cur = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, cur);
cur = ggml_gelu(ctx0, cur);
cb(cur, "adapter_mlp", -1);
}
// Llama4MultiModalProjector
cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur);
cb(cur, "projected", -1);
// build the graph
ggml_build_forward_expand(gf, cur);
return gf;
}
// this graph is used by llava, granite and glm
// due to having embedding_stack (used by granite), we cannot reuse build_vit
ggml_cgraph * build_llava() {
@ -1315,11 +1413,15 @@ private:
// utility functions
//
void cb(ggml_tensor * cur, const char * name, int il) const {
// TODO: implement this
GGML_UNUSED(cur);
GGML_UNUSED(name);
GGML_UNUSED(il);
void cb(ggml_tensor * cur0, const char * name, int il) const {
if (ctx->debug_graph) {
ggml_tensor * cur = ggml_cpy(ctx0, cur0, ggml_dup_tensor(ctx0, cur0));
std::string cur_name = il >= 0 ? std::string(name) + "_" + std::to_string(il) : name;
ggml_set_name(cur, cur_name.c_str());
ggml_set_output(cur);
ggml_build_forward_expand(gf, cur);
ctx->debug_print_tensors.push_back(cur);
}
}
// build vision transformer (ViT) cgraph
@ -1630,9 +1732,10 @@ private:
static ggml_tensor * build_rope_2d(
ggml_context * ctx0,
ggml_tensor * cur,
ggml_tensor * pos_h,
ggml_tensor * pos_w,
const float freq_base
ggml_tensor * pos_a, // first half
ggml_tensor * pos_b, // second half
const float freq_base,
const bool interleave_freq
) {
const int64_t n_dim = cur->ne[0];
const int64_t n_head = cur->ne[1];
@ -1646,7 +1749,9 @@ private:
// ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
// then for the second half, we use freq_scale to shift the inv_freq
// ^ why? replace (2i) with (2i+1) in the above equation
const float freq_scale_odd = std::pow(freq_base, (float)-2/n_dim);
const float freq_scale_odd = interleave_freq
? std::pow(freq_base, (float)-2/n_dim)
: 1.0;
// first half
ggml_tensor * first;
@ -1659,7 +1764,7 @@ private:
first = ggml_rope_ext(
ctx0,
first,
pos_h, // positions
pos_a, // positions
nullptr, // freq factors
n_dim/2, // n_dims
0, 0, freq_base,
@ -1679,7 +1784,7 @@ private:
second = ggml_rope_ext(
ctx0,
second,
pos_w, // positions
pos_b, // positions
nullptr, // freq factors
n_dim/2, // n_dims
0, 0, freq_base,
@ -1723,6 +1828,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
res = graph.build_internvl();
} break;
case PROJECTOR_TYPE_LLAMA4:
{
res = graph.build_llama4();
} break;
default:
{
res = graph.build_llava();
@ -1926,6 +2035,21 @@ struct clip_model_loader {
hparams.warmup_image_size = hparams.patch_size * 8;
get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern);
} break;
case PROJECTOR_TYPE_LLAMA4:
{
hparams.rope_theta = 10000.0f;
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor);
// borrowed from llava-1.6
const int isize = hparams.image_size;
hparams.image_grid_pinpoints = {
isize, isize*2, // 336, 672
isize*2, isize, // 672, 336
isize*2, isize*2, // 672, 672
isize*3, isize, // 1008, 336
isize, isize*3, // 336, 1008
};
} break;
default:
break;
}
@ -1946,6 +2070,10 @@ struct clip_model_loader {
LOG_INF("%s: ffn_op: %s\n", __func__, log_ffn_op.c_str());
LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
if (ctx_clip.proj_type == PROJECTOR_TYPE_LLAMA4) {
LOG_WRN("%s: llama 4 vision is known to have degraded quality: https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__);
}
}
}
@ -2001,7 +2129,7 @@ struct clip_model_loader {
vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false);
vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false);
vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, "v"), false);
vision_model.position_embeddings = get_tensor(TN_POS_EMBD, false);
// layers
vision_model.layers.resize(hparams.n_layer);
@ -2182,6 +2310,12 @@ struct clip_model_loader {
vision_model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
vision_model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
} break;
case PROJECTOR_TYPE_LLAMA4:
{
vision_model.mm_model_proj = get_tensor(TN_MM_PROJECTOR);
vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight"));
} break;
default:
GGML_ASSERT(false && "unknown projector type");
}
@ -2328,14 +2462,6 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
return ctx_clip;
}
void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size) {
ctx_clip->load_image_size = *load_image_size; // copy
}
struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip) {
return &ctx_clip->load_image_size;
}
struct clip_image_size * clip_image_size_init() {
struct clip_image_size * load_image_size = new struct clip_image_size();
load_image_size->width = 448;
@ -2849,7 +2975,7 @@ private:
// used by llava 1.6 with custom list of pinpoints
static clip_image_size select_best_resolution(const std::vector<int32_t> & pinpoints, const clip_image_size & original_size) {
std::vector<clip_image_size> possible_resolutions;
std::vector<clip_image_size> possible_resolutions; // TODO @ngxson : construct this inside hparams, not here
for (size_t i = 0; i < pinpoints.size(); i += 2) {
possible_resolutions.push_back(clip_image_size{pinpoints[i], pinpoints[i+1]});
}
@ -2916,12 +3042,6 @@ private:
}
};
// TODO @ngxson : decprecate the load_image_size singleton pattern
int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) {
const auto inst = llava_uhd::get_slice_instructions(ctx_clip, ctx_clip->load_image_size);
return inst.grid_size.width;
}
// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
// res_imgs memory is being allocated here, previous allocations will be freed if found
bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) {
@ -2943,9 +3063,12 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std);
res_imgs->entries.push_back(std::move(res));
}
res_imgs->grid_x = inst.grid_size.width;
res_imgs->grid_y = inst.grid_size.height;
return true;
}
else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
} else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
clip_image_u8 resized;
auto patch_size = params.patch_size * 2;
auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, patch_size, params.image_size);
@ -2971,8 +3094,8 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
res_imgs->entries.push_back(std::move(img_f32));
return true;
}
else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
} else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
clip_image_u8 resized_image;
auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size);
image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height);
@ -2980,6 +3103,22 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
res_imgs->entries.push_back(std::move(img_f32));
return true;
} else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
GGML_ASSERT(!params.image_grid_pinpoints.empty());
auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
std::vector<clip_image_u8_ptr> imgs = llava_uhd::slice_image(img, inst);
for (size_t i = 0; i < imgs.size(); ++i) {
clip_image_f32_ptr res(clip_image_f32_init());
normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std);
res_imgs->entries.push_back(std::move(res));
}
res_imgs->grid_x = inst.grid_size.width;
res_imgs->grid_y = inst.grid_size.height;
return true;
}
// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
@ -3098,6 +3237,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
const auto & params = ctx->vision_model.hparams;
int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
int scale_factor = ctx->vision_model.hparams.proj_scale_factor;
if (ctx->proj_type == PROJECTOR_TYPE_LDP
|| ctx->proj_type == PROJECTOR_TYPE_LDPV2
@ -3136,6 +3276,8 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1);
int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1);
n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
} else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
n_patches /= (scale_factor * scale_factor);
}
return n_patches;
@ -3247,6 +3389,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
// build the inference graph
ctx->debug_print_tensors.clear();
ggml_backend_sched_reset(ctx->sched.get());
ggml_cgraph * gf = clip_image_build_graph(ctx, imgs);
ggml_backend_sched_alloc_graph(ctx->sched.get(), gf);
@ -3261,8 +3404,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
const int patch_size = hparams.patch_size;
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
const int n_pos = num_patches + (model.class_embedding ? 1 : 0);
const int pos_w = ctx->load_image_size.width / patch_size;
const int pos_h = ctx->load_image_size.height / patch_size;
const int pos_w = image_size_width / patch_size;
const int pos_h = image_size_height / patch_size;
const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl
@ -3528,6 +3671,23 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
{
// do nothing
} break;
case PROJECTOR_TYPE_LLAMA4:
{
// set the 2D positions
int n_patches_per_col = image_size_width / patch_size;
std::vector<int> pos_data(num_patches + 1, 0); // +1 for the [CLS] token
// last pos is always kept 0, it's for CLS
// dimension H
for (int i = 0; i < num_patches; i++) {
pos_data[i] = (i / n_patches_per_col) + 1;
}
set_input_i32("pos_h", pos_data);
// dimension W
for (int i = 0; i < num_patches; i++) {
pos_data[i] = (i % n_patches_per_col) + 1;
}
set_input_i32("pos_w", pos_data);
} break;
default:
GGML_ABORT("Unknown projector type");
}
@ -3548,6 +3708,18 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
return false;
}
// print debug nodes
if (ctx->debug_graph) {
LOG_INF("\n\n---\n\n");
LOG_INF("\n\nDebug graph:\n\n");
for (ggml_tensor * t : ctx->debug_print_tensors) {
std::vector<uint8_t> data(ggml_nbytes(t));
ggml_backend_tensor_get(t, data.data(), 0, ggml_nbytes(t));
print_tensor_shape(t);
print_tensor_data(t, data.data(), 3);
}
}
// the last node is the embedding tensor
ggml_tensor * embeddings = ggml_graph_node(gf, -1);
@ -3596,6 +3768,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->vision_model.projection->ne[1];
case PROJECTOR_TYPE_INTERNVL:
return ctx->vision_model.mm_3_w->ne[1];
case PROJECTOR_TYPE_LLAMA4:
return ctx->vision_model.mm_model_proj->ne[1];
default:
GGML_ABORT("Unknown projector type");
}