vulkan: Pad N dimension of B matrix for coopmat2 perf, to avoid bounds checking (#12273)

* vulkan: Pad N dimension of B matrix for coopmat2 perf, to avoid bounds checking
This commit is contained in:
Jeff Bolz 2025-03-17 04:41:59 -05:00 committed by GitHub
parent 2f21123c1d
commit 891c63956d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 34 additions and 22 deletions

View file

@ -48,6 +48,8 @@ layout (push_constant) uniform parameter
uint broadcast2;
uint broadcast3;
#endif
// N dimension for the B matrix can be >= p.N
uint padded_N;
} p;
@ -202,18 +204,19 @@ void main() {
#endif
// Use end_k rather than p.K as the dimension because that's what
// we need to bound check against when using split_k
// we need to bound check against when using split_k.
// Bounds check B against padded_N, but bounds check D against N.
tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k);
tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.N, end_k);
tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.padded_N, end_k);
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M);
tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);
tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.N, end_k);
tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k);
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
#if !defined(MUL_MAT_ID)
// Detect a fast path where all loads are entirely in bounds and no clamping is required
if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.N && (start_k % BK) == 0 && (end_k % BK) == 0 &&
if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % BK) == 0 && (end_k % BK) == 0 &&
#if QUANT_K == 1
(stride_a % 8) == 0 &&
#endif
@ -263,7 +266,7 @@ void main() {
#ifdef MUL_MAT_ID
bool unclampedB = true;
#else
bool unclampedB = (ic + 1) * BN <= p.N && block_k + BK <= end_k && (block_k % 8) == 0;
bool unclampedB = (ic + 1) * BN <= p.padded_N && block_k + BK <= end_k && (block_k % 8) == 0;
#endif
if (unclampedA && unclampedB) {
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);