From c9bbc77931d223ed7e7cbcf1cb057bc02fd0db19 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 2 Jun 2025 10:15:44 -0700 Subject: [PATCH] `server`: update deepseek reasoning format (pass reasoning_content as diffs) (#13933) * server: update deepseek reasoning format (now in reasoning_content diffs), add legacy option for compat * update unit/test_tool_call.py::test_thoughts --- common/arg.cpp | 1 + common/chat.cpp | 15 ++++++++------- common/chat.h | 2 +- common/common.h | 3 ++- tests/test-chat.cpp | 2 +- tools/server/server.cpp | 2 +- tools/server/tests/unit/test_tool_call.py | 13 ++++++------- tools/server/tests/utils.py | 11 ++++++++++- 8 files changed, 30 insertions(+), 19 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index cfa9878f..0d0daa36 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2869,6 +2869,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "(default: deepseek)", [](common_params & params, const std::string & value) { /**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; } + else if (value == "deepseek-legacy") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY; } else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; } else { throw std::invalid_argument("invalid value"); } } diff --git a/common/chat.cpp b/common/chat.cpp index f1ab4c85..1d6974a8 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -82,10 +82,10 @@ json common_chat_msg::to_json_oaicompat() const std::vector common_chat_msg_diff::compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg) { std::vector diffs; - // if (previous_msg.reasoning_content != current.reasoning_content) { - // auto & diff = diffs.emplace_back(); - // diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, current.reasoning_content); - // } + if (previous_msg.reasoning_content != new_msg.reasoning_content) { + auto & diff = diffs.emplace_back(); + diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, new_msg.reasoning_content); + } if (previous_msg.content != new_msg.content) { auto & diff = diffs.emplace_back(); diff.content_delta = string_diff(previous_msg.content, new_msg.content); @@ -385,9 +385,9 @@ json common_chat_tools_to_json_oaicompat(const std::vector & t template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) { json delta = json::object(); - // if (!diff.reasoning_content_delta.empty()) { - // delta["reasoning_content"] = msg.reasoning_content; - // } + if (!diff.reasoning_content_delta.empty()) { + delta["reasoning_content"] = diff.reasoning_content_delta; + } if (!diff.content_delta.empty()) { delta["content"] = diff.content_delta; } @@ -598,6 +598,7 @@ const char * common_reasoning_format_name(common_reasoning_format format) { switch (format) { case COMMON_REASONING_FORMAT_NONE: return "none"; case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek"; + case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy"; default: throw std::runtime_error("Unknown reasoning format"); } diff --git a/common/chat.h b/common/chat.h index f6b1d0ff..9f59e6b0 100644 --- a/common/chat.h +++ b/common/chat.h @@ -70,7 +70,7 @@ struct common_chat_msg { }; struct common_chat_msg_diff { - // std::string reasoning_content_delta; + std::string reasoning_content_delta; std::string content_delta; size_t tool_call_index = std::string::npos; common_chat_tool_call tool_call_delta; diff --git a/common/common.h b/common/common.h index cee1e303..f26724b6 100644 --- a/common/common.h +++ b/common/common.h @@ -215,7 +215,8 @@ struct common_params_vocoder { enum common_reasoning_format { COMMON_REASONING_FORMAT_NONE, - COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content` + COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in tags in stream mode + COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas. }; struct common_params { diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 1c980792..c6d998f1 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -19,8 +19,8 @@ using json = nlohmann::ordered_json; static std::ostream & operator<<(std::ostream & os, const common_chat_msg_diff & diff) { - // os << "reasoning_content_delta: " << diff.reasoning_content_delta << '\n'; os << "{ content_delta: " << diff.content_delta << "; "; + os << "reasoning_content_delta: " << diff.reasoning_content_delta << "; "; if (diff.tool_call_index != std::string::npos) { os << "tool_call_index: " << diff.tool_call_index << "; "; os << "tool_call_delta.name: " << diff.tool_call_delta.name << "; "; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 4b92eeac..dad686ea 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -360,7 +360,7 @@ struct server_task { params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format; } params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format; - params.oaicompat_chat_syntax.reasoning_in_content = params.stream; + params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (params_base.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false); params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false); } diff --git a/tools/server/tests/unit/test_tool_call.py b/tools/server/tests/unit/test_tool_call.py index 61061074..20f048c6 100755 --- a/tools/server/tests/unit/test_tool_call.py +++ b/tools/server/tests/unit/test_tool_call.py @@ -499,13 +499,12 @@ def do_test_calc_result(server: ServerProcess, result_override: str | None, n_pr @pytest.mark.slow -@pytest.mark.parametrize("n_predict,reasoning_format,stream,expect_reasoning_content,expect_content,hf_repo,template_override", [ - (128, 'deepseek', CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), - (128, None, CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), - (1024, 'deepseek', CompletionMode.NORMAL, "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), - (1024, 'deepseek', CompletionMode.STREAMED, None, "^I need to calculate [\\s\\S]*?To find the sum of [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), - (1024, 'deepseek', CompletionMode.NORMAL, "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), - (1024, 'deepseek', CompletionMode.STREAMED, None, "^First, I [\\s\\S]*?To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), +@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) +@pytest.mark.parametrize("n_predict,reasoning_format,expect_reasoning_content,expect_content,hf_repo,template_override", [ + (128, 'deepseek', None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (128, None, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (1024, 'deepseek', "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + (1024, 'deepseek', "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), # (1024, 'none', CompletionMode.NORMAL, None, "^(\\s*)?I need[\\s\\S]*?\\s*To find[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), # (128, 'deepseek', None, "^Okay, let me figure out the sum of 102 and 7[\\s\\S]*", "bartowski/Qwen_QwQ-32B-GGUF:Q4_K_M", None), ]) diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index f7e1b3b3..bc547ca0 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -308,10 +308,12 @@ class ServerProcess: stream = data.get('stream', False) if stream: content: list[str] = [] + reasoning_content: list[str] = [] tool_calls: list[dict] = [] finish_reason: Optional[str] = None content_parts = 0 + reasoning_content_parts = 0 tool_call_parts = 0 arguments_parts = 0 @@ -322,6 +324,10 @@ class ServerProcess: assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!' content.append(choice['delta']['content']) content_parts += 1 + if choice['delta'].get('reasoning_content') is not None: + assert len(choice['delta']['reasoning_content']) > 0, f'Expected non empty reasoning_content delta!' + reasoning_content.append(choice['delta']['reasoning_content']) + reasoning_content_parts += 1 if choice['delta'].get('finish_reason') is not None: finish_reason = choice['delta']['finish_reason'] for tc in choice['delta'].get('tool_calls', []): @@ -349,8 +355,10 @@ class ServerProcess: tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name'] if fct.get('arguments') is not None: tool_call['function']['arguments'] += fct['arguments'] + arguments_parts += 1 + tool_call_parts += 1 - print(f'Streamed response had {content_parts} content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts') + print(f'Streamed response had {content_parts} content parts, {reasoning_content_parts} reasoning_content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts') result = dict( choices=[ dict( @@ -359,6 +367,7 @@ class ServerProcess: message=dict( role='assistant', content=''.join(content) if content else None, + reasoning_content=''.join(reasoning_content) if reasoning_content else None, tool_calls=tool_calls if tool_calls else None, ), )