upgrade to llguidance 0.7.10 (#12576)

This commit is contained in:
Michał Moskal 2025-03-26 11:06:09 -07:00 committed by GitHub
parent 02082f1519
commit 2447ad8a98
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 94 additions and 49 deletions

View file

@ -1086,6 +1086,65 @@ static void test_json_schema() {
});
}
static void one_hot(llama_token_data_array & tok_arr, llama_token selected) {
auto n_vocab = tok_arr.size;
tok_arr.selected = -1;
tok_arr.sorted = false;
for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
tok_arr.data[token_id].id = token_id;
tok_arr.data[token_id].logit = 0.0f;
}
tok_arr.data[selected].logit = 100.0f;
}
static void test_sampler_chain(void) {
auto sparams = llama_sampler_chain_default_params();
sparams.no_perf = false;
llama_sampler * sampler = llama_sampler_chain_init(sparams);
const auto grammar_data = R"(%llguidance {}
start: /[A-Z ]*/)";
llama_sampler_chain_add(sampler, llama_sampler_init_llg(vocab, "lark", grammar_data));
llama_sampler_chain_add(sampler, llama_sampler_init_dist(42));
auto input = "ALL YOUR BASE ARE BELONG TO US";
auto tokens = common_tokenize(vocab, input, false, false);
auto n_vocab = llama_vocab_n_tokens(vocab);
std::vector<llama_token_data> cur;
cur.reserve(n_vocab);
for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
cur.emplace_back(llama_token_data{ token_id, 0.0f, 0.0f });
}
auto tok_arr = llama_token_data_array{ cur.data(), cur.size(), -1, false };
for (const auto token : tokens) {
one_hot(tok_arr, token);
fprintf(stderr, "applying token: %d\n", token);
llama_sampler_apply(sampler, &tok_arr);
auto idx = tok_arr.selected;
fprintf(stderr, " -> %d %f\n", cur[idx].id, cur[idx].logit);
assert(cur[tok_arr.selected].id == token);
llama_sampler_accept(sampler, token);
}
auto tok_eos = llama_vocab_eot(vocab);
if (tok_eos == LLAMA_TOKEN_NULL) {
tok_eos = llama_vocab_eos(vocab);
}
one_hot(tok_arr, tok_eos);
llama_sampler_apply(sampler, &tok_arr);
assert(cur[tok_arr.selected].id == tok_eos);
}
int main(int argc, const char ** argv) {
fprintf(stdout, "Running llguidance integration tests...\n");
@ -1135,6 +1194,9 @@ int main(int argc, const char ** argv) {
test_special_chars();
test_quantifiers();
test_json_schema();
test_sampler_chain();
fprintf(stdout, "All tests passed.\n");
return 0;
}