diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 4da1e502..de736c7d 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -323,25 +323,17 @@ class File { return 0; } - std::string read_all(const std::string & filename){ - open(filename, "r"); - lock(); - if (!file) { - printe("Error opening file '%s': %s", filename.c_str(), strerror(errno)); - return ""; - } - + std::string to_string() { fseek(file, 0, SEEK_END); - size_t size = ftell(file); + const size_t size = ftell(file); fseek(file, 0, SEEK_SET); - std::string out; out.resize(size); - size_t read_size = fread(&out[0], 1, size, file); + const size_t read_size = fread(&out[0], 1, size, file); if (read_size != size) { - printe("Error reading file '%s': %s", filename.c_str(), strerror(errno)); - return ""; + printe("Error reading file: %s", strerror(errno)); } + return out; } @@ -1098,59 +1090,66 @@ static int get_user_input(std::string & user_input, const std::string & user) { // Reads a chat template file to be used static std::string read_chat_template_file(const std::string & chat_template_file) { - if(chat_template_file.empty()){ - return ""; - } - File file; - std::string chat_template = ""; - chat_template = file.read_all(chat_template_file); - if(chat_template.empty()){ + if (!file.open(chat_template_file, "r")) { printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno)); return ""; } - return chat_template; + + return file.to_string(); +} + +static int process_user_message(const Opt & opt, const std::string & user_input, LlamaData & llama_data, + const common_chat_templates_ptr & chat_templates, int & prev_len, + const bool stdout_a_terminal) { + add_message("user", opt.user.empty() ? user_input : opt.user, llama_data); + int new_len; + if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, opt.use_jinja) < 0) { + return 1; + } + + std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len); + std::string response; + if (generate_response(llama_data, prompt, response, stdout_a_terminal)) { + return 1; + } + + if (!opt.user.empty()) { + return 2; + } + + add_message("assistant", response, llama_data); + if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, opt.use_jinja) < 0) { + return 1; + } + + return 0; } // Main chat loop function -static int chat_loop(LlamaData & llama_data, const std::string & user, const std::string & chat_template_file, bool use_jinja) { +static int chat_loop(LlamaData & llama_data, const Opt & opt) { int prev_len = 0; llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); - - std::string chat_template = ""; - if(!chat_template_file.empty()){ - chat_template = read_chat_template_file(chat_template_file); + std::string chat_template; + if (!opt.chat_template_file.empty()) { + chat_template = read_chat_template_file(opt.chat_template_file); } - auto chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template.empty() ? nullptr : chat_template); + common_chat_templates_ptr chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template); static const bool stdout_a_terminal = is_stdout_a_terminal(); while (true) { // Get user input std::string user_input; - if (get_user_input(user_input, user) == 1) { + if (get_user_input(user_input, opt.user) == 1) { return 0; } - add_message("user", user.empty() ? user_input : user, llama_data); - int new_len; - if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, use_jinja) < 0) { + const int ret = process_user_message(opt, user_input, llama_data, chat_templates, prev_len, stdout_a_terminal); + if (ret == 1) { return 1; - } - - std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len); - std::string response; - if (generate_response(llama_data, prompt, response, stdout_a_terminal)) { - return 1; - } - - if (!user.empty()) { + } else if (ret == 2) { break; } - - add_message("assistant", response, llama_data); - if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, use_jinja) < 0) { - return 1; - } } return 0; @@ -1208,7 +1207,7 @@ int main(int argc, const char ** argv) { return 1; } - if (chat_loop(llama_data, opt.user, opt.chat_template_file, opt.use_jinja)) { + if (chat_loop(llama_data, opt)) { return 1; }