CUDA: add mean operation (#14313)
* CUDA: add mean operation * add back sum_rows_f32_cuda * Review: early exit if col!=0
This commit is contained in:
parent
aa0ef5c578
commit
aa064b2eb7
7 changed files with 54 additions and 19 deletions
|
@ -362,6 +362,26 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
|||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
|
||||
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
|
||||
template<bool norm>
|
||||
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
|
||||
const int row = blockIdx.x;
|
||||
const int col = threadIdx.x;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int i = col; i < ncols; i += blockDim.x) {
|
||||
sum += x[row * ncols + i];
|
||||
}
|
||||
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
if (col != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst[row] = norm ? sum / ncols : sum;
|
||||
}
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||
#pragma unroll
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue