CUDA: MMQ support for iq4_nl, iq4_xs (#8278)

This commit is contained in:
Johannes Gäßler 2024-07-05 09:06:31 +02:00 committed by GitHub
parent 0a423800ff
commit 8e558309dc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 226 additions and 80 deletions

View file

@ -68,7 +68,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
const int iqs4 = k_KQ % QI4_0;
const int shift = k_KQ & (QI8_1/2);
const int v = (get_int_from_uint8(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
const int u = Q_q8[k_KQ_0/WARP_SIZE];
const int sumi = ggml_cuda_dp4a(v, u, 0);
@ -108,7 +108,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
const int iqs4 = k_KQ % QI4_1;
const int shift = k_KQ & (QI8_1/2);
const int v = (get_int_from_uint8_aligned(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
const int u = Q_q8[k_KQ_0/WARP_SIZE];
const int sumi = ggml_cuda_dp4a(v, u, 0);
@ -153,8 +153,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
const int iqs8 = k_KQ % QI8_1;
const int shift = k_KQ & (QI8_1/2);
int v = (get_int_from_uint8(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
const int vh = get_int_from_uint8(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
int v = (get_int_b2(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
const int vh = get_int_b2(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
v |= (vh << 4) & 0x00000010; // 0 -> 4
v |= (vh << 11) & 0x00001000; // 1 -> 12
v |= (vh << 18) & 0x00100000; // 2 -> 20
@ -200,8 +200,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
const int iqs8 = k_KQ % QI8_1;
const int shift = k_KQ & (QI8_1/2);
int v = (get_int_from_uint8(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
const int vh = get_int_from_uint8(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
int v = (get_int_b2(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
const int vh = get_int_b2(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
v |= (vh << 4) & 0x00000010; // 0 -> 4
v |= (vh << 11) & 0x00001000; // 1 -> 12
v |= (vh << 18) & 0x00100000; // 2 -> 20
@ -249,7 +249,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
const int ib = k_KQ / QI8_0;
const int iqs = k_KQ % QI8_0;
const int v = get_int_from_int8(K_q8_0[ib].qs, iqs);
const int v = get_int_b2(K_q8_0[ib].qs, iqs);
T Q_d;
if (std::is_same<T, half>::value) {
@ -408,7 +408,7 @@ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__
const T d = x[ib].d;
const int ql0 = x[ib].qs[iqs];
const int qh0 = get_int_from_uint8(x[ib].qh, 0);
const int qh0 = get_int_b2(x[ib].qh, 0);
const int ql = ((ql0 >> (4*shift)) & 0x0F);
const int qh = ((qh0 >> idq) << 4) & 0x10;
const int q = (ql | qh) - 16;
@ -433,7 +433,7 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__
const half2 dm = x[ib].dm;
const int ql0 = x[ib].qs[iqs];
const int qh0 = get_int_from_uint8_aligned(x[ib].qh, 0);
const int qh0 = get_int_b4(x[ib].qh, 0);
const int ql = ((ql0 >> (4*shift)) & 0x0F);
const int qh = ((qh0 >> idq) << 4) & 0x10;
const int q = (ql | qh);