tool-call
: fix Qwen 2.5 Coder support, add micro benchmarks, support trigger patterns for lazy grammars (#12034)
* sampler: turn lazy grammar trigger words to regexes * add scripts/tool_bench.sh & .py * constrain llama json output regardless of function name if matches at beginning * update relaxed newline space rule in grammar tests * support add_generation_prompt query parameter (useful for /apply_template) * Update src/llama-grammar.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
fa31c438e0
commit
669912d9a5
26 changed files with 1314 additions and 408 deletions
437
common/chat.cpp
437
common/chat.cpp
|
@ -449,12 +449,6 @@ std::string common_chat_format_name(common_chat_format format) {
|
|||
}
|
||||
}
|
||||
|
||||
const common_grammar_options grammar_options {
|
||||
/* .dotall = */ false,
|
||||
/* .compact_spaces = */ false,
|
||||
// /* .compact_spaces = */ true,
|
||||
};
|
||||
|
||||
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
|
||||
// // https://json.nlohmann.me/features/parsing/sax_interface/
|
||||
struct json_error_locator : public nlohmann::json_sax<json> {
|
||||
|
@ -500,6 +494,34 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
|
|||
}
|
||||
}
|
||||
|
||||
static bool parse_literal(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) {
|
||||
auto expected_it = expected.begin();
|
||||
auto tmp_it = it;
|
||||
while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) {
|
||||
++tmp_it;
|
||||
++expected_it;
|
||||
}
|
||||
if (expected_it == expected.end()) {
|
||||
it = tmp_it;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static std::optional<std::smatch> parse_pattern(std::string::const_iterator & it, const std::string::const_iterator & end, const std::regex & expected) {
|
||||
std::smatch match;
|
||||
if (std::regex_match(it, end, match, expected)) {
|
||||
it = match.suffix().first;
|
||||
return match;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static void consume_spaces(std::string::const_iterator & it, const std::string::const_iterator & end) {
|
||||
while (it != end && std::isspace(*it)) {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
|
||||
|
@ -509,7 +531,8 @@ static common_chat_msg parse_json_tool_calls(
|
|||
const std::string& input,
|
||||
const std::optional<std::regex> & trigger_opt,
|
||||
const std::regex & function_regex,
|
||||
const std::regex & close_regex) {
|
||||
const std::regex & close_regex,
|
||||
bool allow_raw_python = false) {
|
||||
std::smatch match;
|
||||
|
||||
common_chat_msg result;
|
||||
|
@ -540,14 +563,19 @@ static common_chat_msg parse_json_tool_calls(
|
|||
it = rit->suffix().first;
|
||||
|
||||
json arguments;
|
||||
if (!parse_json(it, end, arguments)) {
|
||||
if (parse_json(it, end, arguments)) {
|
||||
if (!std::regex_search(it, end, match, close_regex)) {
|
||||
throw std::runtime_error("Malformed input, missing closing pattern: " + input);
|
||||
}
|
||||
it = match.suffix().first;
|
||||
result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
|
||||
} else {
|
||||
if (allow_raw_python && name == "python") {
|
||||
result.tool_calls.push_back({name, json({{"code", std::string(it, end)}}).dump(), /* id= */ ""});
|
||||
break;
|
||||
}
|
||||
throw std::runtime_error("Failed to parse json tool call arguments: " + input);
|
||||
}
|
||||
if (!std::regex_search(it, end, match, close_regex)) {
|
||||
throw std::runtime_error("Malformed input, missing closing pattern: " + input);
|
||||
}
|
||||
it = match.suffix().first;
|
||||
result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
|
||||
}
|
||||
|
||||
if (!result.tool_calls.empty()) {
|
||||
|
@ -559,29 +587,29 @@ static common_chat_msg parse_json_tool_calls(
|
|||
return result;
|
||||
}
|
||||
|
||||
static common_chat_tool_call process_tool_call(const json & tool_call) {
|
||||
const auto & arguments = tool_call.at("arguments");
|
||||
return {
|
||||
/* .name = */ tool_call.at("name"),
|
||||
/* .arguments = */ arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||
/* .id = */ tool_call.contains("id") ? tool_call.at("id") : "",
|
||||
};
|
||||
}
|
||||
static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
|
||||
auto content_end = input.find(prefix);
|
||||
size_t tc_start = std::string::npos;
|
||||
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
const auto process_tool_calls = [&](const json & tool_calls) {
|
||||
for (const auto & tool_call : tool_calls) {
|
||||
const auto & arguments = tool_call.at("arguments");
|
||||
result.tool_calls.push_back({
|
||||
tool_call.at("name"),
|
||||
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||
tool_call.contains("id") ? tool_call.at("id") : "",
|
||||
});
|
||||
}
|
||||
};
|
||||
if (content_end == std::string::npos) {
|
||||
result.content = input;
|
||||
} else {
|
||||
tc_start = content_end + prefix.size() - rstrip_prefix;
|
||||
result.content = input.substr(0, content_end);
|
||||
auto tool_calls = json::parse(input.substr(tc_start));
|
||||
process_tool_calls(tool_calls);
|
||||
for (const auto & tool_call : tool_calls) {
|
||||
result.tool_calls.emplace_back(process_tool_call(tool_call));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
@ -700,7 +728,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
|
|||
data.grammar_lazy = false;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
builder.add_schema("root", schema);
|
||||
}, grammar_options);
|
||||
});
|
||||
|
||||
auto tweaked_messages = common_chat_template::add_system(
|
||||
inputs.messages,
|
||||
|
@ -770,8 +798,11 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
|
|||
schema["maxItems"] = 1;
|
||||
}
|
||||
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
|
||||
}, grammar_options);
|
||||
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
|
||||
});
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"});
|
||||
data.preserved_tokens = {
|
||||
"[TOOL_CALLS]",
|
||||
};
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
|
||||
return data;
|
||||
|
@ -813,14 +844,18 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
|
|||
schema["maxItems"] = 1;
|
||||
}
|
||||
builder.add_rule("root", "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\"");
|
||||
}, grammar_options);
|
||||
data.grammar_triggers.push_back({"<|START_ACTION|>", /* .at_start = */ false});
|
||||
});
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
||||
"<|START_ACTION|>",
|
||||
});
|
||||
data.preserved_tokens = {
|
||||
"<|START_ACTION|>",
|
||||
"<|END_ACTION|>",
|
||||
"<|START_RESPONSE|>",
|
||||
"<|END_RESPONSE|>",
|
||||
"<|START_THINKING|>",
|
||||
"<|END_THINKING|>",
|
||||
"<|END_ACTION|>",
|
||||
};
|
||||
auto adjusted_messages = json::array();
|
||||
for (const auto & msg : inputs.messages) {
|
||||
|
@ -840,9 +875,9 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
|
|||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) {
|
||||
static std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|>)([\\s\\S\\n\\r]*)");
|
||||
static std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>");
|
||||
static std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>");
|
||||
static std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S]*?)<\\|END_THINKING\\|>)([\\s\\S]*)");
|
||||
static std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S]*?)<\\|END_ACTION\\|>");
|
||||
static std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S]*?)<\\|END_RESPONSE\\|>");
|
||||
|
||||
std::smatch match;
|
||||
|
||||
|
@ -945,23 +980,23 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
|
|||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"{\" space "
|
||||
"( \"\\\"type\\\":\" space \"\\\"function\\\",\" space )? "
|
||||
"\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
||||
builder.add_schema(name + "-args", parameters) +
|
||||
" \"}\""));
|
||||
data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
|
||||
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
|
||||
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
|
||||
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
|
||||
"\"}\" space"));
|
||||
});
|
||||
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
||||
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
|
||||
});
|
||||
data.grammar_triggers.push_back({"{\"name\":", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"{\n \"name\":", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"{\n \"name\":", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"{\"type\": \"function\"", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"{\n \"type\": \"function\"", /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({"{\n \"type\": \"function\"", /* .at_start = */ true});
|
||||
if (!builtin_tools.empty()) {
|
||||
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
||||
data.preserved_tokens.push_back("<|python_tag|>");
|
||||
}
|
||||
// Allow a few empty lines on top of the usual constrained json schema space rule.
|
||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||
}, grammar_options);
|
||||
});
|
||||
data.additional_stops.push_back("<|eom_id|>");
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
|
||||
{"tools_in_user_message", false},
|
||||
|
@ -974,33 +1009,33 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
|
|||
}
|
||||
static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
|
||||
// TODO: tighten & simplify the parser, don't accept leading text context.
|
||||
static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
|
||||
static std::regex close_regex("\\}");
|
||||
static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
|
||||
static std::regex function_regex(
|
||||
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
|
||||
static std::regex close_regex("\\}\\s*");
|
||||
static std::regex builtin_call_regex("<\\|python_tag\\|>\\s*([^.(]+)\\s*\\.\\s*call\\s*\\(\\s*([\\w]+)\\s*=\\s*([\\s\\S]*?)\\)");
|
||||
|
||||
if (with_builtin_tools) {
|
||||
std::smatch match;
|
||||
if (std::regex_match(input, match, builtin_call_regex)) {
|
||||
auto name = match[1].str();
|
||||
auto raw_args = match[2].str();
|
||||
try {
|
||||
auto name = match[1].str();
|
||||
auto arg_name = match[2].str();
|
||||
auto arg_value_str = match[3].str();
|
||||
auto arg_value = json::parse(arg_value_str);
|
||||
|
||||
// TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing.
|
||||
auto it_eq = raw_args.find('=');
|
||||
auto arg_name = raw_args.substr(0, it_eq);
|
||||
auto arg_value_str = raw_args.substr(it_eq + 1);
|
||||
auto arg_value = json::parse(arg_value_str);
|
||||
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
msg.content = match.prefix().str();
|
||||
msg.tool_calls.push_back({
|
||||
/* .name = */ name,
|
||||
/* .arguments = */ (json {
|
||||
{arg_name, arg_value},
|
||||
}).dump(),
|
||||
/* .id = */ "",
|
||||
});
|
||||
return msg;
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
msg.tool_calls.push_back({
|
||||
/* .name = */ name,
|
||||
/* .arguments = */ (json {
|
||||
{arg_name, arg_value},
|
||||
}).dump(),
|
||||
/* .id = */ "",
|
||||
});
|
||||
return msg;
|
||||
} catch (const std::exception & e) {
|
||||
LOG_WRN("Failed to parse builtin tool call arguments (%s): %s", e.what(), input.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
|
||||
|
@ -1017,10 +1052,10 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
|
|||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||
tool_rules.push_back(builder.add_rule(name + "-call",
|
||||
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
|
||||
"```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\""));
|
||||
"```json\\n\" " + builder.add_schema(name + "-args", parameters) + " "
|
||||
"\"```<|tool▁call▁end|>\""));
|
||||
});
|
||||
// Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
|
||||
// so we accept common variants (then it's all constrained)
|
||||
|
@ -1029,18 +1064,20 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
|
|||
"(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " "
|
||||
"\"<|tool▁calls▁end|>\""
|
||||
" space");
|
||||
data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({"<|tool_calls_begin|>", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({"<|tool calls begin|>", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({"<|tool\\_calls\\_begin|>", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool▁calls▁begin|>"});
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_calls_begin|>"});
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool calls begin|>"});
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool\\_calls\\_begin|>"});
|
||||
data.preserved_tokens = {
|
||||
"<think>",
|
||||
"</think>",
|
||||
"<|tool▁calls▁begin|>",
|
||||
"<|tool▁call▁begin|>",
|
||||
"<|tool▁sep|>",
|
||||
"<|tool▁calls▁end|",
|
||||
"<|tool▁call▁end|>",
|
||||
"<|tool▁calls▁end|",
|
||||
};
|
||||
}, grammar_options);
|
||||
});
|
||||
}
|
||||
auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
|
||||
|
@ -1129,8 +1166,11 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
|
|||
schema["maxItems"] = 1;
|
||||
}
|
||||
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
|
||||
}, grammar_options);
|
||||
data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
|
||||
});
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, " functools["});
|
||||
data.preserved_tokens = {
|
||||
" functools[",
|
||||
};
|
||||
data.format = COMMON_CHAT_FORMAT_FIREFUNCTION_V2;
|
||||
} else {
|
||||
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
|
@ -1158,11 +1198,28 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
|||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
|
||||
first_tool_rules.push_back(builder.add_rule(name + "-call", "( \"assistant<|end_header_id|>\\n\" )? \"" + name + "\\n\" " + args_rule));
|
||||
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
|
||||
data.grammar_triggers.push_back({name, /* .at_start = */ true});
|
||||
data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
||||
regex_escape(name + "\n"),
|
||||
});
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
||||
regex_escape("assistant<|end_header_id|>\n" + name + "\n"),
|
||||
});
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
||||
regex_escape(">>>" + name + "\n"),
|
||||
});
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
||||
">>>assistant<|end_header_id|>\n" + name,
|
||||
});
|
||||
});
|
||||
data.preserved_tokens = {
|
||||
"<|end_header_id|>",
|
||||
};
|
||||
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
|
||||
if (inputs.parallel_tool_calls) {
|
||||
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
|
||||
|
@ -1171,34 +1228,20 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
|||
builder.add_rule("root", first_rule);
|
||||
}
|
||||
|
||||
}, grammar_options);
|
||||
});
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
static bool consume(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) {
|
||||
auto expected_it = expected.begin();
|
||||
auto tmp_it = it;
|
||||
while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) {
|
||||
++tmp_it;
|
||||
++expected_it;
|
||||
}
|
||||
if (expected_it == expected.end()) {
|
||||
it = tmp_it;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) {
|
||||
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
||||
static std::regex function_regex(R"((?:>>>)?(?:assistant<|end_header_id|>\n)?(\w+)\n)");
|
||||
static std::regex close_regex(R"($|(?=>>>))");
|
||||
|
||||
std::string content;
|
||||
auto it = input.begin();
|
||||
const auto end = input.end();
|
||||
|
||||
if (consume(it, end, "all\n")) {
|
||||
if (parse_literal(it, end, "all\n")) {
|
||||
std::smatch match;
|
||||
if (std::regex_search(it, end, match, function_regex)) {
|
||||
auto fun_it = match.prefix().second;
|
||||
|
@ -1213,7 +1256,7 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
|
|||
}
|
||||
// TODO: tighten & simplify.
|
||||
try {
|
||||
auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex);
|
||||
auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex, /* allow_raw_python= */ true);
|
||||
res.content = content + res.content;
|
||||
return res;
|
||||
} catch (const std::exception & e) {
|
||||
|
@ -1266,12 +1309,13 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
|
|||
});
|
||||
if (has_raw_python) {
|
||||
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
||||
data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
||||
data.preserved_tokens.push_back("<|python_tag|>");
|
||||
}
|
||||
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
|
||||
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
|
||||
}, grammar_options);
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
|
||||
});
|
||||
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
// TODO: if (has_raw_python)
|
||||
|
@ -1306,6 +1350,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
|
|||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
std::vector<std::string> tool_call_alts;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
|
@ -1319,57 +1364,173 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
|
|||
}},
|
||||
{"required", json::array({"name", "arguments"})},
|
||||
}));
|
||||
tool_call_alts.push_back(builder.add_rule(
|
||||
name + "-function-tag",
|
||||
"\"<function\" ( \"=" + name + "\" | \" name=\\\"" + name + "\\\"\" ) \">\" space " +
|
||||
builder.add_schema(name + "-args", parameters) + " "
|
||||
"\"</function>\" space"));
|
||||
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
||||
"<function=" + name + ">",
|
||||
});
|
||||
auto escaped_name = regex_escape(name);
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
||||
"<function\\s+name\\s*=\\s*\"" + escaped_name + "\"",
|
||||
});
|
||||
});
|
||||
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
|
||||
auto any_tool_call = builder.add_rule("any_tool_call", "( " + string_join(tool_rules, " | ") + " ) space");
|
||||
std::vector<std::string> alt_tags {
|
||||
any_tool_call,
|
||||
"\"<tool_call>\" space " + any_tool_call + " \"</tool_call>\"",
|
||||
// The rest is just to accommodate common "good bad" outputs.
|
||||
"\"<function_call>\" space " + any_tool_call + " \"</function_call>\"",
|
||||
"\"<response>\" space " + any_tool_call + " \"</response>\"",
|
||||
"\"<tools>\" space " + any_tool_call + " \"</tools>\"",
|
||||
"\"<json>\" space " + any_tool_call + " \"</json>\"",
|
||||
"\"<xml>\" space " + any_tool_call + " \"</xml>\"",
|
||||
"\"<JSON>\" space " + any_tool_call + " \"</JSON>\"",
|
||||
};
|
||||
auto wrappable_tool_call = builder.add_rule("wrappable_tool_call", "( " + string_join(alt_tags, " | ") + " ) space");
|
||||
tool_call_alts.push_back(wrappable_tool_call);
|
||||
tool_call_alts.push_back(
|
||||
"( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space ");
|
||||
auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | "));
|
||||
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
|
||||
data.preserved_tokens = { "</tool_call>" };
|
||||
}, grammar_options);
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<tool_call>"});
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function"});
|
||||
// Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives)
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
||||
"(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?\\s*\\{\\s*\"", //name\"\\s*:\\s*\"" + escaped_name + "\"",
|
||||
});
|
||||
data.preserved_tokens = {
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
"<function",
|
||||
"<tools>",
|
||||
"</tools>",
|
||||
"<response>",
|
||||
"</response>",
|
||||
"<function_call>",
|
||||
"</function_call>",
|
||||
"<json>",
|
||||
"</json>",
|
||||
"<JSON>",
|
||||
"</JSON>",
|
||||
"```",
|
||||
"```json",
|
||||
"```xml",
|
||||
};
|
||||
});
|
||||
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) {
|
||||
static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input) {
|
||||
const static std::regex open_regex(
|
||||
"(?:"
|
||||
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
|
||||
"(<tool_call>" // match 2 (open_tag)
|
||||
"|<function_call>"
|
||||
"|<tool>"
|
||||
"|<tools>"
|
||||
"|<response>"
|
||||
"|<json>"
|
||||
"|<xml>"
|
||||
"|<JSON>"
|
||||
")?"
|
||||
"(\\s*\\{\\s*\"name\"\\s*:[\\s\\S]*)" // match 3 (named tool call + rest)
|
||||
")"
|
||||
"|"
|
||||
"(?:<function=([^>]+)>" // match 4 (function name)
|
||||
"|<function name=\"([^\"]+)\">)" // match 5 (function name again)
|
||||
"([\\s\\S]*)" // match 6 (function arguments + rest)})"
|
||||
);
|
||||
|
||||
try {
|
||||
std::regex start_pattern(R"([\n\s]*<tool_call>)");
|
||||
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
|
||||
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
|
||||
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
|
||||
auto end = input.end();
|
||||
std::sregex_iterator rend;
|
||||
std::sregex_iterator rit(input.begin(), end, start_pattern);
|
||||
if (rit == rend) {
|
||||
msg.content = input;
|
||||
return msg;
|
||||
}
|
||||
std::string::const_iterator it = input.begin();
|
||||
const std::string::const_iterator end = input.end();
|
||||
std::smatch match;
|
||||
|
||||
msg.content = rit->prefix();
|
||||
|
||||
auto it = rit->suffix().first;
|
||||
while (it != end) {
|
||||
json call;
|
||||
if (!parse_json(it, end, call)) {
|
||||
throw std::runtime_error("Failed to parse json tool call");
|
||||
}
|
||||
const auto & arguments = call.at("arguments");
|
||||
msg.tool_calls.push_back({
|
||||
call.at("name"),
|
||||
arguments.dump(),
|
||||
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||
/* id= */ "",
|
||||
});
|
||||
rit = {it, end, middle_pattern};
|
||||
if (rit != rend) {
|
||||
it = rit->suffix().first;
|
||||
} else {
|
||||
rit = {it, end, end_pattern};
|
||||
if (rit == rend) {
|
||||
throw std::runtime_error("Malformed input, missing </tool_call>");
|
||||
if (std::regex_search(it, end, match, open_regex)) {
|
||||
// Add content before the match
|
||||
msg.content += std::string(it, match[0].first);
|
||||
|
||||
auto block_start = match[1].str();
|
||||
std::string block_end = block_start.empty() ? "" : "```";
|
||||
|
||||
auto open_tag = match[2].str();
|
||||
std::string close_tag;
|
||||
|
||||
if (match[3].matched) {
|
||||
close_tag = open_tag.empty() ? "" : "</" + open_tag.substr(1);
|
||||
auto json_it = match[3].first;
|
||||
json tool_call;
|
||||
if (parse_json(json_it, end, tool_call) && tool_call.contains("name") && tool_call.contains("arguments")) {
|
||||
|
||||
msg.tool_calls.emplace_back(process_tool_call(tool_call));
|
||||
it = json_it; // Move iterator past parsed JSON
|
||||
|
||||
// Handle close tags
|
||||
consume_spaces(it, end);
|
||||
if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
|
||||
throw std::runtime_error("Failed to parse closing tag");
|
||||
}
|
||||
consume_spaces(it, end);
|
||||
if (!block_end.empty() && !parse_literal(it, end, block_end)) {
|
||||
throw std::runtime_error("Failed to parse block end");
|
||||
}
|
||||
consume_spaces(it, end);
|
||||
} else {
|
||||
// Not a valid tool call, treat as content
|
||||
msg.content += std::string(match[0].first, match[0].second);
|
||||
it = match[0].second;
|
||||
}
|
||||
} else {
|
||||
auto function_name = match[4].str();
|
||||
if (function_name.empty()) {
|
||||
function_name = match[5].str();
|
||||
}
|
||||
GGML_ASSERT(!function_name.empty());
|
||||
|
||||
close_tag = "</function>";
|
||||
// Start parsing from after the opening tags
|
||||
auto json_it = match[6].first;
|
||||
json arguments;
|
||||
if (parse_json(json_it, end, arguments)) {
|
||||
msg.tool_calls.emplace_back(process_tool_call({
|
||||
{"name", function_name},
|
||||
{"arguments", arguments},
|
||||
}));
|
||||
it = json_it; // Move iterator past parsed JSON
|
||||
|
||||
// Handle close tags
|
||||
consume_spaces(it, end);
|
||||
if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
|
||||
throw std::runtime_error("Failed to parse closing tag");
|
||||
}
|
||||
consume_spaces(it, end);
|
||||
if (!block_end.empty() && !parse_literal(it, end, block_end)) {
|
||||
throw std::runtime_error("Failed to parse block end");
|
||||
}
|
||||
consume_spaces(it, end);
|
||||
} else {
|
||||
// Not a valid tool call, treat as content
|
||||
msg.content += std::string(match[0].first, match[0].second);
|
||||
it = match[0].second;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Add remaining content
|
||||
msg.content += std::string(it, end);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,7 +10,6 @@
|
|||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "json.hpp"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
@ -483,6 +482,11 @@ void string_replace_all(std::string & s, const std::string & search, const std::
|
|||
s = std::move(builder);
|
||||
}
|
||||
|
||||
std::string regex_escape(const std::string & s) {
|
||||
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
||||
return std::regex_replace(s, special_chars, "\\$0");
|
||||
}
|
||||
|
||||
std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
|
||||
std::ostringstream result;
|
||||
for (size_t i = 0; i < values.size(); ++i) {
|
||||
|
@ -2026,3 +2030,25 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
|
|||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
json common_grammar_trigger::to_json() const {
|
||||
json out {
|
||||
{"type", (int) type},
|
||||
{"value", value},
|
||||
};
|
||||
if (type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
|
||||
out["token"] = (int) token;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template <>
|
||||
common_grammar_trigger common_grammar_trigger::from_json(const json & in) {
|
||||
common_grammar_trigger out;
|
||||
out.type = (common_grammar_trigger_type) in.at("type").get<int>();
|
||||
out.value = in.at("value").get<std::string>();
|
||||
if (out.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
|
||||
out.token = (llama_token) in.at("token").get<int>();
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
|
|
@ -110,9 +110,21 @@ enum common_conversation_mode {
|
|||
COMMON_CONVERSATION_MODE_AUTO = 2,
|
||||
};
|
||||
|
||||
enum common_grammar_trigger_type {
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
||||
};
|
||||
|
||||
struct common_grammar_trigger {
|
||||
std::string word;
|
||||
bool at_start;
|
||||
common_grammar_trigger_type type;
|
||||
std::string value;
|
||||
llama_token token = LLAMA_TOKEN_NULL;
|
||||
|
||||
// T can only be nlohmann::ordered_json
|
||||
template <class T> T to_json() const;
|
||||
template <class T> static common_grammar_trigger from_json(const T & in);
|
||||
};
|
||||
|
||||
// sampling parameters
|
||||
|
@ -163,8 +175,7 @@ struct common_params_sampling {
|
|||
|
||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||
bool grammar_lazy = false;
|
||||
std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to trigger lazy grammar
|
||||
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens.
|
||||
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
|
||||
std::set<llama_token> preserved_tokens;
|
||||
|
||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
|
@ -458,6 +469,8 @@ std::string string_repeat(const std::string & str, size_t n);
|
|||
|
||||
void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
|
||||
|
||||
std::string regex_escape(const std::string & s);
|
||||
|
||||
template<class T>
|
||||
static std::vector<T> string_split(const std::string & str, char delim) {
|
||||
static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string");
|
||||
|
|
|
@ -264,7 +264,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
|||
throw std::runtime_error("At least one of min_value or max_value must be set");
|
||||
}
|
||||
|
||||
const std::string SPACE_RULE = "| \" \" | \"\\n\" [ \\t]{0,20}";
|
||||
const std::string SPACE_RULE = "| \" \" | \"\\n\"{1,2} [ \\t]{0,20}";
|
||||
|
||||
struct BuiltinRule {
|
||||
std::string content;
|
||||
|
@ -764,11 +764,10 @@ private:
|
|||
public:
|
||||
SchemaConverter(
|
||||
const std::function<json(const std::string &)> & fetch_json,
|
||||
bool dotall,
|
||||
bool compact_spaces)
|
||||
bool dotall)
|
||||
: _fetch_json(fetch_json), _dotall(dotall)
|
||||
{
|
||||
_rules["space"] = compact_spaces ? "\" \"?" : SPACE_RULE;
|
||||
_rules["space"] = SPACE_RULE;
|
||||
}
|
||||
|
||||
void resolve_refs(json & schema, const std::string & url) {
|
||||
|
@ -1007,7 +1006,7 @@ std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
|
|||
}
|
||||
|
||||
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
|
||||
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall, options.compact_spaces);
|
||||
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall);
|
||||
common_grammar_builder builder {
|
||||
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
|
||||
return converter._add_rule(name, rule);
|
||||
|
|
|
@ -16,7 +16,6 @@ struct common_grammar_builder {
|
|||
|
||||
struct common_grammar_options {
|
||||
bool dotall = false;
|
||||
bool compact_spaces = false;
|
||||
};
|
||||
|
||||
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
|
||||
|
|
|
@ -160,16 +160,53 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
} else {
|
||||
std::vector<const char *> trigger_words;
|
||||
trigger_words.reserve(params.grammar_trigger_words.size());
|
||||
for (const auto & str : params.grammar_trigger_words) {
|
||||
trigger_words.push_back(str.word.c_str());
|
||||
std::vector<std::string> patterns_at_start;
|
||||
std::vector<std::string> patterns_anywhere;
|
||||
std::vector<llama_token> trigger_tokens;
|
||||
for (const auto & trigger : params.grammar_triggers) {
|
||||
switch (trigger.type) {
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
||||
{
|
||||
const auto & word = trigger.value;
|
||||
patterns_anywhere.push_back(regex_escape(word));
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
|
||||
{
|
||||
const auto & pattern = trigger.value;
|
||||
(trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
|
||||
{
|
||||
const auto token = trigger.token;
|
||||
trigger_tokens.push_back(token);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown trigger type");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> trigger_patterns;
|
||||
if (!patterns_at_start.empty()) {
|
||||
trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*");
|
||||
}
|
||||
if (!patterns_anywhere.empty()) {
|
||||
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
|
||||
}
|
||||
|
||||
std::vector<const char *> trigger_patterns_c;
|
||||
trigger_patterns_c.reserve(trigger_patterns.size());
|
||||
for (const auto & regex : trigger_patterns) {
|
||||
trigger_patterns_c.push_back(regex.c_str());
|
||||
}
|
||||
|
||||
grmr = params.grammar_lazy
|
||||
? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root",
|
||||
trigger_words.data(), trigger_words.size(),
|
||||
params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size())
|
||||
? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
||||
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||
trigger_tokens.data(), trigger_tokens.size())
|
||||
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue