mtmd : Expose helper_decode_image_chunk (#13366)
* mtmd: Expose helper_decode_image, output_embd_copy, image_tokens_copy/free * Slim down * Cleanups
This commit is contained in:
parent
ee01d71e58
commit
f05a6d71a0
2 changed files with 90 additions and 51 deletions
|
@ -580,6 +580,79 @@ struct decode_embd_batch {
|
|||
}
|
||||
};
|
||||
|
||||
// Helper function for decoding an image whose embeddings have already been calculated
|
||||
int32_t mtmd_helper_decode_image_chunk(
|
||||
mtmd_context * ctx,
|
||||
struct llama_context * lctx,
|
||||
const mtmd_input_chunk * chunk,
|
||||
float * encoded_embd,
|
||||
llama_pos n_past,
|
||||
llama_seq_id seq_id,
|
||||
int32_t n_batch,
|
||||
llama_pos * new_n_past) {
|
||||
if (mtmd_input_chunk_get_type(chunk) != MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||
LOG_ERR("failed to decode image chunk: input chunk not of image type\n");
|
||||
return -1;
|
||||
}
|
||||
const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
|
||||
if (!image_tokens) {
|
||||
LOG_ERR("failed to decode image chunk: image tokens are null\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
|
||||
int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;
|
||||
|
||||
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
|
||||
int32_t i_batch = 0;
|
||||
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
|
||||
decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
|
||||
|
||||
const int nx = mtmd_image_tokens_get_nx(image_tokens);
|
||||
const int ny = mtmd_image_tokens_get_ny(image_tokens);
|
||||
|
||||
if (mtmd_decode_use_mrope(ctx)) {
|
||||
batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
|
||||
} else {
|
||||
batch_embd.set_position_normal(n_past, seq_id);
|
||||
}
|
||||
|
||||
if (mtmd_decode_use_non_causal(ctx)) {
|
||||
llama_set_causal_attn(lctx, false);
|
||||
// TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
|
||||
}
|
||||
|
||||
while (i_batch < n_img_batches) { // split into batches
|
||||
int pos_offset = i_batch*n_batch;
|
||||
int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
|
||||
llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
|
||||
|
||||
LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
int32_t ret = llama_decode(lctx, batch_embd_view);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("failed to decode image\n");
|
||||
llama_set_causal_attn(lctx, true); // restore causal attn
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (ctx->print_timings) {
|
||||
LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
|
||||
}
|
||||
|
||||
i_batch++;
|
||||
}
|
||||
|
||||
n_past += mtmd_image_tokens_get_n_pos(image_tokens);
|
||||
*new_n_past = n_past;
|
||||
|
||||
if (mtmd_decode_use_non_causal(ctx)) {
|
||||
llama_set_causal_attn(lctx, true);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
|
||||
struct llama_context * lctx,
|
||||
const mtmd_input_chunk * chunk,
|
||||
|
@ -591,8 +664,6 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
|
|||
int32_t ret;
|
||||
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
|
||||
auto chunk_type = mtmd_input_chunk_get_type(chunk);
|
||||
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
|
||||
int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;
|
||||
|
||||
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||
size_t n_tokens;
|
||||
|
@ -637,57 +708,13 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
|
|||
if (ctx->print_timings) {
|
||||
LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
|
||||
}
|
||||
|
||||
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
|
||||
int32_t i_batch = 0;
|
||||
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
|
||||
float * embd = mtmd_get_output_embd(ctx);
|
||||
decode_embd_batch batch_embd(embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
|
||||
|
||||
const int nx = mtmd_image_tokens_get_nx(image_tokens);
|
||||
const int ny = mtmd_image_tokens_get_ny(image_tokens);
|
||||
|
||||
if (mtmd_decode_use_mrope(ctx)) {
|
||||
batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
|
||||
} else {
|
||||
batch_embd.set_position_normal(n_past, seq_id);
|
||||
ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("failed to decode image\n");
|
||||
llama_batch_free(text_batch);
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (mtmd_decode_use_non_causal(ctx)) {
|
||||
llama_set_causal_attn(lctx, false);
|
||||
// TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
|
||||
}
|
||||
|
||||
while (i_batch < n_img_batches) { // split into batches
|
||||
int pos_offset = i_batch*n_batch;
|
||||
int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
|
||||
llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
|
||||
|
||||
LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
ret = llama_decode(lctx, batch_embd_view);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("failed to decode image\n");
|
||||
llama_set_causal_attn(lctx, true); // restore causal attn
|
||||
llama_batch_free(text_batch);
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (ctx->print_timings) {
|
||||
LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
|
||||
}
|
||||
|
||||
i_batch++;
|
||||
}
|
||||
|
||||
n_past += mtmd_image_tokens_get_n_pos(image_tokens);
|
||||
*new_n_past = n_past;
|
||||
|
||||
if (mtmd_decode_use_non_causal(ctx)) {
|
||||
llama_set_causal_attn(lctx, true);
|
||||
}
|
||||
|
||||
} else {
|
||||
GGML_ABORT("chunk type not supported");
|
||||
}
|
||||
|
|
|
@ -231,6 +231,18 @@ MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
|
|||
bool logits_last,
|
||||
llama_pos * new_n_past);
|
||||
|
||||
// helper function to decode an image whose embeddings have already been calculated
|
||||
// this helper will handle batching and pre/post decoding setup (for ex. gemma 3 requires non-causal attention)
|
||||
// ret 0 on success, -1 on chunk not being a valid image chunk, 1 on decode failure
|
||||
MTMD_API int32_t mtmd_helper_decode_image_chunk(mtmd_context * ctx,
|
||||
struct llama_context * lctx,
|
||||
const mtmd_input_chunk * chunk,
|
||||
float * encoded_embd,
|
||||
llama_pos n_past,
|
||||
llama_seq_id seq_id,
|
||||
int32_t n_batch,
|
||||
llama_pos * new_n_past);
|
||||
|
||||
/////////////////////////////////////////
|
||||
|
||||
// test function, to be used in test-mtmd-c-api.c
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue