sampling : don't consider -infinity values in top_n_sigma (#13344)

This commit is contained in:
oobabooga 2025-05-06 15:24:15 -03:00 committed by GitHub
parent f4ed10b69c
commit 91a86a6f35
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1757,20 +1757,28 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
// find max logit and calculate mean // find max logit and calculate mean
float max = cur_p->data[0].logit; float max = cur_p->data[0].logit;
float logits_sum = 0; float logits_sum = 0;
size_t valid_count = 0;
for (size_t i = 0; i < cur_p->size; ++i) { for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].logit > max) { // Only count non-negative infinity values
max = cur_p->data[i].logit; if (cur_p->data[i].logit != -INFINITY) {
if (cur_p->data[i].logit > max) {
max = cur_p->data[i].logit;
}
logits_sum += cur_p->data[i].logit;
valid_count++;
} }
logits_sum += cur_p->data[i].logit;
} }
float mean = logits_sum/cur_p->size; float mean = valid_count > 0 ? logits_sum/valid_count : 0;
// calculate standard deviation // calculate standard deviation
float acc = 0; float acc = 0;
for (size_t i = 0; i < cur_p->size; ++i) { for (size_t i = 0; i < cur_p->size; ++i) {
acc += pow(cur_p->data[i].logit - mean, 2); // Skip -infinity in std calculation
if (cur_p->data[i].logit != -INFINITY) {
acc += pow(cur_p->data[i].logit - mean, 2);
}
} }
float std = sqrt(acc/cur_p->size); float std = valid_count > 0 ? sqrt(acc/valid_count) : 0;
//apply mask //apply mask
for (size_t i = 0; i < cur_p->size; ++i) { for (size_t i = 0; i < cur_p->size; ++i) {