server
: streaming of tool calls and thoughts when --jinja
is on (#12379)
* add common_json w/ support for truncated json healing * add common_chat_msg_diff * partial common_chat_parse * refactor parser w/ optionals * server: wire chat diffs in stream mode * fix trigger of thinking models (must happen after thoughts are closed) * fix functionary v3.2 raw python! * rename: common_chat_syntax (now contains format) * rm common_regex.at_start * don't return empty <think></think> * accommodate yet another deepseek r1 distill fantasy syntax (`<|tool▁calls|>`) * fix QwQ 32B tool call parsing after thoughts (hermes2) * better logs for grammar triggers * consume spaces after parse_json_tool_calls * fix required tool calls w/ thinking models that have pre-opened thinking tags * fix thinking model's initial trigger + test qwq's template * run most test_tool_call tests in stream + non-stream modes * make functionary v3.2 parsing more strict (differentiate first match from others) * send final diff from server, to close off raw python arguments * support partial content streaming in Generic mode * tool-call: allow content prelude before hermes2 tool calls (for Qwen2.5) * Update function-calling.md * Update tool_bench.py * chat-parser: remove input from exception (llm output may contain PII) --------- Co-authored-by: ochafik <ochafik@google.com> Co-authored-by: Olivier Chafik <ochafik@users.noreply.github.com>
This commit is contained in:
parent
a2d02d5793
commit
f5cd27b71d
23 changed files with 3245 additions and 1091 deletions
|
@ -1,3 +1,4 @@
|
|||
#include "chat.h"
|
||||
#include "utils.hpp"
|
||||
|
||||
#include "arg.h"
|
||||
|
@ -114,11 +115,11 @@ struct slot_params {
|
|||
struct common_params_speculative speculative;
|
||||
|
||||
// OAI-compat fields
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_syntax oaicompat_chat_syntax;
|
||||
|
||||
json to_json() const {
|
||||
std::vector<std::string> samplers;
|
||||
|
@ -176,7 +177,10 @@ struct slot_params {
|
|||
{"grammar_lazy", sampling.grammar_lazy},
|
||||
{"grammar_triggers", grammar_triggers},
|
||||
{"preserved_tokens", sampling.preserved_tokens},
|
||||
{"chat_format", common_chat_format_name(oaicompat_chat_format)},
|
||||
{"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)},
|
||||
{"reasoning_format", (oaicompat_chat_syntax.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "deepseek" : "none")},
|
||||
{"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content},
|
||||
{"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open},
|
||||
{"samplers", samplers},
|
||||
{"speculative.n_max", speculative.n_max},
|
||||
{"speculative.n_min", speculative.n_min},
|
||||
|
@ -352,11 +356,14 @@ struct server_task {
|
|||
{
|
||||
auto it = data.find("chat_format");
|
||||
if (it != data.end()) {
|
||||
params.oaicompat_chat_format = static_cast<common_chat_format>(it->get<int>());
|
||||
SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str());
|
||||
params.oaicompat_chat_syntax.format = static_cast<common_chat_format>(it->get<int>());
|
||||
SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format).c_str());
|
||||
} else {
|
||||
params.oaicompat_chat_format = defaults.oaicompat_chat_format;
|
||||
params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format;
|
||||
}
|
||||
params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format;
|
||||
params.oaicompat_chat_syntax.reasoning_in_content = params.stream;
|
||||
params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -396,7 +403,14 @@ struct server_task {
|
|||
params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
|
||||
}
|
||||
} else {
|
||||
params.sampling.grammar_triggers.push_back(std::move(ct.value));
|
||||
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) {
|
||||
SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str());
|
||||
} else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) {
|
||||
SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str());
|
||||
} else {
|
||||
throw std::runtime_error("Unknown grammar trigger type");
|
||||
}
|
||||
params.sampling.grammar_triggers.emplace_back(std::move(ct.value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -639,11 +653,12 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
slot_params generation_params;
|
||||
|
||||
// OAI-compat fields
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_msg oaicompat_msg;
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
||||
|
||||
virtual int get_index() override {
|
||||
return index;
|
||||
|
@ -738,47 +753,20 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
json to_json_oaicompat_chat() {
|
||||
std::string finish_reason = "length";
|
||||
common_chat_msg msg;
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
SRV_DBG("Parsing chat message: %s\n", content.c_str());
|
||||
msg = common_chat_parse(content, oaicompat_chat_format);
|
||||
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
|
||||
if (!oaicompat_msg.empty()) {
|
||||
msg = oaicompat_msg;
|
||||
} else {
|
||||
msg.role = "assistant";
|
||||
msg.content = content;
|
||||
}
|
||||
|
||||
json message {
|
||||
{"role", "assistant"},
|
||||
};
|
||||
if (!msg.reasoning_content.empty()) {
|
||||
message["reasoning_content"] = msg.reasoning_content;
|
||||
}
|
||||
if (msg.content.empty() && !msg.tool_calls.empty()) {
|
||||
message["content"] = json();
|
||||
} else {
|
||||
message["content"] = msg.content;
|
||||
}
|
||||
if (!msg.tool_calls.empty()) {
|
||||
auto tool_calls = json::array();
|
||||
for (const auto & tc : msg.tool_calls) {
|
||||
tool_calls.push_back({
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", tc.name},
|
||||
{"arguments", tc.arguments},
|
||||
}},
|
||||
// Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
|
||||
// We only generate a random id for the ones that don't generate one by themselves
|
||||
// (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
|
||||
{"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
|
||||
});
|
||||
}
|
||||
message["tool_calls"] = tool_calls;
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
|
||||
}
|
||||
|
||||
json choice {
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"message", message},
|
||||
{"message", msg.to_json_oaicompat<json>()},
|
||||
};
|
||||
|
||||
if (!stream && probs_output.size() > 0) {
|
||||
|
@ -818,17 +806,35 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
std::time_t t = std::time(0);
|
||||
std::string finish_reason = "length";
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
finish_reason = "stop";
|
||||
finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls";
|
||||
}
|
||||
|
||||
json choice = json {
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"delta", json::object()}
|
||||
};
|
||||
json deltas = json::array();
|
||||
for (const auto & diff : oaicompat_msg_diffs) {
|
||||
deltas.push_back({
|
||||
{"choices", json::array({
|
||||
json {
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", common_chat_msg_diff_to_json_oaicompat<json>(diff)},
|
||||
},
|
||||
})},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"system_fingerprint", build_info},
|
||||
{"object", "chat.completion.chunk"},
|
||||
});
|
||||
}
|
||||
|
||||
json ret = json {
|
||||
{"choices", json::array({choice})},
|
||||
deltas.push_back({
|
||||
{"choices", json::array({
|
||||
json {
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"delta", json::object()},
|
||||
},
|
||||
})},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
|
@ -839,18 +845,18 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
{"prompt_tokens", n_prompt_tokens},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens},
|
||||
}},
|
||||
};
|
||||
});
|
||||
|
||||
if (timings.prompt_n >= 0) {
|
||||
ret.push_back({"timings", timings.to_json()});
|
||||
deltas.back().push_back({"timings", timings.to_json()});
|
||||
}
|
||||
|
||||
// extra fields for debugging purposes
|
||||
if (verbose) {
|
||||
ret["__verbose"] = to_json_non_oaicompat();
|
||||
if (verbose && !deltas.empty()) {
|
||||
deltas.front()["__verbose"] = to_json_non_oaicompat();
|
||||
}
|
||||
|
||||
return ret;
|
||||
return deltas;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -868,10 +874,11 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
result_timings timings;
|
||||
|
||||
// OAI-compat fields
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
||||
|
||||
virtual int get_index() override {
|
||||
return index;
|
||||
|
@ -955,84 +962,50 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
std::time_t t = std::time(0);
|
||||
json choices;
|
||||
|
||||
if (first) {
|
||||
if (content.empty()) {
|
||||
choices = json::array({json{{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", json{{"role", "assistant"}}}}});
|
||||
} else {
|
||||
// We have to send this as two updates to conform to openai behavior
|
||||
// initial_ret is the role message for stream=True
|
||||
json initial_ret = json{{"choices", json::array({json{
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", json{
|
||||
{"role", "assistant"},
|
||||
{"content", ""}
|
||||
}}}})},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"system_fingerprint", build_info},
|
||||
{"object", "chat.completion.chunk"}};
|
||||
|
||||
json second_ret = json{
|
||||
{"choices", json::array({json{{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", json {
|
||||
{"content", content}}}
|
||||
}})},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"system_fingerprint", build_info},
|
||||
{"object", "chat.completion.chunk"}};
|
||||
|
||||
if (prob_output.probs.size() > 0) {
|
||||
second_ret["choices"][0]["logprobs"] = json{
|
||||
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
|
||||
};
|
||||
}
|
||||
|
||||
if (timings.prompt_n >= 0) {
|
||||
second_ret.push_back({"timings", timings.to_json()});
|
||||
}
|
||||
|
||||
return std::vector<json>({initial_ret, second_ret});
|
||||
}
|
||||
} else {
|
||||
choices = json::array({json{
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta",
|
||||
json {
|
||||
{"content", content},
|
||||
}},
|
||||
}});
|
||||
}
|
||||
|
||||
GGML_ASSERT(choices.size() >= 1);
|
||||
|
||||
if (prob_output.probs.size() > 0) {
|
||||
choices[0]["logprobs"] = json{
|
||||
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
|
||||
};
|
||||
}
|
||||
|
||||
json ret = json {
|
||||
{"choices", choices},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"system_fingerprint", build_info},
|
||||
{"object", "chat.completion.chunk"}
|
||||
std::vector<json> deltas;
|
||||
auto add_delta = [&](const json & delta) {
|
||||
deltas.push_back({
|
||||
{"choices", json::array({
|
||||
json {
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", delta},
|
||||
},
|
||||
})},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"system_fingerprint", build_info},
|
||||
{"object", "chat.completion.chunk"},
|
||||
});
|
||||
};
|
||||
|
||||
if (timings.prompt_n >= 0) {
|
||||
ret.push_back({"timings", timings.to_json()});
|
||||
// We have to send an initial update to conform to openai behavior
|
||||
if (first) {
|
||||
add_delta({
|
||||
{"role", "assistant"},
|
||||
{"content", nullptr},
|
||||
});
|
||||
}
|
||||
|
||||
return std::vector<json>({ret});
|
||||
for (const auto & diff : oaicompat_msg_diffs) {
|
||||
add_delta(common_chat_msg_diff_to_json_oaicompat<json>(diff));
|
||||
}
|
||||
|
||||
if (!deltas.empty()) {
|
||||
GGML_ASSERT(deltas[deltas.size() - 1].at("choices").size() >= 1);
|
||||
|
||||
if (prob_output.probs.size() > 0) {
|
||||
deltas[deltas.size() - 1].at("choices").at(0)["logprobs"] = json {
|
||||
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
|
||||
};
|
||||
}
|
||||
|
||||
if (timings.prompt_n >= 0) {
|
||||
deltas[deltas.size() - 1].push_back({"timings", timings.to_json()});
|
||||
}
|
||||
}
|
||||
|
||||
return deltas;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1293,6 +1266,7 @@ struct server_slot {
|
|||
|
||||
std::string generated_text;
|
||||
llama_tokens generated_tokens;
|
||||
common_chat_msg chat_msg;
|
||||
|
||||
server_tokens cache_tokens;
|
||||
|
||||
|
@ -1313,6 +1287,7 @@ struct server_slot {
|
|||
llama_token sampled;
|
||||
|
||||
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
std::vector<std::string> generated_tool_call_ids;
|
||||
|
||||
// stats
|
||||
size_t n_sent_text = 0; // number of sent text character
|
||||
|
@ -1342,9 +1317,13 @@ struct server_slot {
|
|||
n_past = 0;
|
||||
n_sent_text = 0;
|
||||
task_type = SERVER_TASK_TYPE_COMPLETION;
|
||||
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
|
||||
generated_tokens.clear();
|
||||
generated_token_probs.clear();
|
||||
chat_msg = {};
|
||||
json_schema = json();
|
||||
generated_tool_call_ids.clear();
|
||||
|
||||
// clear speculative decoding stats
|
||||
n_draft_total = 0;
|
||||
|
@ -1424,6 +1403,21 @@ struct server_slot {
|
|||
return timings;
|
||||
}
|
||||
|
||||
const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
|
||||
auto previous_msg = chat_msg;
|
||||
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
|
||||
auto new_msg = common_chat_parse(
|
||||
generated_text,
|
||||
/* is_partial= */ stop != STOP_TYPE_EOS,
|
||||
params.oaicompat_chat_syntax);
|
||||
if (!new_msg.empty()) {
|
||||
new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id);
|
||||
chat_msg = new_msg;
|
||||
diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
|
||||
}
|
||||
return chat_msg;
|
||||
}
|
||||
|
||||
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
|
||||
size_t stop_pos = std::string::npos;
|
||||
|
||||
|
@ -2475,10 +2469,12 @@ struct server_context {
|
|||
res->n_prompt_tokens = slot.n_prompt_tokens;
|
||||
res->post_sampling_probs = slot.params.post_sampling_probs;
|
||||
|
||||
res->verbose = slot.params.verbose;
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
res->verbose = slot.params.verbose;
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
|
||||
slot.update_chat_msg(res->oaicompat_msg_diffs);
|
||||
|
||||
// populate res.probs_output
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
|
@ -2499,7 +2495,7 @@ struct server_context {
|
|||
res->id_slot = slot.id;
|
||||
|
||||
res->index = slot.index;
|
||||
res->content = std::move(slot.generated_text);
|
||||
res->content = slot.generated_text;
|
||||
res->tokens = std::move(slot.generated_tokens);
|
||||
res->timings = slot.get_timings();
|
||||
res->prompt = slot.prompt_tokens.detokenize(ctx, true);
|
||||
|
@ -2519,7 +2515,8 @@ struct server_context {
|
|||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
res->oaicompat_chat_format = slot.params.oaicompat_chat_format;
|
||||
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
|
||||
|
||||
// populate res.probs_output
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
|
||||
|
|
|
@ -75,7 +75,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
|
|||
choice = data["choices"][0]
|
||||
if i == 0:
|
||||
# Check first role message for stream=True
|
||||
assert choice["delta"]["content"] == ""
|
||||
assert choice["delta"]["content"] is None
|
||||
assert choice["delta"]["role"] == "assistant"
|
||||
else:
|
||||
assert "role" not in choice["delta"]
|
||||
|
@ -92,7 +92,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
|
|||
assert choice["finish_reason"] == finish_reason
|
||||
else:
|
||||
assert choice["finish_reason"] is None
|
||||
content += choice["delta"]["content"]
|
||||
content += choice["delta"]["content"] or ''
|
||||
|
||||
|
||||
def test_chat_completion_with_openai_library():
|
||||
|
@ -251,8 +251,9 @@ def test_chat_completion_with_timings_per_token():
|
|||
for i, data in enumerate(res):
|
||||
if i == 0:
|
||||
# Check first role message for stream=True
|
||||
assert data["choices"][0]["delta"]["content"] == ""
|
||||
assert data["choices"][0]["delta"]["content"] is None
|
||||
assert data["choices"][0]["delta"]["role"] == "assistant"
|
||||
assert "timings" not in data, f'First event should not have timings: {data}'
|
||||
else:
|
||||
assert "role" not in data["choices"][0]["delta"]
|
||||
assert "timings" in data
|
||||
|
@ -311,7 +312,7 @@ def test_logprobs_stream():
|
|||
choice = data.choices[0]
|
||||
if i == 0:
|
||||
# Check first role message for stream=True
|
||||
assert choice.delta.content == ""
|
||||
assert choice.delta.content is None
|
||||
assert choice.delta.role == "assistant"
|
||||
else:
|
||||
assert choice.delta.role is None
|
||||
|
|
|
@ -8,6 +8,7 @@ path = Path(__file__).resolve().parents[1]
|
|||
sys.path.insert(0, str(path))
|
||||
|
||||
from utils import *
|
||||
from enum import Enum
|
||||
|
||||
server: ServerProcess
|
||||
|
||||
|
@ -20,7 +21,11 @@ def create_server():
|
|||
server = ServerPreset.tinyllama2()
|
||||
server.model_alias = "tinyllama-2-tool-call"
|
||||
server.server_port = 8081
|
||||
server.n_slots = 1
|
||||
|
||||
class CompletionMode(Enum):
|
||||
NORMAL = "normal"
|
||||
STREAMED = "streamed"
|
||||
|
||||
TEST_TOOL = {
|
||||
"type":"function",
|
||||
|
@ -73,9 +78,8 @@ WEATHER_TOOL = {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs):
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
|
@ -86,13 +90,13 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
|
|||
"parallel_tool_calls": False,
|
||||
**kwargs,
|
||||
})
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
# assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
||||
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
|
||||
assert expected_function_name == tool_call["function"]["name"]
|
||||
actual_arguments = tool_call["function"]["arguments"]
|
||||
|
@ -102,12 +106,16 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
|
|||
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
||||
("google-gemma-2-2b-it", TEST_TOOL, "success"),
|
||||
("google-gemma-2-2b-it", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
|
||||
])
|
||||
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
|
||||
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
|
||||
global server
|
||||
n_predict = 1024
|
||||
# server = ServerPreset.stories15m_moe()
|
||||
|
@ -115,31 +123,43 @@ def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict,
|
|||
server.n_predict = n_predict
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, temperature=0.0, top_k=1, top_p=1.0)
|
||||
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
||||
("meta-llama-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
|
||||
|
||||
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
|
||||
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
|
||||
|
||||
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
|
||||
("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
|
||||
# Functionary v3.2 format supports raw python content, which w/ a dummy stories model will never end on its own.
|
||||
# ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
|
||||
|
||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
|
||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"),
|
||||
|
||||
("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"),
|
||||
|
||||
("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
|
||||
("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"),
|
||||
|
||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
|
||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"),
|
||||
|
||||
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"),
|
||||
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"),
|
||||
|
||||
("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"),
|
||||
# ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "codeFalse), True),
|
||||
# ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"),
|
||||
|
||||
])
|
||||
def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None):
|
||||
def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
|
||||
global server
|
||||
n_predict = 512
|
||||
# server = ServerPreset.stories15m_moe()
|
||||
|
@ -147,10 +167,11 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
|
|||
server.n_predict = n_predict
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict)
|
||||
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [
|
||||
(TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
|
@ -184,9 +205,9 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
|
|||
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
# (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
# (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
# (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
|
||||
(PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
|
||||
|
@ -203,10 +224,9 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
|
|||
(TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
])
|
||||
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||
global server
|
||||
n_predict = 512
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192
|
||||
server.n_predict = n_predict
|
||||
|
@ -219,7 +239,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
|
|||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
|
@ -228,12 +248,12 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
|
|||
"tool_choice": "required",
|
||||
"tools": [tool],
|
||||
"parallel_tool_calls": False,
|
||||
"stream": stream == CompletionMode.STREAMED,
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"top_p": 1.0,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
choice = body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
|
@ -248,7 +268,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
|
|||
|
||||
|
||||
def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int, tools: list[dict], tool_choice: str | None, **kwargs):
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
|
@ -258,26 +278,27 @@ def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int,
|
|||
"tool_choice": tool_choice,
|
||||
**kwargs,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
choice = body["choices"][0]
|
||||
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
|
||||
("meta-llama-Llama-3.3-70B-Instruct", 128, [], None),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", 128, [TEST_TOOL], None),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'),
|
||||
])
|
||||
def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
|
||||
def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
|
||||
global server
|
||||
server.jinja = True
|
||||
server.n_predict = n_predict
|
||||
server.jinja = True
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
|
||||
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
|
||||
("meetkai-functionary-medium-v3.2", 256, [], None),
|
||||
("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None),
|
||||
|
@ -289,16 +310,17 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t
|
|||
("meta-llama-Llama-3.2-3B-Instruct", 256, [TEST_TOOL], None),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'),
|
||||
])
|
||||
def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
|
||||
def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
|
||||
global server
|
||||
server.jinja = True
|
||||
server.n_predict = n_predict
|
||||
server.jinja = True
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
|
||||
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("hf_repo,template_override", [
|
||||
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
@ -321,11 +343,11 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
|
|||
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
# ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
# ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
|
||||
("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
|
||||
# ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
|
||||
# ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
|
||||
|
||||
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
|
@ -339,10 +361,9 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
|
|||
|
||||
# ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
])
|
||||
def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||
global server
|
||||
n_predict = 512
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192
|
||||
server.n_predict = n_predict
|
||||
|
@ -355,11 +376,11 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
|
|||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
do_test_weather(server, max_tokens=n_predict)
|
||||
do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
|
||||
|
||||
|
||||
def do_test_weather(server: ServerProcess, **kwargs):
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
|
||||
{"role": "user", "content": "What is the weather in Istanbul?"},
|
||||
|
@ -367,14 +388,13 @@ def do_test_weather(server: ServerProcess, **kwargs):
|
|||
"tools": [WEATHER_TOOL],
|
||||
**kwargs,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
choice = body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
||||
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"], f'Expected weather tool call, got {tool_call["function"]["name"]}'
|
||||
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||
assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
|
||||
location = actual_arguments["location"]
|
||||
|
@ -383,6 +403,7 @@ def do_test_weather(server: ServerProcess, **kwargs):
|
|||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [
|
||||
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
|
||||
(None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
|
||||
|
@ -400,9 +421,8 @@ def do_test_weather(server: ServerProcess, **kwargs):
|
|||
# (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
# ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
])
|
||||
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||
global server
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192 * 2
|
||||
server.n_predict = n_predict
|
||||
|
@ -415,11 +435,11 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
|
|||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
do_test_calc_result(server, result_override, n_predict)
|
||||
do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED)
|
||||
|
||||
|
||||
def do_test_calc_result(server: ServerProcess, result_override: str | None, n_predict: int, **kwargs):
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a tools-calling assistant. You express numerical values with at most two decimals."},
|
||||
|
@ -466,8 +486,7 @@ def do_test_calc_result(server: ServerProcess, result_override: str | None, n_pr
|
|||
],
|
||||
**kwargs,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
choice = body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls is None, f'Expected no tool call in {choice["message"]}'
|
||||
content = choice["message"].get("content")
|
||||
|
@ -480,18 +499,18 @@ def do_test_calc_result(server: ServerProcess, result_override: str | None, n_pr
|
|||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("n_predict,reasoning_format,expect_content,expect_reasoning_content,hf_repo,template_override", [
|
||||
(128, 'deepseek', "^The sum of 102 and 7 is 109[\\s\\S]*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(128, None, "^The sum of 102 and 7 is 109[\\s\\S]*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
|
||||
(1024, 'deepseek', "To find the sum of[\\s\\S]*", "I need to calculate the sum of 102 and 7[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(1024, 'none', "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
|
||||
(1024, 'deepseek', "To find the sum of[\\s\\S]*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
@pytest.mark.parametrize("n_predict,reasoning_format,stream,expect_reasoning_content,expect_content,hf_repo,template_override", [
|
||||
(128, 'deepseek', CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(128, None, CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(1024, 'deepseek', CompletionMode.NORMAL, "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(1024, 'deepseek', CompletionMode.STREAMED, None, "^<think>I need to calculate [\\s\\S]*?</think>To find the sum of [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(1024, 'deepseek', CompletionMode.NORMAL, "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
(1024, 'deepseek', CompletionMode.STREAMED, None, "^<think>First, I [\\s\\S]*?</think>To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
# (1024, 'none', CompletionMode.NORMAL, None, "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
# (128, 'deepseek', None, "^Okay, let me figure out the sum of 102 and 7[\\s\\S]*", "bartowski/Qwen_QwQ-32B-GGUF:Q4_K_M", None),
|
||||
])
|
||||
def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||
global server
|
||||
server.n_slots = 1
|
||||
server.reasoning_format = reasoning_format
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192 * 2
|
||||
|
@ -505,14 +524,14 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
|
|||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the sum of 102 and 7?"},
|
||||
]
|
||||
],
|
||||
"stream": stream == CompletionMode.STREAMED,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
choice = body["choices"][0]
|
||||
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
|
||||
|
||||
content = choice["message"].get("content")
|
||||
|
@ -529,6 +548,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
|
|||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("hf_repo,template_override", [
|
||||
("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
|
||||
|
@ -562,10 +582,9 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
|
|||
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
|
||||
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"),
|
||||
])
|
||||
def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||
global server
|
||||
n_predict = 512 # High because of DeepSeek R1
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192
|
||||
server.n_predict = n_predict
|
||||
|
@ -579,11 +598,11 @@ def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | Non
|
|||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
|
||||
do_test_hello_world(server, max_tokens=n_predict)
|
||||
do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
|
||||
|
||||
|
||||
def do_test_hello_world(server: ServerProcess, **kwargs):
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a tool-calling agent."},
|
||||
{"role": "user", "content": "say hello world with python"},
|
||||
|
@ -591,16 +610,15 @@ def do_test_hello_world(server: ServerProcess, **kwargs):
|
|||
"tools": [PYTHON_TOOL],
|
||||
**kwargs,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
choice = body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
||||
assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
|
||||
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||
assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
|
||||
code = actual_arguments["code"]
|
||||
assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
|
||||
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}'
|
||||
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', re.sub(r'#.*\n?', '', code)), f'Expected hello world, got {code}'
|
||||
|
|
|
@ -294,6 +294,77 @@ class ServerProcess:
|
|||
print("Partial response from server", json.dumps(data, indent=2))
|
||||
yield data
|
||||
|
||||
def make_any_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: dict | None = None,
|
||||
headers: dict | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> dict:
|
||||
stream = data.get('stream', False)
|
||||
if stream:
|
||||
content: list[str] = []
|
||||
tool_calls: list[dict] = []
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
content_parts = 0
|
||||
tool_call_parts = 0
|
||||
arguments_parts = 0
|
||||
|
||||
for chunk in self.make_stream_request(method, path, data, headers):
|
||||
assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}'
|
||||
choice = chunk['choices'][0]
|
||||
if choice['delta'].get('content') is not None:
|
||||
assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!'
|
||||
content.append(choice['delta']['content'])
|
||||
content_parts += 1
|
||||
if choice['delta'].get('finish_reason') is not None:
|
||||
finish_reason = choice['delta']['finish_reason']
|
||||
for tc in choice['delta'].get('tool_calls', []):
|
||||
if 'function' not in tc:
|
||||
raise ValueError(f"Expected function type, got {tc['type']}")
|
||||
if tc['index'] >= len(tool_calls):
|
||||
tool_calls.append(dict(
|
||||
id="",
|
||||
type="function",
|
||||
function=dict(
|
||||
name="",
|
||||
arguments="",
|
||||
)
|
||||
))
|
||||
tool_call = tool_calls[tc['index']]
|
||||
if tc.get('id') is not None:
|
||||
tool_call['id'] = tc['id']
|
||||
fct = tc['function']
|
||||
if fct.get('name') is not None:
|
||||
tool_call['function']['name'] = fct['name']
|
||||
if fct.get('arguments') is not None:
|
||||
assert len(fct['arguments']) > 0, f'Expected non empty arguments delta!'
|
||||
tool_call['function']['arguments'] += fct['arguments']
|
||||
|
||||
print(f'Streamed response had {content_parts} content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
|
||||
result = dict(
|
||||
choices=[
|
||||
dict(
|
||||
index=0,
|
||||
finish_reason=finish_reason,
|
||||
message=dict(
|
||||
role='assistant',
|
||||
content=''.join(content) if content else None,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
print("Final response from server", json.dumps(result, indent=2))
|
||||
return result
|
||||
else:
|
||||
response = self.make_request(method, path, data, headers, timeout=timeout)
|
||||
assert response.status_code == 200, f"Server returned error: {response.status_code}"
|
||||
return response.body
|
||||
|
||||
|
||||
|
||||
server_instances: Set[ServerProcess] = set()
|
||||
|
||||
|
|
|
@ -474,26 +474,6 @@ static std::string gen_tool_call_id() {
|
|||
// other common utils
|
||||
//
|
||||
|
||||
static bool ends_with(const std::string & str, const std::string & suffix) {
|
||||
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
|
||||
}
|
||||
|
||||
static size_t find_partial_stop_string(const std::string &stop, const std::string &text) {
|
||||
if (!text.empty() && !stop.empty()) {
|
||||
const char text_last_char = text.back();
|
||||
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
|
||||
if (stop[char_index] == text_last_char) {
|
||||
const std::string current_partial = stop.substr(0, char_index + 1);
|
||||
if (ends_with(text, current_partial)) {
|
||||
return text.size() - char_index - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
// TODO: reuse llama_detokenize
|
||||
template <class Iter>
|
||||
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
|
||||
|
@ -599,19 +579,16 @@ static json oaicompat_chat_params_parse(
|
|||
json llama_params;
|
||||
|
||||
auto tools = json_value(body, "tools", json());
|
||||
auto has_tools = tools.is_array() && !tools.empty();
|
||||
auto stream = json_value(body, "stream", false);
|
||||
auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
|
||||
|
||||
if (tools.is_array() && !tools.empty()) {
|
||||
if (stream) {
|
||||
throw std::runtime_error("Cannot use tools with stream");
|
||||
}
|
||||
if (!opt.use_jinja) {
|
||||
if (!opt.use_jinja) {
|
||||
if (has_tools) {
|
||||
throw std::runtime_error("tools param requires --jinja flag");
|
||||
}
|
||||
}
|
||||
if (!opt.use_jinja) {
|
||||
if (body.contains("tool_choice") && !body.at("tool_choice").is_null()) {
|
||||
throw std::runtime_error("Unsupported param: tool_choice");
|
||||
if (tool_choice != "auto") {
|
||||
throw std::runtime_error("tool_choice param requires --jinja flag");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -749,14 +726,12 @@ static json oaicompat_chat_params_parse(
|
|||
common_chat_templates_inputs inputs;
|
||||
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
|
||||
inputs.tools = common_chat_tools_parse_oaicompat(tools);
|
||||
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
|
||||
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice);
|
||||
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
|
||||
inputs.grammar = grammar;
|
||||
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
|
||||
inputs.use_jinja = opt.use_jinja;
|
||||
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
|
||||
inputs.extract_reasoning = opt.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
|
||||
inputs.reasoning_format = opt.reasoning_format;
|
||||
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
|
||||
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
|
||||
}
|
||||
|
@ -774,7 +749,8 @@ static json oaicompat_chat_params_parse(
|
|||
throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list.");
|
||||
}
|
||||
|
||||
inputs.extract_reasoning = false;
|
||||
/* TODO: test this properly */
|
||||
inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
inputs.add_generation_prompt = true;
|
||||
}
|
||||
|
||||
|
@ -799,6 +775,7 @@ static json oaicompat_chat_params_parse(
|
|||
}
|
||||
llama_params["grammar_triggers"] = grammar_triggers;
|
||||
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
|
||||
llama_params["thinking_forced_open"] = chat_params.thinking_forced_open;
|
||||
for (const auto & stop : chat_params.additional_stops) {
|
||||
llama_params["stop"].push_back(stop);
|
||||
}
|
||||
|
@ -812,6 +789,9 @@ static json oaicompat_chat_params_parse(
|
|||
// Handle "logprobs" field
|
||||
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
|
||||
if (json_value(body, "logprobs", false)) {
|
||||
if (has_tools && stream) {
|
||||
throw std::runtime_error("logprobs is not supported with tools + stream");
|
||||
}
|
||||
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
|
||||
} else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
|
||||
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue