ggml/ex: calculate accuracy in graph, adapt MNIST (ggml/980)
This commit is contained in:
parent
eee39bdc96
commit
fabdc3bda3
11 changed files with 389 additions and 8 deletions
|
|
@ -259,7 +259,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
}
|
||||
|
||||
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
|
||||
kqsum_j = warp_reduce_sum(kqsum_j);
|
||||
kqsum_j = warp_reduce_sum((float)kqsum_j);
|
||||
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue