From 42158ae2e8ead667a83f07247321ce85f32ace66 Mon Sep 17 00:00:00 2001 From: Dorin-Andrei Geman Date: Wed, 21 May 2025 16:07:57 +0300 Subject: [PATCH] server : fix first message identification (#13634) * server : fix first message identification When using the OpenAI SDK (https://github.com/openai/openai-node/blob/master/src/lib/ChatCompletionStream.ts#L623-L626) we noticed that the expected assistant role is missing in the first streaming message. Fix this by correctly checking for the first message. Co-authored-by: Piotr Stankiewicz Signed-off-by: Dorin Geman * server : Fix checks for first role message for stream=True Co-authored-by: Piotr Stankiewicz Signed-off-by: Dorin Geman --------- Signed-off-by: Dorin Geman Co-authored-by: Piotr Stankiewicz --- tools/server/server.cpp | 18 +++++- .../server/tests/unit/test_chat_completion.py | 56 ++++++++++++------- 2 files changed, 53 insertions(+), 21 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 3b1305e1..d48cf46e 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -951,7 +951,7 @@ struct server_task_result_cmpl_partial : server_task_result { } json to_json_oaicompat_chat() { - bool first = n_decoded == 0; + bool first = n_decoded == 1; std::time_t t = std::time(0); json choices; @@ -962,15 +962,18 @@ struct server_task_result_cmpl_partial : server_task_result { {"delta", json{{"role", "assistant"}}}}}); } else { // We have to send this as two updates to conform to openai behavior + // initial_ret is the role message for stream=True json initial_ret = json{{"choices", json::array({json{ {"finish_reason", nullptr}, {"index", 0}, {"delta", json{ - {"role", "assistant"} + {"role", "assistant"}, + {"content", ""} }}}})}, {"created", t}, {"id", oaicompat_cmpl_id}, {"model", oaicompat_model}, + {"system_fingerprint", build_info}, {"object", "chat.completion.chunk"}}; json second_ret = json{ @@ -982,8 +985,19 @@ struct server_task_result_cmpl_partial : server_task_result { {"created", t}, {"id", oaicompat_cmpl_id}, {"model", oaicompat_model}, + {"system_fingerprint", build_info}, {"object", "chat.completion.chunk"}}; + if (prob_output.probs.size() > 0) { + second_ret["choices"][0]["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + + if (timings.prompt_n >= 0) { + second_ret.push_back({"timings", timings.to_json()}); + } + return std::vector({initial_ret, second_ret}); } } else { diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 491cb3a5..bab5d005 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -71,8 +71,14 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte }) content = "" last_cmpl_id = None - for data in res: + for i, data in enumerate(res): choice = data["choices"][0] + if i == 0: + # Check first role message for stream=True + assert choice["delta"]["content"] == "" + assert choice["delta"]["role"] == "assistant" + else: + assert "role" not in choice["delta"] assert data["system_fingerprint"].startswith("b") assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future if last_cmpl_id is None: @@ -242,12 +248,18 @@ def test_chat_completion_with_timings_per_token(): "stream": True, "timings_per_token": True, }) - for data in res: - assert "timings" in data - assert "prompt_per_second" in data["timings"] - assert "predicted_per_second" in data["timings"] - assert "predicted_n" in data["timings"] - assert data["timings"]["predicted_n"] <= 10 + for i, data in enumerate(res): + if i == 0: + # Check first role message for stream=True + assert data["choices"][0]["delta"]["content"] == "" + assert data["choices"][0]["delta"]["role"] == "assistant" + else: + assert "role" not in data["choices"][0]["delta"] + assert "timings" in data + assert "prompt_per_second" in data["timings"] + assert "predicted_per_second" in data["timings"] + assert "predicted_n" in data["timings"] + assert data["timings"]["predicted_n"] <= 10 def test_logprobs(): @@ -295,17 +307,23 @@ def test_logprobs_stream(): ) output_text = '' aggregated_text = '' - for data in res: + for i, data in enumerate(res): choice = data.choices[0] - if choice.finish_reason is None: - if choice.delta.content: - output_text += choice.delta.content - assert choice.logprobs is not None - assert choice.logprobs.content is not None - for token in choice.logprobs.content: - aggregated_text += token.token - assert token.logprob <= 0.0 - assert token.bytes is not None - assert token.top_logprobs is not None - assert len(token.top_logprobs) > 0 + if i == 0: + # Check first role message for stream=True + assert choice.delta.content == "" + assert choice.delta.role == "assistant" + else: + assert choice.delta.role is None + if choice.finish_reason is None: + if choice.delta.content: + output_text += choice.delta.content + assert choice.logprobs is not None + assert choice.logprobs.content is not None + for token in choice.logprobs.content: + aggregated_text += token.token + assert token.logprob <= 0.0 + assert token.bytes is not None + assert token.top_logprobs is not None + assert len(token.top_logprobs) > 0 assert aggregated_text == output_text