Some llama-run cleanups (#11973)

Use consolidated open function call from File class. Change
read_all to to_string(). Remove exclusive locking, the intent for
that lock is to avoid multiple processes writing to the same file,
it's not an issue for readers, although we may want to consider
adding a shared lock. Remove passing nullptr as reference,
references are never supposed to be null. clang-format the code
for consistent styling.

Signed-off-by: Eric Curtin <ecurtin@redhat.com>
This commit is contained in:
Eric Curtin 2025-02-23 13:14:32 +00:00 committed by GitHub
parent af7747c95a
commit f777a73e18
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -323,25 +323,17 @@ class File {
return 0; return 0;
} }
std::string read_all(const std::string & filename){ std::string to_string() {
open(filename, "r");
lock();
if (!file) {
printe("Error opening file '%s': %s", filename.c_str(), strerror(errno));
return "";
}
fseek(file, 0, SEEK_END); fseek(file, 0, SEEK_END);
size_t size = ftell(file); const size_t size = ftell(file);
fseek(file, 0, SEEK_SET); fseek(file, 0, SEEK_SET);
std::string out; std::string out;
out.resize(size); out.resize(size);
size_t read_size = fread(&out[0], 1, size, file); const size_t read_size = fread(&out[0], 1, size, file);
if (read_size != size) { if (read_size != size) {
printe("Error reading file '%s': %s", filename.c_str(), strerror(errno)); printe("Error reading file: %s", strerror(errno));
return "";
} }
return out; return out;
} }
@ -1098,42 +1090,21 @@ static int get_user_input(std::string & user_input, const std::string & user) {
// Reads a chat template file to be used // Reads a chat template file to be used
static std::string read_chat_template_file(const std::string & chat_template_file) { static std::string read_chat_template_file(const std::string & chat_template_file) {
if(chat_template_file.empty()){
return "";
}
File file; File file;
std::string chat_template = ""; if (!file.open(chat_template_file, "r")) {
chat_template = file.read_all(chat_template_file);
if(chat_template.empty()){
printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno)); printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno));
return ""; return "";
} }
return chat_template;
return file.to_string();
} }
// Main chat loop function static int process_user_message(const Opt & opt, const std::string & user_input, LlamaData & llama_data,
static int chat_loop(LlamaData & llama_data, const std::string & user, const std::string & chat_template_file, bool use_jinja) { const common_chat_templates_ptr & chat_templates, int & prev_len,
int prev_len = 0; const bool stdout_a_terminal) {
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); add_message("user", opt.user.empty() ? user_input : opt.user, llama_data);
std::string chat_template = "";
if(!chat_template_file.empty()){
chat_template = read_chat_template_file(chat_template_file);
}
auto chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template.empty() ? nullptr : chat_template);
static const bool stdout_a_terminal = is_stdout_a_terminal();
while (true) {
// Get user input
std::string user_input;
if (get_user_input(user_input, user) == 1) {
return 0;
}
add_message("user", user.empty() ? user_input : user, llama_data);
int new_len; int new_len;
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, use_jinja) < 0) { if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, opt.use_jinja) < 0) {
return 1; return 1;
} }
@ -1143,14 +1114,42 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, const std
return 1; return 1;
} }
if (!user.empty()) { if (!opt.user.empty()) {
break; return 2;
} }
add_message("assistant", response, llama_data); add_message("assistant", response, llama_data);
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, use_jinja) < 0) { if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, opt.use_jinja) < 0) {
return 1; return 1;
} }
return 0;
}
// Main chat loop function
static int chat_loop(LlamaData & llama_data, const Opt & opt) {
int prev_len = 0;
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
std::string chat_template;
if (!opt.chat_template_file.empty()) {
chat_template = read_chat_template_file(opt.chat_template_file);
}
common_chat_templates_ptr chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template);
static const bool stdout_a_terminal = is_stdout_a_terminal();
while (true) {
// Get user input
std::string user_input;
if (get_user_input(user_input, opt.user) == 1) {
return 0;
}
const int ret = process_user_message(opt, user_input, llama_data, chat_templates, prev_len, stdout_a_terminal);
if (ret == 1) {
return 1;
} else if (ret == 2) {
break;
}
} }
return 0; return 0;
@ -1208,7 +1207,7 @@ int main(int argc, const char ** argv) {
return 1; return 1;
} }
if (chat_loop(llama_data, opt.user, opt.chat_template_file, opt.use_jinja)) { if (chat_loop(llama_data, opt)) {
return 1; return 1;
} }