mtmd : add C public API (#13184)
* init * wip * working version * add mtmd::bitmaps * add test target * rm redundant define * test: mtmd_input_chunks_free * rm outdated comment * fix merging issue * explicitly create mtmd::input_chunks * mtmd_input_chunk_copy * add clone() * add const to various places * add warning about breaking changes * helper: use mtmd_image_tokens_get_n_pos
This commit is contained in:
parent
9fdfcdaedd
commit
27aa259532
7 changed files with 730 additions and 258 deletions
|
@ -63,7 +63,7 @@ static void sigint_handler(int signo) {
|
|||
#endif
|
||||
|
||||
struct mtmd_cli_context {
|
||||
mtmd_context_ptr ctx_vision;
|
||||
mtmd::context_ptr ctx_vision;
|
||||
common_init_result llama_init;
|
||||
|
||||
llama_model * model;
|
||||
|
@ -72,7 +72,7 @@ struct mtmd_cli_context {
|
|||
llama_batch batch;
|
||||
int n_batch;
|
||||
|
||||
std::vector<mtmd_bitmap> bitmaps;
|
||||
mtmd::bitmaps bitmaps;
|
||||
|
||||
// note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
|
||||
// so here we don't need to keep track of chat history
|
||||
|
@ -115,12 +115,12 @@ struct mtmd_cli_context {
|
|||
|
||||
void init_vision_context(common_params & params) {
|
||||
const char * clip_path = params.mmproj.path.c_str();
|
||||
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mtmd_context_params{
|
||||
/* use_gpu */ params.mmproj_use_gpu,
|
||||
/* timings */ true,
|
||||
/* n_threads */ params.cpuparams.n_threads,
|
||||
/* verbosity */ params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO,
|
||||
}));
|
||||
mtmd_context_params mparams = mtmd_context_params_default();
|
||||
mparams.use_gpu = params.mmproj_use_gpu;
|
||||
mparams.print_timings = true;
|
||||
mparams.n_threads = params.cpuparams.n_threads;
|
||||
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
||||
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
|
||||
if (!ctx_vision.get()) {
|
||||
LOG_ERR("Failed to load vision model from %s\n", clip_path);
|
||||
exit(1);
|
||||
|
@ -139,11 +139,11 @@ struct mtmd_cli_context {
|
|||
}
|
||||
|
||||
bool load_image(const std::string & fname) {
|
||||
mtmd_bitmap bitmap;
|
||||
if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) {
|
||||
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_file(fname.c_str()));
|
||||
if (!bmp.ptr) {
|
||||
return false;
|
||||
}
|
||||
bitmaps.push_back(std::move(bitmap));
|
||||
bitmaps.entries.push_back(std::move(bmp));
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
@ -193,27 +193,40 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, bool add_
|
|||
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str());
|
||||
|
||||
mtmd_input_text text;
|
||||
text.text = formatted_chat.prompt;
|
||||
text.text = formatted_chat.prompt.c_str();
|
||||
text.add_special = add_bos;
|
||||
text.parse_special = true;
|
||||
mtmd_input_chunks chunks;
|
||||
|
||||
if (g_is_interrupted) return 0;
|
||||
|
||||
int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, ctx.bitmaps);
|
||||
mtmd::input_chunks chunks(mtmd_input_chunks_init());
|
||||
auto bitmaps_c_ptr = ctx.bitmaps.c_ptr();
|
||||
int32_t res = mtmd_tokenize(ctx.ctx_vision.get(),
|
||||
chunks.ptr.get(), // output
|
||||
&text, // text
|
||||
bitmaps_c_ptr.data(),
|
||||
bitmaps_c_ptr.size());
|
||||
if (res != 0) {
|
||||
LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
|
||||
return 1;
|
||||
}
|
||||
|
||||
ctx.bitmaps.clear();
|
||||
ctx.bitmaps.entries.clear();
|
||||
|
||||
if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) {
|
||||
llama_pos new_n_past;
|
||||
if (mtmd_helper_eval_chunks(ctx.ctx_vision.get(),
|
||||
ctx.lctx, // lctx
|
||||
chunks.ptr.get(), // chunks
|
||||
ctx.n_past, // n_past
|
||||
0, // seq_id
|
||||
ctx.n_batch, // n_batch
|
||||
true, // logits_last
|
||||
&new_n_past)) {
|
||||
LOG_ERR("Unable to eval prompt\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
ctx.n_past += mtmd_helper_get_n_pos(chunks);
|
||||
ctx.n_past = new_n_past;
|
||||
|
||||
LOG("\n");
|
||||
|
||||
|
@ -246,7 +259,7 @@ int main(int argc, char ** argv) {
|
|||
struct common_sampler * smpl = common_sampler_init(ctx.model, params.sampling);
|
||||
int n_predict = params.n_predict < 0 ? INT_MAX : params.n_predict;
|
||||
|
||||
// ctrl+C handling
|
||||
// Ctrl+C handling
|
||||
{
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
struct sigaction sigint_action;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue