355 lines
14 KiB
Text
355 lines
14 KiB
Text
#version 450
|
|
|
|
#extension GL_EXT_control_flow_attributes : enable
|
|
#extension GL_EXT_shader_16bit_storage : require
|
|
|
|
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
|
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
|
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
|
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
|
|
|
|
#extension GL_KHR_memory_scope_semantics : enable
|
|
#extension GL_KHR_cooperative_matrix : enable
|
|
#extension GL_NV_cooperative_matrix2 : enable
|
|
#extension GL_EXT_buffer_reference : enable
|
|
#extension GL_KHR_shader_subgroup_ballot : enable
|
|
#extension GL_KHR_shader_subgroup_vote : enable
|
|
|
|
#include "types.comp"
|
|
|
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
layout (constant_id = 1) const uint BM = 64;
|
|
layout (constant_id = 2) const uint BN = 64;
|
|
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
|
|
|
|
layout (constant_id = 4) const bool enable_smaller_matrices = false;
|
|
const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
|
|
const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;
|
|
|
|
layout (push_constant) uniform parameter
|
|
{
|
|
uint M;
|
|
uint N;
|
|
uint K;
|
|
uint stride_a;
|
|
uint stride_b;
|
|
uint stride_d;
|
|
|
|
uint batch_stride_a;
|
|
uint batch_stride_b;
|
|
uint batch_stride_d;
|
|
|
|
#ifdef MUL_MAT_ID
|
|
uint nei0;
|
|
uint nei1;
|
|
uint nbi1;
|
|
uint ne11;
|
|
#else
|
|
uint k_split;
|
|
uint ne02;
|
|
uint ne12;
|
|
uint broadcast2;
|
|
uint broadcast3;
|
|
#endif
|
|
// N dimension for the B matrix can be >= p.N
|
|
uint padded_N;
|
|
} p;
|
|
|
|
|
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
|
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
|
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
|
|
|
#if QUANT_K > 1
|
|
#define DECODEFUNCA , dequantFuncA
|
|
|
|
#include "dequant_funcs_cm2.comp"
|
|
|
|
#else
|
|
#define DECODEFUNCA
|
|
#endif
|
|
|
|
#ifdef MUL_MAT_ID
|
|
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
|
|
|
shared u16vec4 row_ids[3072];
|
|
|
|
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
|
|
B_TYPE b[];
|
|
};
|
|
|
|
uint _ne1;
|
|
shared uint _ne1_sh;
|
|
|
|
B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
|
{
|
|
const uint row_i = blockCoords[0];
|
|
|
|
if (row_i >= _ne1) {
|
|
return B_TYPE(0.0);
|
|
}
|
|
|
|
const u16vec4 row_idx = row_ids[row_i];
|
|
B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
|
|
|
|
return ret;
|
|
}
|
|
|
|
D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic)
|
|
{
|
|
uint dr = ir * BM + r;
|
|
uint dc = ic * BN + c;
|
|
|
|
if (dr < p.M && dc < _ne1) {
|
|
uint row_i = dc;
|
|
const u16vec4 row_idx = row_ids[row_i];
|
|
data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem;
|
|
}
|
|
return elem;
|
|
}
|
|
|
|
#endif
|
|
|
|
void main() {
|
|
#ifdef NEEDS_INIT_IQ_SHMEM
|
|
init_iq_shmem(gl_WorkGroupSize);
|
|
#endif
|
|
|
|
#ifdef MUL_MAT_ID
|
|
const uint expert_idx = gl_GlobalInvocationID.z;
|
|
#else
|
|
const uint batch_idx = gl_GlobalInvocationID.z;
|
|
|
|
const uint i13 = batch_idx / p.ne12;
|
|
const uint i12 = batch_idx % p.ne12;
|
|
|
|
const uint i03 = i13 / p.broadcast3;
|
|
const uint i02 = i12 / p.broadcast2;
|
|
|
|
const uint batch_idx_a = i03 * p.ne02 + i02;
|
|
#endif
|
|
|
|
const uint blocks_m = (p.M + BM - 1) / BM;
|
|
const uint ir = gl_WorkGroupID.x % blocks_m;
|
|
const uint ik = gl_WorkGroupID.x / blocks_m;
|
|
const uint ic = gl_WorkGroupID.y;
|
|
|
|
#ifdef MUL_MAT_ID
|
|
// Spread the search across all elements in the first subgroup
|
|
if (gl_SubgroupID == 0) {
|
|
_ne1 = 0;
|
|
uint num_elements = p.nei1 * p.nei0;
|
|
|
|
for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) {
|
|
bool in_range = i < num_elements;
|
|
uint ii0 = i % p.nei0;
|
|
uint ii1 = i / p.nei0;
|
|
uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
|
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
|
uint idx = subgroupBallotExclusiveBitCount(ballot);
|
|
if (in_range && id == expert_idx) {
|
|
row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
|
|
}
|
|
_ne1 += subgroupBallotBitCount(ballot);
|
|
}
|
|
_ne1_sh = _ne1;
|
|
}
|
|
|
|
barrier();
|
|
|
|
_ne1 = _ne1_sh;
|
|
|
|
// Workgroup has no work
|
|
if (ic * BN >= _ne1) return;
|
|
#endif
|
|
|
|
#ifdef MUL_MAT_ID
|
|
uint start_k = 0;
|
|
const uint end_k = p.K;
|
|
#else
|
|
uint start_k = ik * p.k_split;
|
|
const uint end_k = min(p.K, (ik + 1) * p.k_split);
|
|
#endif
|
|
|
|
#ifdef MUL_MAT_ID
|
|
uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K;
|
|
uint pos_b = 0;
|
|
#else
|
|
uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K;
|
|
uint pos_b = batch_idx * p.batch_stride_b;
|
|
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
|
|
#endif
|
|
|
|
uint stride_a = p.stride_a / QUANT_K;
|
|
uint stride_b = p.stride_b;
|
|
|
|
// Hint to the compiler that values are aligned (want 16B alignment).
|
|
// Quants are always block-aligned, no alignment needed.
|
|
#if ALIGNED
|
|
#if QUANT_K == 1
|
|
stride_a &= ~7;
|
|
#endif
|
|
stride_b &= ~7;
|
|
#endif
|
|
|
|
// Create layouts for both clamped and unclamped accesses
|
|
tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2);
|
|
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutAClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
|
tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
|
|
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
|
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
|
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
|
|
|
|
#if QUANT_K > 1
|
|
tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
|
|
tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K);
|
|
#endif
|
|
|
|
// Use end_k rather than p.K as the dimension because that's what
|
|
// we need to bound check against when using split_k.
|
|
// Bounds check B against padded_N, but bounds check D against N.
|
|
tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k);
|
|
tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.padded_N, end_k);
|
|
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M);
|
|
tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);
|
|
tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k);
|
|
|
|
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
|
|
|
|
#if !defined(MUL_MAT_ID)
|
|
// Detect a fast path where all loads are entirely in bounds and no clamping is required
|
|
if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % BK) == 0 && (end_k % BK) == 0 &&
|
|
#if QUANT_K == 1
|
|
(stride_a % 8) == 0 &&
|
|
#endif
|
|
(stride_b % 8) == 0 && (start_k % 8) == 0) {
|
|
// Hint to the compiler that values are aligned (want 16B alignment)
|
|
start_k &= ~7;
|
|
stride_b &= ~7;
|
|
#if QUANT_K == 1
|
|
stride_a &= ~7;
|
|
#endif
|
|
|
|
tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);
|
|
tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
|
|
|
|
uint k_iters = (end_k - start_k + BK - 1) / BK;
|
|
if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) {
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
|
|
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
|
|
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
|
|
|
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
|
|
|
|
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
}
|
|
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
|
|
|
|
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);
|
|
return;
|
|
} else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) {
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
|
|
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
|
|
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
|
|
|
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
|
|
|
|
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
}
|
|
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
|
|
|
|
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);
|
|
return;
|
|
} else {
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
|
|
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
|
|
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
|
|
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
|
|
|
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
}
|
|
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
|
|
|
|
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
|
|
return;
|
|
}
|
|
} else
|
|
#endif // !defined(MUL_MAT_ID)
|
|
{
|
|
tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);
|
|
|
|
tensorLayoutAClamp = setTensorLayoutStrideNV(tensorLayoutAClamp, stride_a, 1);
|
|
|
|
tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
|
|
|
|
tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
|
|
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
|
|
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
|
|
|
|
[[dont_unroll]]
|
|
for (uint block_k = start_k; block_k < end_k; block_k += BK) {
|
|
|
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
|
|
|
// Clamping is expensive, so detect different code paths for each combination
|
|
// of A and B needing clamping.
|
|
bool unclampedA = (ir + 1) * BM <= p.M && block_k + BK <= end_k && (block_k % 8) == 0;
|
|
#ifdef MUL_MAT_ID
|
|
bool unclampedB = true;
|
|
#else
|
|
bool unclampedB = (ic + 1) * BN <= p.padded_N && block_k + BK <= end_k && (block_k % 8) == 0;
|
|
#endif
|
|
if (unclampedA && unclampedB) {
|
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
|
|
#ifdef MUL_MAT_ID
|
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
|
#else
|
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
|
|
#endif
|
|
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
} else if (unclampedA && !unclampedB) {
|
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
|
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
|
|
|
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
} else if (!unclampedA && unclampedB) {
|
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
|
#ifdef MUL_MAT_ID
|
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
|
#else
|
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
|
|
#endif
|
|
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
} else if (!unclampedA && !unclampedB) {
|
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
|
|
|
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
}
|
|
}
|
|
|
|
// Convert from ACC_TYPE to D_TYPE
|
|
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
|
|
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
|
|
|
|
#ifdef MUL_MAT_ID
|
|
// Call callback to store each element, remapping row through shared memory
|
|
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
|
|
#else
|
|
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
|
|
#endif
|
|
}
|
|
}
|