CUDA: fix race condition in MMQ stream-k fixup (#13299)
This commit is contained in:
parent
8afbd96818
commit
93c4e23905
1 changed files with 1 additions and 0 deletions
|
@ -2958,6 +2958,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
||||||
for (int j = threadIdx.y*WARP_SIZE + threadIdx.x; j < mmq_x; j += nwarps*WARP_SIZE) {
|
for (int j = threadIdx.y*WARP_SIZE + threadIdx.x; j < mmq_x; j += nwarps*WARP_SIZE) {
|
||||||
ids_dst_shared[j] = ids_dst[col_low + j];
|
ids_dst_shared[j] = ids_dst[col_low + j];
|
||||||
}
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
const int offset_dst = it*mmq_y;
|
const int offset_dst = it*mmq_y;
|
||||||
dst += offset_dst;
|
dst += offset_dst;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue