From 91a86a6f354aa73a7aab7bc3d283be410fdc93a5 Mon Sep 17 00:00:00 2001 From: oobabooga Date: Tue, 6 May 2025 15:24:15 -0300 Subject: [PATCH] sampling : don't consider -infinity values in top_n_sigma (#13344) --- src/llama-sampling.cpp | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 0c9c6a31..2869f60d 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1757,20 +1757,28 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t // find max logit and calculate mean float max = cur_p->data[0].logit; float logits_sum = 0; + size_t valid_count = 0; for (size_t i = 0; i < cur_p->size; ++i) { - if (cur_p->data[i].logit > max) { - max = cur_p->data[i].logit; + // Only count non-negative infinity values + if (cur_p->data[i].logit != -INFINITY) { + if (cur_p->data[i].logit > max) { + max = cur_p->data[i].logit; + } + logits_sum += cur_p->data[i].logit; + valid_count++; } - logits_sum += cur_p->data[i].logit; } - float mean = logits_sum/cur_p->size; + float mean = valid_count > 0 ? logits_sum/valid_count : 0; // calculate standard deviation float acc = 0; for (size_t i = 0; i < cur_p->size; ++i) { - acc += pow(cur_p->data[i].logit - mean, 2); + // Skip -infinity in std calculation + if (cur_p->data[i].logit != -INFINITY) { + acc += pow(cur_p->data[i].logit - mean, 2); + } } - float std = sqrt(acc/cur_p->size); + float std = valid_count > 0 ? sqrt(acc/valid_count) : 0; //apply mask for (size_t i = 0; i < cur_p->size; ++i) {