llama : default sampling changes + greedy update (#9897)
* llama : deprecate softmax sampler + fix dist sampler ggml-ci * tests : replace macros with functions ggml-ci * sampling : change temperature sampler logic For t <= 0.0f, keep the max logit intact and set the rest to -inf * cont : no need for special "greedy" logic top-k == 1 is the same * tests : init prob correctly * llama : handle temp <= 0.0 in the temp_ext sampler too ggml-ci * cont : avoid extra loop in temperature sampler for sub-zero temp ggml-ci
This commit is contained in:
parent
bc21975084
commit
55e47786e3
7 changed files with 202 additions and 218 deletions
|
@ -46,7 +46,6 @@ actor LlamaContext {
|
|||
let sparams = llama_sampler_chain_default_params()
|
||||
self.sampling = llama_sampler_chain_init(sparams)
|
||||
llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4))
|
||||
llama_sampler_chain_add(self.sampling, llama_sampler_init_softmax())
|
||||
llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
|
||||
}
|
||||
|
||||
|
|
|
@ -42,7 +42,6 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_softmax());
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sparams.seed));
|
||||
|
||||
// tokenize prompt
|
||||
|
@ -107,7 +106,6 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_sampler * smpl2 = llama_sampler_chain_init(sparams);
|
||||
|
||||
llama_sampler_chain_add(smpl2, llama_sampler_init_softmax());
|
||||
llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sparams.seed));
|
||||
|
||||
printf("\nsecond run: %s", params.prompt.c_str());
|
||||
|
@ -171,7 +169,6 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
|
||||
|
||||
llama_sampler_chain_add(smpl3, llama_sampler_init_softmax());
|
||||
llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sparams.seed));
|
||||
|
||||
printf("\nsingle seq run: %s", params.prompt.c_str());
|
||||
|
|
|
@ -185,8 +185,6 @@ int main(int argc, char ** argv) {
|
|||
// target model sampling context (reuse the llama_context's sampling instance)
|
||||
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams);
|
||||
|
||||
struct llama_sampler * softmax = llama_sampler_init_softmax();
|
||||
|
||||
// draft sequence data
|
||||
std::vector<seq_draft> drafts(n_seq_dft);
|
||||
|
||||
|
@ -629,7 +627,6 @@ int main(int argc, char ** argv) {
|
|||
common_sampler_free(drafts[s].smpl);
|
||||
}
|
||||
|
||||
llama_sampler_free(softmax);
|
||||
llama_batch_free(batch_dft);
|
||||
|
||||
llama_free(ctx_tgt);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue