server: extract <think> tags from qwq outputs (#12297)

* extract <think> tags from qwq outputs

* const for all static regexes in chat.cpp
This commit is contained in:
Olivier Chafik 2025-03-10 10:59:03 +00:00 committed by GitHub
parent be421fc429
commit 4e39a3c332
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 162 additions and 134 deletions

View file

@ -445,6 +445,7 @@ std::string common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2"; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1"; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro"; case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING: return "Hermes 2 Pro (extract reasoning)";
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B"; case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: return "Command R7B (extract reasoning)"; case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: return "Command R7B (extract reasoning)";
default: default:
@ -878,9 +879,9 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
return data; return data;
} }
static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) { static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) {
static std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S]*?)<\\|END_THINKING\\|>)([\\s\\S]*)"); static const std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S]*?)<\\|END_THINKING\\|>)([\\s\\S]*)");
static std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S]*?)<\\|END_ACTION\\|>"); static const std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S]*?)<\\|END_ACTION\\|>");
static std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S]*?)<\\|END_RESPONSE\\|>"); static const std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S]*?)<\\|END_RESPONSE\\|>");
std::smatch match; std::smatch match;
@ -1012,10 +1013,10 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
} }
static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) { static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
// TODO: tighten & simplify the parser, don't accept leading text context. // TODO: tighten & simplify the parser, don't accept leading text context.
static std::regex function_regex( static const std::regex function_regex(
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: "); "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
static std::regex close_regex("\\}\\s*"); static const std::regex close_regex("\\}\\s*");
static std::regex builtin_call_regex("<\\|python_tag\\|>\\s*([^.(]+)\\s*\\.\\s*call\\s*\\(\\s*([\\w]+)\\s*=\\s*([\\s\\S]*?)\\)"); static const std::regex builtin_call_regex("<\\|python_tag\\|>\\s*([^.(]+)\\s*\\.\\s*call\\s*\\(\\s*([\\w]+)\\s*=\\s*([\\s\\S]*?)\\)");
if (with_builtin_tools) { if (with_builtin_tools) {
std::smatch match; std::smatch match;
@ -1105,34 +1106,42 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING : COMMON_CHAT_FORMAT_DEEPSEEK_R1; data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING : COMMON_CHAT_FORMAT_DEEPSEEK_R1;
return data; return data;
} }
static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) { static common_chat_msg handle_think_tag_prelude(const std::string & input, bool extract_reasoning, const std::function<common_chat_msg(const std::string &)> & rest_parser) {
static std::regex function_regex("<tool▁call▁begin>function<tool▁sep>([^\n]+)\n```json\n");
static std::regex close_regex("```[\\s\\r\\n]*<tool▁call▁end>");
static std::regex reasoning_content_regex("((?:<think>)?([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)");
static std::regex tool_calls_regex("[\\s\\r\\n]*(?:<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>)([\\s\\S\\r\\n]*?)<tool▁calls▁end>");
common_chat_msg msg;
msg.role = "assistant";
std::smatch match; std::smatch match;
static const std::regex reasoning_content_regex("((?:<think>)?([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)");
if (std::regex_match(input, match, reasoning_content_regex)) { if (std::regex_match(input, match, reasoning_content_regex)) {
std::string rest; auto rest = match[3].str();
auto msg = rest_parser(rest);
auto reasoning_content = string_strip(match[2].str());
if (extract_reasoning) { if (extract_reasoning) {
msg.reasoning_content = string_strip(match[2].str()); msg.reasoning_content = reasoning_content;
} else { } else if (!reasoning_content.empty()) {
msg.content = match[1].str(); std::ostringstream content;
content << "<think>" << reasoning_content << "</think>" << msg.content;
msg.content = content.str();
} }
rest = match[3].str(); return msg;
}
return rest_parser(input);
}
static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) {
return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) {
static const std::regex function_regex("<tool▁call▁begin>function<tool▁sep>([^\n]+)\n```json\n");
static const std::regex close_regex("```[\\s\\r\\n]*<tool▁call▁end>");
static const std::regex tool_calls_regex("[\\s\\r\\n]*(?:<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>)([\\s\\S\\r\\n]*?)<tool▁calls▁end>");
if (std::regex_search(rest, match, tool_calls_regex)) { common_chat_msg msg;
msg.role = "assistant";
std::smatch match;
if (std::regex_search(input, match, tool_calls_regex)) {
auto tool_calls = match[1].str(); auto tool_calls = match[1].str();
auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex); auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex);
msg.tool_calls = std::move(msg2.tool_calls); msg.tool_calls = std::move(msg2.tool_calls);
} else { } else {
msg.content += std::string(rest.begin() + rest.find_first_not_of(" \r\n"), rest.end()); msg.content = input;
} }
} else { return msg;
msg.content = input; });
}
return msg;
} }
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
@ -1237,8 +1246,8 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
} }
static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) { static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) {
static std::regex function_regex(R"((?:>>>)?(?:assistant<|end_header_id|>\n)?(\w+)\n)"); static const std::regex function_regex(R"((?:>>>)?(?:assistant<|end_header_id|>\n)?(\w+)\n)");
static std::regex close_regex(R"($|(?=>>>))"); static const std::regex close_regex(R"($|(?=>>>))");
std::string content; std::string content;
auto it = input.begin(); auto it = input.begin();
@ -1327,7 +1336,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
} }
static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) { static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
// This version of Functionary still supports the llama 3.1 tool call format for the python tool. // This version of Functionary still supports the llama 3.1 tool call format for the python tool.
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); static const std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
std::smatch match; std::smatch match;
if (std::regex_search(input, match, python_tag_regex)) { if (std::regex_search(input, match, python_tag_regex)) {
auto code = match[1].str(); auto code = match[1].str();
@ -1341,8 +1350,8 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s
}); });
return msg; return msg;
} }
static std::regex function_regex(R"(<function=(\w+)>)"); static const std::regex function_regex(R"(<function=(\w+)>)");
static std::regex close_regex(R"(</function>)"); static const std::regex close_regex(R"(</function>)");
// TODO: tighten & simplify. // TODO: tighten & simplify.
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex); return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
} }
@ -1409,6 +1418,8 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
"(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?\\s*\\{\\s*\"", //name\"\\s*:\\s*\"" + escaped_name + "\"", "(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?\\s*\\{\\s*\"", //name\"\\s*:\\s*\"" + escaped_name + "\"",
}); });
data.preserved_tokens = { data.preserved_tokens = {
"<think>",
"</think>",
"<tool_call>", "<tool_call>",
"</tool_call>", "</tool_call>",
"<function", "<function",
@ -1429,122 +1440,123 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
}); });
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO; data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING : COMMON_CHAT_FORMAT_HERMES_2_PRO;
return data; return data;
} }
static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input) { static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input, bool extract_reasoning) {
const static std::regex open_regex( return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) {
"(?:" static const std::regex open_regex(
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start) "(?:"
"(<tool_call>" // match 2 (open_tag) "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
"|<function_call>" "(<tool_call>" // match 2 (open_tag)
"|<tool>" "|<function_call>"
"|<tools>" "|<tool>"
"|<response>" "|<tools>"
"|<json>" "|<response>"
"|<xml>" "|<json>"
"|<JSON>" "|<xml>"
")?" "|<JSON>"
"(\\s*\\{\\s*\"name\"\\s*:[\\s\\S]*)" // match 3 (named tool call + rest) ")?"
")" "(\\s*\\{\\s*\"name\"\\s*:[\\s\\S]*)" // match 3 (named tool call + rest)
"|" ")"
"(?:<function=([^>]+)>" // match 4 (function name) "|"
"|<function name=\"([^\"]+)\">)" // match 5 (function name again) "(?:<function=([^>]+)>" // match 4 (function name)
"([\\s\\S]*)" // match 6 (function arguments + rest)})" "|<function name=\"([^\"]+)\">)" // match 5 (function name again)
); "([\\s\\S]*)" // match 6 (function arguments + rest)})"
);
try { try {
common_chat_msg msg;
msg.role = "assistant";
common_chat_msg msg; std::string::const_iterator it = input.begin();
msg.role = "assistant"; const std::string::const_iterator end = input.end();
std::smatch match;
std::string::const_iterator it = input.begin(); while (it != end) {
const std::string::const_iterator end = input.end(); if (std::regex_search(it, end, match, open_regex)) {
std::smatch match; // Add content before the match
msg.content += std::string(it, match[0].first);
while (it != end) { auto block_start = match[1].str();
if (std::regex_search(it, end, match, open_regex)) { std::string block_end = block_start.empty() ? "" : "```";
// Add content before the match
msg.content += std::string(it, match[0].first);
auto block_start = match[1].str(); auto open_tag = match[2].str();
std::string block_end = block_start.empty() ? "" : "```"; std::string close_tag;
auto open_tag = match[2].str(); if (match[3].matched) {
std::string close_tag; close_tag = open_tag.empty() ? "" : "</" + open_tag.substr(1);
auto json_it = match[3].first;
json tool_call;
if (parse_json(json_it, end, tool_call) && tool_call.contains("name") && tool_call.contains("arguments")) {
if (match[3].matched) { msg.tool_calls.emplace_back(process_tool_call(tool_call));
close_tag = open_tag.empty() ? "" : "</" + open_tag.substr(1); it = json_it; // Move iterator past parsed JSON
auto json_it = match[3].first;
json tool_call;
if (parse_json(json_it, end, tool_call) && tool_call.contains("name") && tool_call.contains("arguments")) {
msg.tool_calls.emplace_back(process_tool_call(tool_call)); // Handle close tags
it = json_it; // Move iterator past parsed JSON consume_spaces(it, end);
if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
// Handle close tags throw std::runtime_error("Failed to parse closing tag");
consume_spaces(it, end); }
if (!close_tag.empty() && !parse_literal(it, end, close_tag)) { consume_spaces(it, end);
throw std::runtime_error("Failed to parse closing tag"); if (!block_end.empty() && !parse_literal(it, end, block_end)) {
throw std::runtime_error("Failed to parse block end");
}
consume_spaces(it, end);
} else {
// Not a valid tool call, treat as content
msg.content += std::string(match[0].first, match[0].second);
it = match[0].second;
} }
consume_spaces(it, end);
if (!block_end.empty() && !parse_literal(it, end, block_end)) {
throw std::runtime_error("Failed to parse block end");
}
consume_spaces(it, end);
} else { } else {
// Not a valid tool call, treat as content auto function_name = match[4].str();
msg.content += std::string(match[0].first, match[0].second); if (function_name.empty()) {
it = match[0].second; function_name = match[5].str();
}
GGML_ASSERT(!function_name.empty());
close_tag = "</function>";
// Start parsing from after the opening tags
auto json_it = match[6].first;
json arguments;
if (parse_json(json_it, end, arguments)) {
msg.tool_calls.emplace_back(process_tool_call({
{"name", function_name},
{"arguments", arguments},
}));
it = json_it; // Move iterator past parsed JSON
// Handle close tags
consume_spaces(it, end);
if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
throw std::runtime_error("Failed to parse closing tag");
}
consume_spaces(it, end);
if (!block_end.empty() && !parse_literal(it, end, block_end)) {
throw std::runtime_error("Failed to parse block end");
}
consume_spaces(it, end);
} else {
// Not a valid tool call, treat as content
msg.content += std::string(match[0].first, match[0].second);
it = match[0].second;
}
} }
} else { } else {
auto function_name = match[4].str(); // Add remaining content
if (function_name.empty()) { msg.content += std::string(it, end);
function_name = match[5].str(); break;
}
GGML_ASSERT(!function_name.empty());
close_tag = "</function>";
// Start parsing from after the opening tags
auto json_it = match[6].first;
json arguments;
if (parse_json(json_it, end, arguments)) {
msg.tool_calls.emplace_back(process_tool_call({
{"name", function_name},
{"arguments", arguments},
}));
it = json_it; // Move iterator past parsed JSON
// Handle close tags
consume_spaces(it, end);
if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
throw std::runtime_error("Failed to parse closing tag");
}
consume_spaces(it, end);
if (!block_end.empty() && !parse_literal(it, end, block_end)) {
throw std::runtime_error("Failed to parse block end");
}
consume_spaces(it, end);
} else {
// Not a valid tool call, treat as content
msg.content += std::string(match[0].first, match[0].second);
it = match[0].second;
}
} }
} else {
// Add remaining content
msg.content += std::string(it, end);
break;
} }
return msg;
} catch (const std::exception & e) {
LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what());
common_chat_msg msg;
msg.role = "assistant";
msg.content = input;
return msg;
} }
return msg; });
} catch (const std::exception & e) {
LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what());
common_chat_msg msg;
msg.role = "assistant";
msg.content = input;
return msg;
}
} }
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
@ -1609,6 +1621,11 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_command_r7b(tmpl, params); return common_chat_params_init_command_r7b(tmpl, params);
} }
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
return common_chat_params_init_hermes_2_pro(tmpl, params);
}
// Use generic handler when mixing tools + JSON schema. // Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below. // TODO: support that mix in handlers below.
if ((params.tools.is_array() && params.json_schema.is_object())) { if ((params.tools.is_array() && params.json_schema.is_object())) {
@ -1630,11 +1647,6 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_without_tools(tmpl, params); return common_chat_params_init_without_tools(tmpl, params);
} }
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
if (src.find("<tool_call>") != std::string::npos) {
return common_chat_params_init_hermes_2_pro(tmpl, params);
}
// Functionary v3.1 (w/ tools) // Functionary v3.1 (w/ tools)
if (src.find("<|start_header_id|>") != std::string::npos if (src.find("<|start_header_id|>") != std::string::npos
&& src.find("<function=") != std::string::npos) { && src.find("<function=") != std::string::npos) {
@ -1752,7 +1764,9 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
return common_chat_parse_functionary_v3_1_llama_3_1(input); return common_chat_parse_functionary_v3_1_llama_3_1(input);
case COMMON_CHAT_FORMAT_HERMES_2_PRO: case COMMON_CHAT_FORMAT_HERMES_2_PRO:
return common_chat_parse_hermes_2_pro(input); return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ false);
case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING:
return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ true);
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
return common_chat_parse_firefunction_v2(input); return common_chat_parse_firefunction_v2(input);
case COMMON_CHAT_FORMAT_COMMAND_R7B: case COMMON_CHAT_FORMAT_COMMAND_R7B:

View file

@ -53,6 +53,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO, COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_COMMAND_R7B, COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,

View file

@ -766,6 +766,19 @@ static void test_template_output_parsers() {
"{\n \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", "{\n \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
COMMON_CHAT_FORMAT_HERMES_2_PRO)); COMMON_CHAT_FORMAT_HERMES_2_PRO));
assert_msg_equals(message_assist_thoughts_unparsed_think,
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
COMMON_CHAT_FORMAT_HERMES_2_PRO));
assert_msg_equals(message_assist_thoughts_unparsed_think,
common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
COMMON_CHAT_FORMAT_HERMES_2_PRO));
assert_msg_equals(message_assist_thoughts,
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING));
assert_msg_equals(message_assist_thoughts,
common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING));
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
test_templates(tmpls.get(), end_tokens, message_assist_call, tools, test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
"<tool_call>\n" "<tool_call>\n"