upgrade to llguidance 0.7.10 (#12576)
This commit is contained in:
parent
02082f1519
commit
2447ad8a98
3 changed files with 94 additions and 49 deletions
|
@ -1086,6 +1086,65 @@ static void test_json_schema() {
|
|||
});
|
||||
}
|
||||
|
||||
static void one_hot(llama_token_data_array & tok_arr, llama_token selected) {
|
||||
auto n_vocab = tok_arr.size;
|
||||
|
||||
tok_arr.selected = -1;
|
||||
tok_arr.sorted = false;
|
||||
for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
|
||||
tok_arr.data[token_id].id = token_id;
|
||||
tok_arr.data[token_id].logit = 0.0f;
|
||||
}
|
||||
|
||||
tok_arr.data[selected].logit = 100.0f;
|
||||
}
|
||||
|
||||
static void test_sampler_chain(void) {
|
||||
auto sparams = llama_sampler_chain_default_params();
|
||||
sparams.no_perf = false;
|
||||
llama_sampler * sampler = llama_sampler_chain_init(sparams);
|
||||
|
||||
const auto grammar_data = R"(%llguidance {}
|
||||
start: /[A-Z ]*/)";
|
||||
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_llg(vocab, "lark", grammar_data));
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_dist(42));
|
||||
|
||||
auto input = "ALL YOUR BASE ARE BELONG TO US";
|
||||
auto tokens = common_tokenize(vocab, input, false, false);
|
||||
|
||||
auto n_vocab = llama_vocab_n_tokens(vocab);
|
||||
|
||||
std::vector<llama_token_data> cur;
|
||||
cur.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
|
||||
cur.emplace_back(llama_token_data{ token_id, 0.0f, 0.0f });
|
||||
}
|
||||
auto tok_arr = llama_token_data_array{ cur.data(), cur.size(), -1, false };
|
||||
|
||||
for (const auto token : tokens) {
|
||||
one_hot(tok_arr, token);
|
||||
|
||||
fprintf(stderr, "applying token: %d\n", token);
|
||||
llama_sampler_apply(sampler, &tok_arr);
|
||||
|
||||
auto idx = tok_arr.selected;
|
||||
fprintf(stderr, " -> %d %f\n", cur[idx].id, cur[idx].logit);
|
||||
assert(cur[tok_arr.selected].id == token);
|
||||
llama_sampler_accept(sampler, token);
|
||||
}
|
||||
|
||||
auto tok_eos = llama_vocab_eot(vocab);
|
||||
if (tok_eos == LLAMA_TOKEN_NULL) {
|
||||
tok_eos = llama_vocab_eos(vocab);
|
||||
}
|
||||
|
||||
one_hot(tok_arr, tok_eos);
|
||||
|
||||
llama_sampler_apply(sampler, &tok_arr);
|
||||
assert(cur[tok_arr.selected].id == tok_eos);
|
||||
}
|
||||
|
||||
int main(int argc, const char ** argv) {
|
||||
fprintf(stdout, "Running llguidance integration tests...\n");
|
||||
|
||||
|
@ -1135,6 +1194,9 @@ int main(int argc, const char ** argv) {
|
|||
test_special_chars();
|
||||
test_quantifiers();
|
||||
test_json_schema();
|
||||
|
||||
test_sampler_chain();
|
||||
|
||||
fprintf(stdout, "All tests passed.\n");
|
||||
return 0;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue