Some llama-run cleanups (#11973)
Use consolidated open function call from File class. Change read_all to to_string(). Remove exclusive locking, the intent for that lock is to avoid multiple processes writing to the same file, it's not an issue for readers, although we may want to consider adding a shared lock. Remove passing nullptr as reference, references are never supposed to be null. clang-format the code for consistent styling. Signed-off-by: Eric Curtin <ecurtin@redhat.com>
This commit is contained in:
parent
af7747c95a
commit
f777a73e18
1 changed files with 45 additions and 46 deletions
|
@ -323,25 +323,17 @@ class File {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string read_all(const std::string & filename){
|
std::string to_string() {
|
||||||
open(filename, "r");
|
|
||||||
lock();
|
|
||||||
if (!file) {
|
|
||||||
printe("Error opening file '%s': %s", filename.c_str(), strerror(errno));
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
|
|
||||||
fseek(file, 0, SEEK_END);
|
fseek(file, 0, SEEK_END);
|
||||||
size_t size = ftell(file);
|
const size_t size = ftell(file);
|
||||||
fseek(file, 0, SEEK_SET);
|
fseek(file, 0, SEEK_SET);
|
||||||
|
|
||||||
std::string out;
|
std::string out;
|
||||||
out.resize(size);
|
out.resize(size);
|
||||||
size_t read_size = fread(&out[0], 1, size, file);
|
const size_t read_size = fread(&out[0], 1, size, file);
|
||||||
if (read_size != size) {
|
if (read_size != size) {
|
||||||
printe("Error reading file '%s': %s", filename.c_str(), strerror(errno));
|
printe("Error reading file: %s", strerror(errno));
|
||||||
return "";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1098,59 +1090,66 @@ static int get_user_input(std::string & user_input, const std::string & user) {
|
||||||
|
|
||||||
// Reads a chat template file to be used
|
// Reads a chat template file to be used
|
||||||
static std::string read_chat_template_file(const std::string & chat_template_file) {
|
static std::string read_chat_template_file(const std::string & chat_template_file) {
|
||||||
if(chat_template_file.empty()){
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
|
|
||||||
File file;
|
File file;
|
||||||
std::string chat_template = "";
|
if (!file.open(chat_template_file, "r")) {
|
||||||
chat_template = file.read_all(chat_template_file);
|
|
||||||
if(chat_template.empty()){
|
|
||||||
printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno));
|
printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno));
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
return chat_template;
|
|
||||||
|
return file.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
static int process_user_message(const Opt & opt, const std::string & user_input, LlamaData & llama_data,
|
||||||
|
const common_chat_templates_ptr & chat_templates, int & prev_len,
|
||||||
|
const bool stdout_a_terminal) {
|
||||||
|
add_message("user", opt.user.empty() ? user_input : opt.user, llama_data);
|
||||||
|
int new_len;
|
||||||
|
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, opt.use_jinja) < 0) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
|
||||||
|
std::string response;
|
||||||
|
if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!opt.user.empty()) {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
add_message("assistant", response, llama_data);
|
||||||
|
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, opt.use_jinja) < 0) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Main chat loop function
|
// Main chat loop function
|
||||||
static int chat_loop(LlamaData & llama_data, const std::string & user, const std::string & chat_template_file, bool use_jinja) {
|
static int chat_loop(LlamaData & llama_data, const Opt & opt) {
|
||||||
int prev_len = 0;
|
int prev_len = 0;
|
||||||
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
|
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
|
||||||
|
std::string chat_template;
|
||||||
std::string chat_template = "";
|
if (!opt.chat_template_file.empty()) {
|
||||||
if(!chat_template_file.empty()){
|
chat_template = read_chat_template_file(opt.chat_template_file);
|
||||||
chat_template = read_chat_template_file(chat_template_file);
|
|
||||||
}
|
}
|
||||||
auto chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template.empty() ? nullptr : chat_template);
|
|
||||||
|
|
||||||
|
common_chat_templates_ptr chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template);
|
||||||
static const bool stdout_a_terminal = is_stdout_a_terminal();
|
static const bool stdout_a_terminal = is_stdout_a_terminal();
|
||||||
while (true) {
|
while (true) {
|
||||||
// Get user input
|
// Get user input
|
||||||
std::string user_input;
|
std::string user_input;
|
||||||
if (get_user_input(user_input, user) == 1) {
|
if (get_user_input(user_input, opt.user) == 1) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
add_message("user", user.empty() ? user_input : user, llama_data);
|
const int ret = process_user_message(opt, user_input, llama_data, chat_templates, prev_len, stdout_a_terminal);
|
||||||
int new_len;
|
if (ret == 1) {
|
||||||
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, use_jinja) < 0) {
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
} else if (ret == 2) {
|
||||||
|
|
||||||
std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
|
|
||||||
std::string response;
|
|
||||||
if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!user.empty()) {
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
add_message("assistant", response, llama_data);
|
|
||||||
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, use_jinja) < 0) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -1208,7 +1207,7 @@ int main(int argc, const char ** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (chat_loop(llama_data, opt.user, opt.chat_template_file, opt.use_jinja)) {
|
if (chat_loop(llama_data, opt)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue