minja: sync (qwen3) (#13573)
* minja: sync f06140fa52
- https://github.com/google/minja/pull/67 (@grf53)
- https://github.com/google/minja/pull/66 (@taha-yassine)
- https://github.com/google/minja/pull/63 (@grf53)
- https://github.com/google/minja/pull/58
---------
Co-authored-by: ochafik <ochafik@google.com>
This commit is contained in:
parent
c6a2c9e741
commit
bc098c3cf0
2 changed files with 78 additions and 41 deletions
|
@ -13,10 +13,12 @@
|
|||
#include <chrono>
|
||||
#include <cstddef>
|
||||
#include <cstdio>
|
||||
#include <ctime>
|
||||
#include <exception>
|
||||
#include <iomanip>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
@ -393,8 +395,8 @@ class chat_template {
|
|||
|
||||
for (const auto & message_ : adjusted_messages) {
|
||||
auto message = message_;
|
||||
if (!message.contains("role") || !message.contains("content")) {
|
||||
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
|
||||
if (!message.contains("role") || (!message.contains("content") && !message.contains("tool_calls"))) {
|
||||
throw std::runtime_error("message must have 'role' and one of 'content' or 'tool_calls' fields: " + message.dump());
|
||||
}
|
||||
std::string role = message.at("role");
|
||||
|
||||
|
@ -415,7 +417,6 @@ class chat_template {
|
|||
}
|
||||
}
|
||||
if (polyfill_tool_calls) {
|
||||
auto content = message.at("content");
|
||||
auto tool_calls = json::array();
|
||||
for (const auto & tool_call : message.at("tool_calls")) {
|
||||
if (tool_call.at("type") != "function") {
|
||||
|
@ -434,8 +435,11 @@ class chat_template {
|
|||
auto obj = json {
|
||||
{"tool_calls", tool_calls},
|
||||
};
|
||||
if (!content.is_null() && !content.empty()) {
|
||||
obj["content"] = content;
|
||||
if (message.contains("content")) {
|
||||
auto content = message.at("content");
|
||||
if (!content.is_null() && !content.empty()) {
|
||||
obj["content"] = content;
|
||||
}
|
||||
}
|
||||
message["content"] = obj.dump(2);
|
||||
message.erase("tool_calls");
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue