CUDA: fix race condition in FA vector kernels (#13742)
This commit is contained in:
parent
b775345d78
commit
ffd0eae60b
2 changed files with 2 additions and 0 deletions
|
@ -212,6 +212,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (__all_sync(0xFFFFFFFF, skip)) {
|
if (__all_sync(0xFFFFFFFF, skip)) {
|
||||||
|
__syncthreads();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
#endif // GGML_USE_HIP
|
#endif // GGML_USE_HIP
|
||||||
|
|
|
@ -217,6 +217,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (__all_sync(0xFFFFFFFF, skip)) {
|
if (__all_sync(0xFFFFFFFF, skip)) {
|
||||||
|
__syncthreads();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
#endif // GGML_USE_HIP
|
#endif // GGML_USE_HIP
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue