llama-mtmd-cli: Sigint rework in mtmd vision example (#13080)
* Sigint rework in mtmd vision example * Applied suggestions on mtmd-cli PR * Forgot to invert one of the conditions * Update examples/llava/mtmd-cli.cpp * Removed redundant exit check --------- Co-authored-by: pl752 <maximpl752@gmail.com> Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
This commit is contained in:
parent
ecda2ec4b3
commit
5630406959
1 changed files with 24 additions and 7 deletions
|
@ -24,7 +24,9 @@
|
||||||
#include <signal.h>
|
#include <signal.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
static bool g_is_generating = false;
|
// volatile, because of signal being an interrupt
|
||||||
|
static volatile bool g_is_generating = false;
|
||||||
|
static volatile bool g_is_interrupted = false;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Please note that this is NOT a production-ready stuff.
|
* Please note that this is NOT a production-ready stuff.
|
||||||
|
@ -50,8 +52,10 @@ static void sigint_handler(int signo) {
|
||||||
g_is_generating = false;
|
g_is_generating = false;
|
||||||
} else {
|
} else {
|
||||||
console::cleanup();
|
console::cleanup();
|
||||||
LOG("\nInterrupted by user\n");
|
if (g_is_interrupted) {
|
||||||
_exit(130);
|
_exit(1);
|
||||||
|
}
|
||||||
|
g_is_interrupted = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -167,7 +171,7 @@ struct decode_embd_batch {
|
||||||
static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
|
static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
|
||||||
llama_tokens generated_tokens;
|
llama_tokens generated_tokens;
|
||||||
for (int i = 0; i < n_predict; i++) {
|
for (int i = 0; i < n_predict; i++) {
|
||||||
if (i > n_predict || !g_is_generating) {
|
if (i > n_predict || !g_is_generating || g_is_interrupted) {
|
||||||
printf("\n");
|
printf("\n");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -184,6 +188,11 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int
|
||||||
printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
|
printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
|
||||||
|
if (g_is_interrupted) {
|
||||||
|
printf("\n");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
// eval the token
|
// eval the token
|
||||||
common_batch_clear(ctx.batch);
|
common_batch_clear(ctx.batch);
|
||||||
common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true);
|
common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true);
|
||||||
|
@ -219,6 +228,9 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
|
||||||
text.add_special = add_bos;
|
text.add_special = add_bos;
|
||||||
text.parse_special = true;
|
text.parse_special = true;
|
||||||
mtmd_input_chunks chunks;
|
mtmd_input_chunks chunks;
|
||||||
|
|
||||||
|
if (g_is_interrupted) return 0;
|
||||||
|
|
||||||
int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps);
|
int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps);
|
||||||
if (res != 0) {
|
if (res != 0) {
|
||||||
LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
|
LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
|
||||||
|
@ -276,6 +288,8 @@ int main(int argc, char ** argv) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (g_is_interrupted) return 130;
|
||||||
|
|
||||||
if (is_single_turn) {
|
if (is_single_turn) {
|
||||||
g_is_generating = true;
|
g_is_generating = true;
|
||||||
if (params.prompt.find("<__image__>") == std::string::npos) {
|
if (params.prompt.find("<__image__>") == std::string::npos) {
|
||||||
|
@ -287,7 +301,7 @@ int main(int argc, char ** argv) {
|
||||||
if (eval_message(ctx, msg, params.image, true)) {
|
if (eval_message(ctx, msg, params.image, true)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
if (generate_response(ctx, smpl, n_predict)) {
|
if (!g_is_interrupted && generate_response(ctx, smpl, n_predict)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -302,12 +316,13 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<std::string> images_fname;
|
std::vector<std::string> images_fname;
|
||||||
std::string content;
|
std::string content;
|
||||||
|
|
||||||
while (true) {
|
while (!g_is_interrupted) {
|
||||||
g_is_generating = false;
|
g_is_generating = false;
|
||||||
LOG("\n> ");
|
LOG("\n> ");
|
||||||
console::set_display(console::user_input);
|
console::set_display(console::user_input);
|
||||||
std::string line;
|
std::string line;
|
||||||
console::readline(line, false);
|
console::readline(line, false);
|
||||||
|
if (g_is_interrupted) break;
|
||||||
console::set_display(console::reset);
|
console::set_display(console::reset);
|
||||||
line = string_strip(line);
|
line = string_strip(line);
|
||||||
if (line.empty()) {
|
if (line.empty()) {
|
||||||
|
@ -335,6 +350,7 @@ int main(int argc, char ** argv) {
|
||||||
msg.role = "user";
|
msg.role = "user";
|
||||||
msg.content = content;
|
msg.content = content;
|
||||||
int ret = eval_message(ctx, msg, images_fname, is_first_msg);
|
int ret = eval_message(ctx, msg, images_fname, is_first_msg);
|
||||||
|
if (g_is_interrupted) break;
|
||||||
if (ret == 2) {
|
if (ret == 2) {
|
||||||
// non-fatal error
|
// non-fatal error
|
||||||
images_fname.clear();
|
images_fname.clear();
|
||||||
|
@ -352,6 +368,7 @@ int main(int argc, char ** argv) {
|
||||||
is_first_msg = false;
|
is_first_msg = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (g_is_interrupted) LOG("\nInterrupted by user\n");
|
||||||
llama_perf_context_print(ctx.lctx);
|
llama_perf_context_print(ctx.lctx);
|
||||||
return 0;
|
return g_is_interrupted ? 130 : 0;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue