CUDA: fix FTZ in FA for Gemma 3 (#13991)

This commit is contained in:
Johannes Gäßler 2025-06-04 08:57:05 +02:00 committed by GitHub
parent e0e806f52e
commit 0b4be4c435
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -652,9 +652,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
float KQ_max_scale[cols_per_thread];
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]);
const float KQ_max_diff = KQ_max[col] - KQ_max_new[col];
KQ_max_scale[col] = expf(KQ_max_diff);
KQ_max[col] = KQ_max_new[col];
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
}