run : avoid double tokenization (#14327)

* run : avoid double tokenization by adopting common_tokenize heuristic

* build : fix windows gcc and clang warnings

* lint : fixed trailing whitepace

* run : fix is_first flag
This commit is contained in:
Ruikai Peng 2025-06-23 01:28:06 +08:00 committed by GitHub
parent f1f5e82df6
commit 66aba7aca9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -9,6 +9,9 @@
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#if defined(_WIN32) #if defined(_WIN32)
# ifndef NOMINMAX
# define NOMINMAX
# endif
# include <windows.h> # include <windows.h>
# include <io.h> # include <io.h>
#else #else
@ -940,16 +943,29 @@ static int apply_chat_template(const struct common_chat_templates * tmpls, Llama
static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt, static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) { std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
const bool is_first = llama_memory_seq_pos_max(llama_get_memory(llama_data.context.get()), 0) == -1; const bool is_first = llama_memory_seq_pos_max(llama_get_memory(llama_data.context.get()), 0) == -1;
int n_tokens = prompt.size() + 2 * is_first;
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); prompt_tokens.resize(n_tokens);
prompt_tokens.resize(n_prompt_tokens); n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.size(),
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, prompt_tokens.data(), prompt_tokens.size(),
true) < 0) { is_first, /*parse_special =*/true);
printe("failed to tokenize the prompt\n"); if (n_tokens == std::numeric_limits<int32_t>::min()) {
printe("tokenization failed: input too large\n");
return -1; return -1;
} }
if (n_tokens < 0) {
return n_prompt_tokens; prompt_tokens.resize(-n_tokens);
int check = llama_tokenize(vocab, prompt.c_str(), prompt.size(),
prompt_tokens.data(), prompt_tokens.size(),
is_first, /*parse_special =*/true);
if (check != -n_tokens) {
printe("failed to tokenize the prompt (size mismatch)\n");
return -1;
}
n_tokens = check;
} else {
prompt_tokens.resize(n_tokens);
}
return n_tokens;
} }
// Check if we have enough space in the context to evaluate this batch // Check if we have enough space in the context to evaluate this batch