CUDA: add mean operation (#14313)

* CUDA: add mean operation

* add back sum_rows_f32_cuda

* Review: early exit if col!=0
This commit is contained in:
Aman Gupta 2025-06-22 12:39:54 +08:00 committed by GitHub
parent aa0ef5c578
commit aa064b2eb7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 54 additions and 19 deletions

View file

@ -362,6 +362,26 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#endif // FP16_AVAILABLE
}
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
template<bool norm>
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x;
const int col = threadIdx.x;
float sum = 0.0f;
for (int i = col; i < ncols; i += blockDim.x) {
sum += x[row * ncols + i];
}
sum = warp_reduce_sum(sum);
if (col != 0) {
return;
}
dst[row] = norm ? sum / ncols : sum;
}
template<int width = WARP_SIZE>
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll