tool-call
: fix Qwen 2.5 Coder support, add micro benchmarks, support trigger patterns for lazy grammars (#12034)
* sampler: turn lazy grammar trigger words to regexes * add scripts/tool_bench.sh & .py * constrain llama json output regardless of function name if matches at beginning * update relaxed newline space rule in grammar tests * support add_generation_prompt query parameter (useful for /apply_template) * Update src/llama-grammar.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
fa31c438e0
commit
669912d9a5
26 changed files with 1314 additions and 408 deletions
|
@ -237,12 +237,35 @@ static void test_templates(const struct common_chat_templates * tmpls, const std
|
|||
auto earliest_trigger_pos = std::string::npos;
|
||||
auto constrained = data.delta;
|
||||
for (const auto & trigger : data.params.grammar_triggers) {
|
||||
auto pos = constrained.find(trigger.word);
|
||||
if (pos == std::string::npos) {
|
||||
continue;
|
||||
size_t pos = std::string::npos;
|
||||
std::smatch match;
|
||||
switch (trigger.type) {
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
||||
{
|
||||
const auto & word = trigger.value;
|
||||
pos = constrained.find(word);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||
{
|
||||
const auto & pattern = trigger.value;
|
||||
if (std::regex_search(constrained, match, std::regex(pattern))) {
|
||||
pos = match.position();
|
||||
}
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
|
||||
{
|
||||
const auto & pattern = trigger.value;
|
||||
if (std::regex_search(constrained, match, std::regex(pattern)) && match.position() == 0) {
|
||||
pos = 0;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("Unknown trigger type");
|
||||
}
|
||||
if (pos > 0 && trigger.at_start) {
|
||||
fprintf(stderr, "Trigger %s not at start of message, skipping:\n\n%s\n\n", trigger.word.c_str(), constrained.c_str());
|
||||
if (pos == std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
|
||||
|
@ -260,7 +283,8 @@ static void test_templates(const struct common_chat_templates * tmpls, const std
|
|||
|
||||
if (grammar_triggered && test_grammar_if_triggered && !match_string(constrained, grammar.get())) {
|
||||
throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
|
||||
"\n\nGrammar: " + data.params.grammar);
|
||||
"\n\nConstrained: " + constrained +
|
||||
"\n\nGrammar: " + data.params.grammar);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -640,6 +664,93 @@ static void test_template_output_parsers() {
|
|||
inputs_tools)
|
||||
.format);
|
||||
|
||||
// Test parsing
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"<tool_call>\n"
|
||||
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</tool_call>",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"<function=special_function>{\"arg1\": 1}</function>",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"<function name=\"special_function\">\n"
|
||||
"{\"arg1\": 1}\n"
|
||||
"</function>",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"<tool>\n"
|
||||
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</tool>",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"<tools>\n"
|
||||
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</tools>",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"<response>\n"
|
||||
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</response>",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"```xml\n"
|
||||
"<response>\n"
|
||||
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</response>\n"
|
||||
"```",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"```xml\n"
|
||||
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"```",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"```\n"
|
||||
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"```",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"```\n"
|
||||
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"```",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"```json\n"
|
||||
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"```",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"```json\n"
|
||||
"\n"
|
||||
" <function_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}} \n"
|
||||
" </function_call> \n"
|
||||
"``` ",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"<json>\n"
|
||||
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</json>",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"<xml>\n"
|
||||
" {\n"
|
||||
" \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}\n"
|
||||
" }\n"
|
||||
"</xml>",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"<JSON>\n"
|
||||
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
||||
"</JSON>",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
assert_msg_equals(message_assist_call, common_chat_parse(
|
||||
"{\n \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
||||
|
||||
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
||||
"<tool_call>\n"
|
||||
|
@ -789,7 +900,7 @@ static void test_template_output_parsers() {
|
|||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
try {
|
||||
// try {
|
||||
#ifndef _WIN32
|
||||
if (argc > 1) {
|
||||
common_chat_templates_inputs inputs;
|
||||
|
@ -827,8 +938,8 @@ int main(int argc, char ** argv) {
|
|||
std::cout << "\n[chat] All tests passed!" << '\n';
|
||||
}
|
||||
return 0;
|
||||
} catch (const std::exception & e) {
|
||||
std::cerr << "Error: " << e.what() << '\n';
|
||||
return 1;
|
||||
}
|
||||
// } catch (const std::exception & e) {
|
||||
// std::cerr << "Error: " << e.what() << '\n';
|
||||
// return 1;
|
||||
// }
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue