vulkan: Pad N dimension of B matrix for coopmat2 perf, to avoid bounds checking (#12273)
* vulkan: Pad N dimension of B matrix for coopmat2 perf, to avoid bounds checking
This commit is contained in:
parent
2f21123c1d
commit
891c63956d
2 changed files with 34 additions and 22 deletions
|
@ -48,6 +48,8 @@ layout (push_constant) uniform parameter
|
|||
uint broadcast2;
|
||||
uint broadcast3;
|
||||
#endif
|
||||
// N dimension for the B matrix can be >= p.N
|
||||
uint padded_N;
|
||||
} p;
|
||||
|
||||
|
||||
|
@ -202,18 +204,19 @@ void main() {
|
|||
#endif
|
||||
|
||||
// Use end_k rather than p.K as the dimension because that's what
|
||||
// we need to bound check against when using split_k
|
||||
// we need to bound check against when using split_k.
|
||||
// Bounds check B against padded_N, but bounds check D against N.
|
||||
tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k);
|
||||
tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.N, end_k);
|
||||
tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.padded_N, end_k);
|
||||
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M);
|
||||
tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);
|
||||
tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.N, end_k);
|
||||
tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k);
|
||||
|
||||
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
|
||||
|
||||
#if !defined(MUL_MAT_ID)
|
||||
// Detect a fast path where all loads are entirely in bounds and no clamping is required
|
||||
if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.N && (start_k % BK) == 0 && (end_k % BK) == 0 &&
|
||||
if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % BK) == 0 && (end_k % BK) == 0 &&
|
||||
#if QUANT_K == 1
|
||||
(stride_a % 8) == 0 &&
|
||||
#endif
|
||||
|
@ -263,7 +266,7 @@ void main() {
|
|||
#ifdef MUL_MAT_ID
|
||||
bool unclampedB = true;
|
||||
#else
|
||||
bool unclampedB = (ic + 1) * BN <= p.N && block_k + BK <= end_k && (block_k % 8) == 0;
|
||||
bool unclampedB = (ic + 1) * BN <= p.padded_N && block_k + BK <= end_k && (block_k % 8) == 0;
|
||||
#endif
|
||||
if (unclampedA && unclampedB) {
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue