run : add --chat-template-file (#11961)
Relates to: https://github.com/ggml-org/llama.cpp/issues/11178 Added --chat-template-file CLI option to llama-run. If specified, the file will be read and the content passed for overwriting the chat template of the model to common_chat_templates_from_model. Signed-off-by: Michael Engel <mengel@redhat.com>
This commit is contained in:
parent
d04e7163c8
commit
0d559580a0
1 changed files with 70 additions and 5 deletions
|
@ -113,6 +113,7 @@ class Opt {
|
||||||
llama_context_params ctx_params;
|
llama_context_params ctx_params;
|
||||||
llama_model_params model_params;
|
llama_model_params model_params;
|
||||||
std::string model_;
|
std::string model_;
|
||||||
|
std::string chat_template_file;
|
||||||
std::string user;
|
std::string user;
|
||||||
bool use_jinja = false;
|
bool use_jinja = false;
|
||||||
int context_size = -1, ngl = -1;
|
int context_size = -1, ngl = -1;
|
||||||
|
@ -148,6 +149,16 @@ class Opt {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int handle_option_with_value(int argc, const char ** argv, int & i, std::string & option_value) {
|
||||||
|
if (i + 1 >= argc) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
option_value = argv[++i];
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
int parse(int argc, const char ** argv) {
|
int parse(int argc, const char ** argv) {
|
||||||
bool options_parsing = true;
|
bool options_parsing = true;
|
||||||
for (int i = 1, positional_args_i = 0; i < argc; ++i) {
|
for (int i = 1, positional_args_i = 0; i < argc; ++i) {
|
||||||
|
@ -169,6 +180,11 @@ class Opt {
|
||||||
verbose = true;
|
verbose = true;
|
||||||
} else if (options_parsing && strcmp(argv[i], "--jinja") == 0) {
|
} else if (options_parsing && strcmp(argv[i], "--jinja") == 0) {
|
||||||
use_jinja = true;
|
use_jinja = true;
|
||||||
|
} else if (options_parsing && strcmp(argv[i], "--chat-template-file") == 0){
|
||||||
|
if (handle_option_with_value(argc, argv, i, chat_template_file) == 1) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
use_jinja = true;
|
||||||
} else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
|
} else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
|
||||||
help = true;
|
help = true;
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -207,6 +223,11 @@ class Opt {
|
||||||
"Options:\n"
|
"Options:\n"
|
||||||
" -c, --context-size <value>\n"
|
" -c, --context-size <value>\n"
|
||||||
" Context size (default: %d)\n"
|
" Context size (default: %d)\n"
|
||||||
|
" --chat-template-file <path>\n"
|
||||||
|
" Path to the file containing the chat template to use with the model.\n"
|
||||||
|
" Only supports jinja templates and implicitly sets the --jinja flag.\n"
|
||||||
|
" --jinja\n"
|
||||||
|
" Use jinja templating for the chat template of the model\n"
|
||||||
" -n, -ngl, --ngl <value>\n"
|
" -n, -ngl, --ngl <value>\n"
|
||||||
" Number of GPU layers (default: %d)\n"
|
" Number of GPU layers (default: %d)\n"
|
||||||
" --temp <value>\n"
|
" --temp <value>\n"
|
||||||
|
@ -261,13 +282,12 @@ static int get_terminal_width() {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef LLAMA_USE_CURL
|
|
||||||
class File {
|
class File {
|
||||||
public:
|
public:
|
||||||
FILE * file = nullptr;
|
FILE * file = nullptr;
|
||||||
|
|
||||||
FILE * open(const std::string & filename, const char * mode) {
|
FILE * open(const std::string & filename, const char * mode) {
|
||||||
file = fopen(filename.c_str(), mode);
|
file = ggml_fopen(filename.c_str(), mode);
|
||||||
|
|
||||||
return file;
|
return file;
|
||||||
}
|
}
|
||||||
|
@ -303,6 +323,28 @@ class File {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string read_all(const std::string & filename){
|
||||||
|
open(filename, "r");
|
||||||
|
lock();
|
||||||
|
if (!file) {
|
||||||
|
printe("Error opening file '%s': %s", filename.c_str(), strerror(errno));
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
fseek(file, 0, SEEK_END);
|
||||||
|
size_t size = ftell(file);
|
||||||
|
fseek(file, 0, SEEK_SET);
|
||||||
|
|
||||||
|
std::string out;
|
||||||
|
out.resize(size);
|
||||||
|
size_t read_size = fread(&out[0], 1, size, file);
|
||||||
|
if (read_size != size) {
|
||||||
|
printe("Error reading file '%s': %s", filename.c_str(), strerror(errno));
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
~File() {
|
~File() {
|
||||||
if (fd >= 0) {
|
if (fd >= 0) {
|
||||||
# ifdef _WIN32
|
# ifdef _WIN32
|
||||||
|
@ -327,6 +369,7 @@ class File {
|
||||||
# endif
|
# endif
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#ifdef LLAMA_USE_CURL
|
||||||
class HttpClient {
|
class HttpClient {
|
||||||
public:
|
public:
|
||||||
int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
|
int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
|
||||||
|
@ -1053,11 +1096,33 @@ static int get_user_input(std::string & user_input, const std::string & user) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reads a chat template file to be used
|
||||||
|
static std::string read_chat_template_file(const std::string & chat_template_file) {
|
||||||
|
if(chat_template_file.empty()){
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
File file;
|
||||||
|
std::string chat_template = "";
|
||||||
|
chat_template = file.read_all(chat_template_file);
|
||||||
|
if(chat_template.empty()){
|
||||||
|
printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno));
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
return chat_template;
|
||||||
|
}
|
||||||
|
|
||||||
// Main chat loop function
|
// Main chat loop function
|
||||||
static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
|
static int chat_loop(LlamaData & llama_data, const std::string & user, const std::string & chat_template_file, bool use_jinja) {
|
||||||
int prev_len = 0;
|
int prev_len = 0;
|
||||||
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
|
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
|
||||||
auto chat_templates = common_chat_templates_init(llama_data.model.get(), "");
|
|
||||||
|
std::string chat_template = "";
|
||||||
|
if(!chat_template_file.empty()){
|
||||||
|
chat_template = read_chat_template_file(chat_template_file);
|
||||||
|
}
|
||||||
|
auto chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template.empty() ? nullptr : chat_template);
|
||||||
|
|
||||||
static const bool stdout_a_terminal = is_stdout_a_terminal();
|
static const bool stdout_a_terminal = is_stdout_a_terminal();
|
||||||
while (true) {
|
while (true) {
|
||||||
// Get user input
|
// Get user input
|
||||||
|
@ -1143,7 +1208,7 @@ int main(int argc, const char ** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (chat_loop(llama_data, opt.user, opt.use_jinja)) {
|
if (chat_loop(llama_data, opt.user, opt.chat_template_file, opt.use_jinja)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue