From 87c2630546cd8ccc836ae4c7d87d03fa597d2267 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 10 Mar 2025 09:45:07 +0000 Subject: [PATCH] allow missing content in message if tool_calls provided (#12293) --- common/chat.cpp | 29 ++++++++++++++++------------- tests/test-chat.cpp | 15 +++++++++++++++ 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index 1b10219c..1b3f286a 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -60,7 +60,9 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa } msg.role = message.at("role"); - if (message.contains("content")) { + auto has_content = message.contains("content"); + auto has_tool_calls = message.contains("tool_calls"); + if (has_content) { const auto & content = message.at("content"); if (content.is_string()) { msg.content = content; @@ -81,19 +83,8 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa } else if (!content.is_null()) { throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)"); } - } else { - throw std::runtime_error("Expected 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)"); } - if (message.contains("reasoning_content")) { - msg.reasoning_content = message.at("reasoning_content"); - } - if (message.contains("name")) { - msg.tool_name = message.at("name"); - } - if (message.contains("tool_call_id")) { - msg.tool_call_id = message.at("tool_call_id"); - } - if (message.contains("tool_calls")) { + if (has_tool_calls) { for (const auto & tool_call : message.at("tool_calls")) { common_chat_tool_call tc; if (!tool_call.contains("type")) { @@ -118,6 +109,18 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa msg.tool_calls.push_back(tc); } } + if (!has_content && !has_tool_calls) { + throw std::runtime_error("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)"); + } + if (message.contains("reasoning_content")) { + msg.reasoning_content = message.at("reasoning_content"); + } + if (message.contains("name")) { + msg.tool_name = message.at("name"); + } + if (message.contains("tool_call_id")) { + msg.tool_call_id = message.at("tool_call_id"); + } msgs.push_back(msg); } diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 35a307c6..35c7ee34 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -480,6 +480,21 @@ static void test_msgs_oaicompat_json_conversion() { "]" ), common_chat_msgs_to_json_oaicompat({message_assist_call_python}).dump(2)); + + auto res = common_chat_msgs_parse_oaicompat(json::parse("[{\"role\": \"assistant\", \"tool_calls\": []}]")); + assert_equals(1, res.size()); + assert_equals(res[0].role, "assistant"); + assert_equals(true, res[0].content.empty()); + assert_equals(true, res[0].tool_calls.empty()); + + try { + common_chat_msgs_parse_oaicompat(json::parse("[{\"role\": \"assistant\"}]")); + throw std::runtime_error("Expected exception"); + } catch (const std::exception & e) { + if (std::string(e.what()).find("'content'") == std::string::npos) { + throw std::runtime_error("Expected exception about missing 'content'"); + } + } } static void test_tools_oaicompat_json_conversion() {