ggml/ex: calculate accuracy in graph, adapt MNIST (ggml/980)

This commit is contained in:
Johannes Gäßler 2024-10-03 17:29:59 +02:00 committed by Georgi Gerganov
parent eee39bdc96
commit fabdc3bda3
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
11 changed files with 389 additions and 8 deletions

View file

@ -259,7 +259,7 @@ static __global__ void flash_attn_tile_ext_f16(
}
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
kqsum_j = warp_reduce_sum(kqsum_j);
kqsum_j = warp_reduce_sum((float)kqsum_j);
#pragma unroll
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {