llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp
Jeff Bolz a91a41364b
vulkan: optimize coopmat2 dequant functions (#10855)
Change the code to do 16b loads when possible and extract the appropriate
component late, so the code is effectively decoding a pair of elements and
then selecting one. This can allow more commoning to happen in the compiler
when neighboring elements are loaded.
2024-12-21 08:04:45 +01:00

325 lines
11 KiB
Text

#include "types.comp"
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
block_q4_0_packed16 block;
};
float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = uint32_t(bl.block.qs[(idx & 0xE) >> 1]);
qs >>= shift;
qs &= 0x0F0F;
qs = unpack8(qs)[idx & 1];
float16_t ret = (float16_t(qs) - float16_t(8)) * d;
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 {
block_q4_1 block;
};
float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const float16_t m = bl.block.m;
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = float16_t(qs) * d + m;
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 {
block_q5_0 block;
};
float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint uint_qh = uint(bl.block.qh[1]) << 16 | bl.block.qh[0];
const uint qh = ((uint_qh >> idx) << 4) & 0x10;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = (float16_t(qs | qh) - float16_t(16)) * d;
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 {
block_q5_1 block;
};
float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const float16_t m = bl.block.m;
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint uint_qh = bl.block.qh;
const uint qh = ((uint_qh >> idx) << 4) & 0x10;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = float16_t(qs | qh) * d + m;
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 {
block_q8_0_packed16 block;
};
float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint iqs = idx;
// Load 16b and select the byte for this element
int32_t qs = unpack8(int32_t(bl.block.qs[(iqs & 0x1E) >> 1]))[iqs & 1];
float16_t ret = float16_t(qs) * d;
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K {
block_q2_K block;
};
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const f16vec2 d = bl.block.d;
const uint idx = coordInBlock[1];
const uint iqs = idx;
const uint qsi = (iqs / 128) * 32 + (iqs % 32); // 0..31
const uint scalesi = iqs / 16; // 0..15
const uint qsshift = ((iqs % 128) / 32) * 2; // 0,2,4,6
uint32_t qs = bl.block.qs[qsi];
const uint scales = bl.block.scales[scalesi];
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t((qs >> qsshift) & 3) - d.y * float16_t(scales >> 4);
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K {
block_q3_K block;
};
float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const uint idx = coordInBlock[1];
const uint iqs = idx;
const uint n = iqs / 128; // 0,1
const uint qsi = n * 32 + (iqs % 32); // 0..63
const uint hmi = (iqs % 32); // 0..31
const uint j = (iqs % 128) / 8; // 0..15
const uint is = iqs / 16; // 0..15
const uint halfsplit = ((iqs % 128) / 32); // 0,1,2,3
const uint qsshift = halfsplit * 2; // 0,2,4,6
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
uint32_t scaleidx0 = (is < 8) ? is : (is-8);
uint32_t scaleidx0shift = (is < 8) ? 0 : 4;
uint32_t scaleidx1 = is + 8 - (is/4)*4;
uint32_t scaleidx1shift = (is/4)*2;
const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4));
const float16_t dl = bl.block.d * float16_t(us - 32);
float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi ] >> qsshift) & 3) - (((bl.block.hmask[hmi ] & m) != 0) ? 0 : 4));
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K {
block_q4_K block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed16 {
block_q4_K_packed16 block;
};
float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
const uint idx = coordInBlock[1];
const uint b = (idx & 0x20) >> 5; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7
const f16vec2 loadd = bl.block.d;
uint32_t sc;
uint32_t mbyte;
uint32_t scidx0 = (is < 4) ? is : (is + 4);
uint32_t scidx1 = (is < 4) ? is : (is - 4);
uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint32_t scidxshift1 = (is < 4) ? 0 : 2;
uint32_t mbidx0 = is + 4;
uint32_t mbidx1 = (is < 4) ? is + 4 : is;
uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const float16_t d = loadd.x * float16_t(sc);
const float16_t m = loadd.y * float16_t(mbyte);
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
qs = (qs >> (b * 4)) & 0x0F0F;
qs = unpack8(qs)[idx & 1];
float16_t ret = d * float16_t(qs) - m;
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {
block_q5_K block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed16 {
block_q5_K_packed16 block;
};
float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
const uint idx = coordInBlock[1];
const uint b = (idx & 0x20) >> 5; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7
const uint32_t hm = 0x0101 << is;
const f16vec2 loadd = bl.block.d;
uint32_t sc;
uint32_t mbyte;
uint32_t scidx0 = (is < 4) ? is : (is + 4);
uint32_t scidx1 = (is < 4) ? is : (is - 4);
uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint32_t scidxshift1 = (is < 4) ? 0 : 2;
uint32_t mbidx0 = is + 4;
uint32_t mbidx1 = (is < 4) ? is + 4 : is;
uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const float16_t d = loadd.x * float16_t(sc);
const float16_t m = loadd.y * float16_t(mbyte);
uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
qh = qh & hm;
qh = unpack8(qh)[idx & 1];
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
qs = (qs >> (b * 4)) & 0x0F0F;
qs = unpack8(qs)[idx & 1];
float16_t ret = d * (float16_t(qs) + (qh != 0 ? float16_t(16) : float16_t(0))) - m;
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
block_q6_K block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ6_K_packed16 {
block_q6_K_packed16 block;
};
float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl);
const uint idx = coordInBlock[1];
const uint b = (idx & 0x40) >> 6; // 0,1
const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6
const uint is = (idx & 0xF0) >> 4; // 0..15
const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]);
uint ql = uint32_t(bl16.block.ql[((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1)]);
ql = (ql >> (b * 4)) & 0x0F0F;
uint qh = uint32_t(bl16.block.qh[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]);
qh = ((qh >> qhshift) & 0x0303) << 4;
int q = unpack8(ql | qh)[idx & 1];
float16_t ret = dscale * float16_t(q - 32);
return ret;
}
#if defined(DATA_A_IQ4_NL)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL {
block_iq4_nl block;
};
float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = float16_t(kvalues_iq4nl[qs]) * d;
return ret;
}
#endif
#if defined(DATA_A_Q4_0)
#define dequantFuncA dequantFuncQ4_0
#elif defined(DATA_A_Q4_1)
#define dequantFuncA dequantFuncQ4_1
#elif defined(DATA_A_Q5_0)
#define dequantFuncA dequantFuncQ5_0
#elif defined(DATA_A_Q5_1)
#define dequantFuncA dequantFuncQ5_1
#elif defined(DATA_A_Q8_0)
#define dequantFuncA dequantFuncQ8_0
#elif defined(DATA_A_Q2_K)
#define dequantFuncA dequantFuncQ2_K
#elif defined(DATA_A_Q3_K)
#define dequantFuncA dequantFuncQ3_K
#elif defined(DATA_A_Q4_K)
#define dequantFuncA dequantFuncQ4_K
#elif defined(DATA_A_Q5_K)
#define dequantFuncA dequantFuncQ5_K
#elif defined(DATA_A_Q6_K)
#define dequantFuncA dequantFuncQ6_K
#elif defined(DATA_A_IQ4_NL)
#define dequantFuncA dequantFuncIQ4_NL
#endif