mtmd : add ultravox audio input (#13623)

* convert ok, load ok

* warmup ok

* test

* still does not work?

* fix padding

* temporary give up

* fix merge conflict

* build_ultravox()

* rm test

* fix merge conflict

* add necessary mtmd APIs

* first working version (only 4s of audio)

* will this monster compile?

* fix compile

* please compile

* fPIC

* fix windows

* various fixes

* clean up audio_helpers

* fix conversion

* add some debug stuff

* long audio input ok

* adapt the api

* add --audio arg

* final touch UX

* add miniaudio to readme

* fix typo

* refactor kv metadata

* mtmd_default_marker()
This commit is contained in:
Xuan-Son Nguyen 2025-05-22 20:42:48 +02:00 committed by GitHub
parent ab86335760
commit 797990c4bc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 95401 additions and 259 deletions

View file

@ -35,6 +35,7 @@ struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callbac
enum ffn_op_type {
FFN_GELU,
FFN_GELU_ERF,
FFN_SILU,
FFN_GELU_QUICK,
};
@ -165,6 +166,9 @@ enum patch_merge_type {
};
struct clip_hparams {
bool has_vision = false;
bool has_audio = false;
int32_t image_size;
int32_t patch_size;
int32_t n_embd;
@ -191,6 +195,10 @@ struct clip_hparams {
int32_t attn_window_size = 0;
int32_t n_wa_pattern = 0;
int32_t spatial_merge_size = 0;
// audio
int32_t n_mel_bins = 0; // whisper preprocessor
int32_t proj_stack_factor = 0; // ultravox
};
struct clip_layer {
@ -332,6 +340,14 @@ struct clip_vision_model {
// pixtral
ggml_tensor * token_embd_img_break = nullptr;
ggml_tensor * mm_patch_merger_w = nullptr;
// ultravox / whisper encoder
ggml_tensor * conv1d_1_w = nullptr;
ggml_tensor * conv1d_1_b = nullptr;
ggml_tensor * conv1d_2_w = nullptr;
ggml_tensor * conv1d_2_b = nullptr;
ggml_tensor * mm_norm_pre_w = nullptr;
ggml_tensor * mm_norm_mid_w = nullptr;
};
struct clip_ctx {
@ -1408,6 +1424,104 @@ struct clip_graph {
return gf;
}
// whisper encoder with custom projector
ggml_cgraph * build_whisper_enc() {
const int n_frames = img.nx;
const int n_pos = n_frames / 2;
GGML_ASSERT(model.position_embeddings->ne[1] >= n_pos);
ggml_tensor * inp = build_inp_raw(1);
// conv1d block
{
// convolution + gelu
ggml_tensor * cur = ggml_conv_1d_ph(ctx0, model.conv1d_1_w, inp, 1, 1);
cur = ggml_add(ctx0, cur, model.conv1d_1_b);
cur = ggml_gelu_erf(ctx0, cur);
cur = ggml_conv_1d_ph(ctx0, model.conv1d_2_w, cur, 2, 1);
cur = ggml_add(ctx0, cur, model.conv1d_2_b);
cur = ggml_gelu_erf(ctx0, cur);
// transpose
inp = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
cb(inp, "after_conv1d", -1);
}
// sanity check (only check one layer, but it should be the same for all)
GGML_ASSERT(model.layers[0].ln_1_w && model.layers[0].ln_1_b);
GGML_ASSERT(model.layers[0].ln_2_w && model.layers[0].ln_2_b);
GGML_ASSERT(model.layers[0].q_b);
GGML_ASSERT(model.layers[0].v_b);
GGML_ASSERT(!model.layers[0].k_b); // no bias for k
GGML_ASSERT(model.post_ln_w && model.post_ln_b);
ggml_tensor * pos_embd_selected = ggml_view_2d(
ctx0, model.position_embeddings,
model.position_embeddings->ne[0], n_pos,
model.position_embeddings->nb[1], 0
);
ggml_tensor * cur = build_vit(
inp, n_pos,
NORM_TYPE_NORMAL,
hparams.ffn_op,
pos_embd_selected,
nullptr);
cb(cur, "after_transformer", -1);
// StackAudioFrames
// https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py
{
int64_t stride = n_embd * hparams.proj_stack_factor;
int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
int64_t pad = padded_len - ggml_nelements(cur);
if (pad > 0) {
cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
}
cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
ggml_row_size(cur->type, stride), 0);
}
cb(cur, "after_stacked", -1);
// UltravoxProjector
{
// pre-norm
cur = ggml_rms_norm(ctx0, cur, 1e-6);
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
// ffn in
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
// swiglu
{
int64_t split_point = cur->ne[0] / 2;
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
x1 = ggml_silu(ctx0, x1);
cur = ggml_mul(ctx0, x0, x1);
}
// mid-norm
cur = ggml_rms_norm(ctx0, cur, 1e-6);
cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w);
// ffn out
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
}
cb(cur, "projected", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
private:
//
// utility functions
@ -1562,8 +1676,8 @@ private:
return inp;
}
ggml_tensor * build_inp_raw() {
ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, 3);
ggml_tensor * build_inp_raw(int channels = 3) {
ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, channels);
ggml_set_name(inp_raw, "inp_raw");
ggml_set_input(inp_raw);
return inp_raw;
@ -1641,6 +1755,11 @@ private:
cur = ggml_gelu(ctx0, cur);
cb(cur, "ffn_gelu", il);
} break;
case FFN_GELU_ERF:
{
cur = ggml_gelu_erf(ctx0, cur);
cb(cur, "ggml_gelu_erf", il);
} break;
case FFN_GELU_QUICK:
{
cur = ggml_gelu_quick(ctx0, cur);
@ -1832,6 +1951,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
res = graph.build_llama4();
} break;
case PROJECTOR_TYPE_ULTRAVOX:
{
res = graph.build_whisper_enc();
} break;
default:
{
res = graph.build_llava();
@ -1915,18 +2038,30 @@ struct clip_model_loader {
// other hparams
{
get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false); // legacy
get_bool(KEY_HAS_AUDIO_ENC, hparams.has_audio, false);
get_bool(KEY_HAS_VISION_ENC, hparams.has_vision, false);
get_u32(KEY_N_EMBD, hparams.n_embd);
get_u32(KEY_N_HEAD, hparams.n_head);
get_u32(KEY_N_FF, hparams.n_ff);
get_u32(KEY_N_BLOCK, hparams.n_layer);
get_u32(KEY_PROJ_DIM, hparams.projection_dim);
get_f32(KEY_LAYER_NORM_EPS, hparams.eps);
get_u32(KEY_IMAGE_SIZE, hparams.image_size);
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
const char * prefix = hparams.has_vision ? "vision" : "audio";
get_u32(string_format(KEY_N_EMBD, prefix), hparams.n_embd);
get_u32(string_format(KEY_N_HEAD, prefix), hparams.n_head);
get_u32(string_format(KEY_N_FF, prefix), hparams.n_ff);
get_u32(string_format(KEY_N_BLOCK, prefix), hparams.n_layer);
get_u32(string_format(KEY_PROJ_DIM, prefix), hparams.projection_dim);
get_f32(string_format(KEY_LAYER_NORM_EPS, prefix), hparams.eps);
if (hparams.has_vision) {
get_u32(KEY_IMAGE_SIZE, hparams.image_size);
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false); // legacy
} else if (hparams.has_audio) {
get_u32(KEY_A_NUM_MEL_BINS, hparams.n_mel_bins);
} else {
throw std::runtime_error(string_format("%s: neither vision nor audio encoder is present\n", __func__));
}
// default warmup value
hparams.warmup_image_size = hparams.image_size;
@ -1964,7 +2099,7 @@ struct clip_model_loader {
}
}
{
if (hparams.has_vision) {
int idx_mean = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_MEAN);
int idx_std = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_STD);
GGML_ASSERT(idx_mean >= 0 && "image_mean not found");
@ -2050,30 +2185,43 @@ struct clip_model_loader {
isize, isize*3, // 336, 1008
};
} break;
case PROJECTOR_TYPE_ULTRAVOX:
{
get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor);
if (hparams.n_mel_bins != 128) {
throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__));
}
hparams.ffn_op = FFN_GELU_ERF;
log_ffn_op = "gelu_erf"; // temporary solution for logging
} break;
default:
break;
}
LOG_INF("%s: projector: %s\n", __func__, proj_type.c_str());
LOG_INF("%s: has_vision_encoder: %d\n", __func__, hparams.has_vision);
LOG_INF("%s: has_audio_encoder: %d\n", __func__, hparams.has_audio);
LOG_INF("%s: n_embd: %d\n", __func__, hparams.n_embd);
LOG_INF("%s: n_head: %d\n", __func__, hparams.n_head);
LOG_INF("%s: n_ff: %d\n", __func__, hparams.n_ff);
LOG_INF("%s: n_layer: %d\n", __func__, hparams.n_layer);
LOG_INF("%s: projection_dim: %d\n", __func__, hparams.projection_dim);
LOG_INF("%s: image_size: %d\n", __func__, hparams.image_size);
LOG_INF("%s: patch_size: %d\n", __func__, hparams.patch_size);
LOG_INF("\n");
LOG_INF("%s: has_llava_proj: %d\n", __func__, ctx_clip.has_llava_projector);
LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version);
LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor);
LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
LOG_INF("%s: ffn_op: %s\n", __func__, log_ffn_op.c_str());
LOG_INF("%s: projection_dim: %d\n", __func__, hparams.projection_dim);
LOG_INF("\n");
if (hparams.has_vision) {
LOG_INF("%s: image_size: %d\n", __func__, hparams.image_size);
LOG_INF("%s: patch_size: %d\n", __func__, hparams.patch_size);
LOG_INF("%s: has_llava_proj: %d\n", __func__, ctx_clip.has_llava_projector);
LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version);
LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor);
LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
} else if (hparams.has_audio) {
LOG_INF("%s: n_mel_bins: %d\n", __func__, hparams.n_mel_bins);
LOG_INF("%s: proj_stack_factor: %d\n", __func__, hparams.proj_stack_factor);
}
LOG_INF("\n");
LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
if (ctx_clip.proj_type == PROJECTOR_TYPE_LLAMA4) {
LOG_WRN("%s: llama 4 vision is known to have degraded quality: https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__);
}
}
}
@ -2082,6 +2230,9 @@ struct clip_model_loader {
std::map<std::string, size_t> tensor_offset;
std::vector<ggml_tensor *> tensors_to_load;
// TODO @ngxson : support both audio and video in the future
const char * prefix = hparams.has_audio ? "a" : "v";
// get offsets
for (int64_t i = 0; i < gguf_get_n_tensors(ctx_gguf.get()); ++i) {
const char * name = gguf_get_tensor_name(ctx_gguf.get(), i);
@ -2119,47 +2270,47 @@ struct clip_model_loader {
vision_model.class_embedding = get_tensor(TN_CLASS_EMBD, false);
vision_model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, "v", "weight"), false);
vision_model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, "v", "bias"), false);
vision_model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, prefix, "weight"), false);
vision_model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, prefix, "bias"), false);
vision_model.post_ln_w = get_tensor(string_format(TN_LN_POST, "v", "weight"), false);
vision_model.post_ln_b = get_tensor(string_format(TN_LN_POST, "v", "bias"), false);
vision_model.post_ln_w = get_tensor(string_format(TN_LN_POST, prefix, "weight"), false);
vision_model.post_ln_b = get_tensor(string_format(TN_LN_POST, prefix, "bias"), false);
vision_model.patch_bias = get_tensor(TN_PATCH_BIAS, false);
vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false);
vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false);
vision_model.position_embeddings = get_tensor(TN_POS_EMBD, false);
vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, prefix), false);
// layers
vision_model.layers.resize(hparams.n_layer);
for (int il = 0; il < hparams.n_layer; ++il) {
auto & layer = vision_model.layers[il];
layer.k_w = get_tensor(string_format(TN_ATTN_K, "v", il, "weight"));
layer.q_w = get_tensor(string_format(TN_ATTN_Q, "v", il, "weight"));
layer.v_w = get_tensor(string_format(TN_ATTN_V, "v", il, "weight"));
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight"));
layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, "v", il, "weight"), false);
layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, "v", il, "weight"), false);
layer.ln_1_w = get_tensor(string_format(TN_LN_1, "v", il, "weight"), false);
layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false);
layer.ls_1_w = get_tensor(string_format(TN_LS_1, "v", il, "weight"), false); // no bias
layer.ls_2_w = get_tensor(string_format(TN_LS_2, "v", il, "weight"), false); // no bias
layer.k_w = get_tensor(string_format(TN_ATTN_K, prefix, il, "weight"));
layer.q_w = get_tensor(string_format(TN_ATTN_Q, prefix, il, "weight"));
layer.v_w = get_tensor(string_format(TN_ATTN_V, prefix, il, "weight"));
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "weight"));
layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, prefix, il, "weight"), false);
layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, prefix, il, "weight"), false);
layer.ln_1_w = get_tensor(string_format(TN_LN_1, prefix, il, "weight"), false);
layer.ln_2_w = get_tensor(string_format(TN_LN_2, prefix, il, "weight"), false);
layer.ls_1_w = get_tensor(string_format(TN_LS_1, prefix, il, "weight"), false); // no bias
layer.ls_2_w = get_tensor(string_format(TN_LS_2, prefix, il, "weight"), false); // no bias
layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false);
layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false);
layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false);
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "bias"), false);
layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false);
layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false);
layer.k_b = get_tensor(string_format(TN_ATTN_K, prefix, il, "bias"), false);
layer.q_b = get_tensor(string_format(TN_ATTN_Q, prefix, il, "bias"), false);
layer.v_b = get_tensor(string_format(TN_ATTN_V, prefix, il, "bias"), false);
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "bias"), false);
layer.ln_1_b = get_tensor(string_format(TN_LN_1, prefix, il, "bias"), false);
layer.ln_2_b = get_tensor(string_format(TN_LN_2, prefix, il, "bias"), false);
// ffn
layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), false);
layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, "v", il, "bias"), false);
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false);
layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, prefix, il, "weight"));
layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, prefix, il, "bias"), false);
layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, prefix, il, "weight"), false);
layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, prefix, il, "bias"), false);
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight"));
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"), false);
// some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here
// note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check!
@ -2301,6 +2452,17 @@ struct clip_model_loader {
vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
} break;
case PROJECTOR_TYPE_ULTRAVOX:
{
vision_model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
vision_model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
vision_model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
vision_model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
vision_model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
vision_model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
vision_model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
vision_model.mm_norm_mid_w = get_tensor(string_format(TN_MM_NORM_MID, "weight"));
} break;
case PROJECTOR_TYPE_INTERNVL:
{
vision_model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
@ -2358,13 +2520,19 @@ struct clip_model_loader {
}
void alloc_compute_meta() {
const auto & hparams = ctx_clip.vision_model.hparams;
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
// create a fake batch
clip_image_f32_batch batch;
clip_image_f32_ptr img(clip_image_f32_init());
img->nx = ctx_clip.vision_model.hparams.warmup_image_size;
img->ny = ctx_clip.vision_model.hparams.warmup_image_size;
if (hparams.has_vision) {
img->nx = hparams.warmup_image_size;
img->ny = hparams.warmup_image_size;
} else {
img->nx = 1024; // TODO @ngxson : use a better default
img->ny = hparams.n_mel_bins;
}
img->buf.resize(img->nx * img->ny * 3);
batch.entries.push_back(std::move(img));
@ -3278,6 +3446,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
} else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
n_patches /= (scale_factor * scale_factor);
} else if (ctx->proj_type == PROJECTOR_TYPE_ULTRAVOX) {
const int proj_stack_factor = ctx->vision_model.hparams.proj_stack_factor;
const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor);
n_patches = n_len / proj_stack_factor / 2;
}
return n_patches;
@ -3435,7 +3607,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
};
// set input pixel values
{
if (!imgs.is_audio) {
size_t nelem = 0;
for (const auto & img : imgs.entries) {
nelem += img->nx * img->ny * 3;
@ -3472,6 +3644,16 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
}
set_input_f32("inp_raw", inp_raw);
} else {
// audio input
GGML_ASSERT(imgs.entries.size() == 1);
const auto & mel_inp = imgs.entries[0];
const int n_step = mel_inp->nx;
const int n_mel = mel_inp->ny;
std::vector<float> inp_raw(n_step * n_mel);
std::memcpy(inp_raw.data(), mel_inp->buf.data(), n_step * n_mel * sizeof(float));
set_input_f32("inp_raw", inp_raw);
}
// set input per projector
@ -3668,6 +3850,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
case PROJECTOR_TYPE_GEMMA3:
case PROJECTOR_TYPE_IDEFICS3:
case PROJECTOR_TYPE_INTERNVL:
case PROJECTOR_TYPE_ULTRAVOX:
{
// do nothing
} break;
@ -3766,6 +3949,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->vision_model.mm_input_proj_w->ne[0];
case PROJECTOR_TYPE_IDEFICS3:
return ctx->vision_model.projection->ne[1];
case PROJECTOR_TYPE_ULTRAVOX:
return ctx->vision_model.mm_2_w->ne[1];
case PROJECTOR_TYPE_INTERNVL:
return ctx->vision_model.mm_3_w->ne[1];
case PROJECTOR_TYPE_LLAMA4:
@ -3798,6 +3983,14 @@ bool clip_is_gemma3(const struct clip_ctx * ctx) {
return ctx->proj_type == PROJECTOR_TYPE_GEMMA3;
}
bool clip_has_vision_encoder(const struct clip_ctx * ctx) {
return ctx->vision_model.hparams.has_vision;
}
bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
return ctx->vision_model.hparams.has_audio;
}
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
clip_image_f32 clip_img;
clip_img.buf.resize(h * w * 3);
@ -3818,3 +4011,14 @@ bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img,
projector_type clip_get_projector_type(const struct clip_ctx * ctx) {
return ctx->proj_type;
}
void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_frames, float * mel) {
clip_image_f32 * audio = new clip_image_f32;
audio->nx = n_frames;
audio->ny = n_mel;
audio->buf.resize(n_frames * n_mel);
std::memcpy(audio->buf.data(), mel, n_frames * n_mel * sizeof(float));
batch->entries.push_back(clip_image_f32_ptr(audio));
batch->is_audio = true;
}