server : fix first message identification (#13634)
* server : fix first message identification When using the OpenAI SDK (https://github.com/openai/openai-node/blob/master/src/lib/ChatCompletionStream.ts#L623-L626) we noticed that the expected assistant role is missing in the first streaming message. Fix this by correctly checking for the first message. Co-authored-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com> Signed-off-by: Dorin Geman <dorin.geman@docker.com> * server : Fix checks for first role message for stream=True Co-authored-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com> Signed-off-by: Dorin Geman <dorin.geman@docker.com> --------- Signed-off-by: Dorin Geman <dorin.geman@docker.com> Co-authored-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com>
This commit is contained in:
parent
797f2ac062
commit
42158ae2e8
2 changed files with 53 additions and 21 deletions
|
@ -951,7 +951,7 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||||
}
|
}
|
||||||
|
|
||||||
json to_json_oaicompat_chat() {
|
json to_json_oaicompat_chat() {
|
||||||
bool first = n_decoded == 0;
|
bool first = n_decoded == 1;
|
||||||
std::time_t t = std::time(0);
|
std::time_t t = std::time(0);
|
||||||
json choices;
|
json choices;
|
||||||
|
|
||||||
|
@ -962,15 +962,18 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||||
{"delta", json{{"role", "assistant"}}}}});
|
{"delta", json{{"role", "assistant"}}}}});
|
||||||
} else {
|
} else {
|
||||||
// We have to send this as two updates to conform to openai behavior
|
// We have to send this as two updates to conform to openai behavior
|
||||||
|
// initial_ret is the role message for stream=True
|
||||||
json initial_ret = json{{"choices", json::array({json{
|
json initial_ret = json{{"choices", json::array({json{
|
||||||
{"finish_reason", nullptr},
|
{"finish_reason", nullptr},
|
||||||
{"index", 0},
|
{"index", 0},
|
||||||
{"delta", json{
|
{"delta", json{
|
||||||
{"role", "assistant"}
|
{"role", "assistant"},
|
||||||
|
{"content", ""}
|
||||||
}}}})},
|
}}}})},
|
||||||
{"created", t},
|
{"created", t},
|
||||||
{"id", oaicompat_cmpl_id},
|
{"id", oaicompat_cmpl_id},
|
||||||
{"model", oaicompat_model},
|
{"model", oaicompat_model},
|
||||||
|
{"system_fingerprint", build_info},
|
||||||
{"object", "chat.completion.chunk"}};
|
{"object", "chat.completion.chunk"}};
|
||||||
|
|
||||||
json second_ret = json{
|
json second_ret = json{
|
||||||
|
@ -982,8 +985,19 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||||
{"created", t},
|
{"created", t},
|
||||||
{"id", oaicompat_cmpl_id},
|
{"id", oaicompat_cmpl_id},
|
||||||
{"model", oaicompat_model},
|
{"model", oaicompat_model},
|
||||||
|
{"system_fingerprint", build_info},
|
||||||
{"object", "chat.completion.chunk"}};
|
{"object", "chat.completion.chunk"}};
|
||||||
|
|
||||||
|
if (prob_output.probs.size() > 0) {
|
||||||
|
second_ret["choices"][0]["logprobs"] = json{
|
||||||
|
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (timings.prompt_n >= 0) {
|
||||||
|
second_ret.push_back({"timings", timings.to_json()});
|
||||||
|
}
|
||||||
|
|
||||||
return std::vector<json>({initial_ret, second_ret});
|
return std::vector<json>({initial_ret, second_ret});
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -71,8 +71,14 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
|
||||||
})
|
})
|
||||||
content = ""
|
content = ""
|
||||||
last_cmpl_id = None
|
last_cmpl_id = None
|
||||||
for data in res:
|
for i, data in enumerate(res):
|
||||||
choice = data["choices"][0]
|
choice = data["choices"][0]
|
||||||
|
if i == 0:
|
||||||
|
# Check first role message for stream=True
|
||||||
|
assert choice["delta"]["content"] == ""
|
||||||
|
assert choice["delta"]["role"] == "assistant"
|
||||||
|
else:
|
||||||
|
assert "role" not in choice["delta"]
|
||||||
assert data["system_fingerprint"].startswith("b")
|
assert data["system_fingerprint"].startswith("b")
|
||||||
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
|
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
|
||||||
if last_cmpl_id is None:
|
if last_cmpl_id is None:
|
||||||
|
@ -242,12 +248,18 @@ def test_chat_completion_with_timings_per_token():
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"timings_per_token": True,
|
"timings_per_token": True,
|
||||||
})
|
})
|
||||||
for data in res:
|
for i, data in enumerate(res):
|
||||||
assert "timings" in data
|
if i == 0:
|
||||||
assert "prompt_per_second" in data["timings"]
|
# Check first role message for stream=True
|
||||||
assert "predicted_per_second" in data["timings"]
|
assert data["choices"][0]["delta"]["content"] == ""
|
||||||
assert "predicted_n" in data["timings"]
|
assert data["choices"][0]["delta"]["role"] == "assistant"
|
||||||
assert data["timings"]["predicted_n"] <= 10
|
else:
|
||||||
|
assert "role" not in data["choices"][0]["delta"]
|
||||||
|
assert "timings" in data
|
||||||
|
assert "prompt_per_second" in data["timings"]
|
||||||
|
assert "predicted_per_second" in data["timings"]
|
||||||
|
assert "predicted_n" in data["timings"]
|
||||||
|
assert data["timings"]["predicted_n"] <= 10
|
||||||
|
|
||||||
|
|
||||||
def test_logprobs():
|
def test_logprobs():
|
||||||
|
@ -295,17 +307,23 @@ def test_logprobs_stream():
|
||||||
)
|
)
|
||||||
output_text = ''
|
output_text = ''
|
||||||
aggregated_text = ''
|
aggregated_text = ''
|
||||||
for data in res:
|
for i, data in enumerate(res):
|
||||||
choice = data.choices[0]
|
choice = data.choices[0]
|
||||||
if choice.finish_reason is None:
|
if i == 0:
|
||||||
if choice.delta.content:
|
# Check first role message for stream=True
|
||||||
output_text += choice.delta.content
|
assert choice.delta.content == ""
|
||||||
assert choice.logprobs is not None
|
assert choice.delta.role == "assistant"
|
||||||
assert choice.logprobs.content is not None
|
else:
|
||||||
for token in choice.logprobs.content:
|
assert choice.delta.role is None
|
||||||
aggregated_text += token.token
|
if choice.finish_reason is None:
|
||||||
assert token.logprob <= 0.0
|
if choice.delta.content:
|
||||||
assert token.bytes is not None
|
output_text += choice.delta.content
|
||||||
assert token.top_logprobs is not None
|
assert choice.logprobs is not None
|
||||||
assert len(token.top_logprobs) > 0
|
assert choice.logprobs.content is not None
|
||||||
|
for token in choice.logprobs.content:
|
||||||
|
aggregated_text += token.token
|
||||||
|
assert token.logprob <= 0.0
|
||||||
|
assert token.bytes is not None
|
||||||
|
assert token.top_logprobs is not None
|
||||||
|
assert len(token.top_logprobs) > 0
|
||||||
assert aggregated_text == output_text
|
assert aggregated_text == output_text
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue