sycl: reordered Q4_K MMVQ (#13109)
This commit is contained in:
parent
9c404ed54c
commit
64bb51cf90
7 changed files with 280 additions and 84 deletions
|
@ -357,6 +357,28 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8
|
|||
}
|
||||
#endif
|
||||
|
||||
template <typename dst_t>
|
||||
inline void dequantize_q4_K_common(dst_t * __restrict__ y, const uint8_t * __restrict__ qs_ptr, const float dall,
|
||||
const float dmin, uint8_t * __restrict__ scales_local, int il, int ir) {
|
||||
const int is = 2 * il;
|
||||
constexpr int n = 4;
|
||||
|
||||
uint8_t sc, m;
|
||||
get_scale_min_k4(is + 0, scales_local, sc, m);
|
||||
const float d1 = dall * sc;
|
||||
const float m1 = dmin * m;
|
||||
|
||||
get_scale_min_k4(is + 1, scales_local, sc, m);
|
||||
const float d2 = dall * sc;
|
||||
const float m2 = dmin * m;
|
||||
|
||||
sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(qs_ptr + 32 * il + n * ir);
|
||||
for (int l = 0; l < n; ++l) {
|
||||
y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
|
||||
y[l + 32] = d2 * (q_vec[l] >> 4) - m2;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||
uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
|
||||
|
@ -365,36 +387,22 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
|
|||
const int64_t i = item_ct1.get_group(2);
|
||||
|
||||
#if QK_K == 256
|
||||
// assume 32 threads
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
const int64_t il = tid/8;
|
||||
const int64_t ir = tid%8;
|
||||
const int64_t is = 2*il;
|
||||
const int64_t n = 4;
|
||||
const int64_t il = tid / 8;
|
||||
const int64_t ir = tid % 8;
|
||||
|
||||
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
||||
dst_t * y = yy + i * QK_K + 64 * il + 4 * ir;
|
||||
|
||||
const sycl::half2 dm = x[i].dm;
|
||||
const float dall = dm[0];
|
||||
const float dmin = dm[1];
|
||||
|
||||
if (tid < 12)
|
||||
if (tid < 12) {
|
||||
scales_local[tid] = x[i].scales[tid];
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
uint8_t sc, m;
|
||||
get_scale_min_k4(is + 0, scales_local, sc, m);
|
||||
const float d1 = dall * sc;
|
||||
const float m1 = dmin * m;
|
||||
get_scale_min_k4(is + 1, scales_local, sc, m);
|
||||
const float d2 = dall * sc;
|
||||
const float m2 = dmin * m;
|
||||
|
||||
sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(x[i].qs + 32*il + n*ir);
|
||||
for (int l = 0; l < n; ++l) {
|
||||
y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
|
||||
y[l +32] = d2 * (q_vec[l] >> 4) - m2;
|
||||
}
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
dequantize_q4_K_common(y, x[i].qs, dall, dmin, scales_local, il, ir);
|
||||
#else
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
const uint8_t * q = x[i].qs;
|
||||
|
@ -406,6 +414,36 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
|
|||
#endif
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_block_q4_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, uint8_t * scales_local,
|
||||
const sycl::nd_item<1> & item_ct1, int64_t nb) {
|
||||
const int64_t i = item_ct1.get_group(0); // block index
|
||||
const int64_t tid = item_ct1.get_local_id(0); // thread index within block
|
||||
const int64_t il = tid / 8;
|
||||
const int64_t ir = tid % 8;
|
||||
|
||||
dst_t * y = yy + i * QK_K + 64 * il + 4 * ir;
|
||||
|
||||
const uint8_t * base = static_cast<const uint8_t *>(vx);
|
||||
const size_t qs_offset = i * (QK_K / 2);
|
||||
const size_t scales_offset = nb * (QK_K / 2) + i * K_SCALE_SIZE;
|
||||
const size_t dm_offset = nb * (QK_K / 2) + nb * K_SCALE_SIZE + i * sizeof(ggml_half2);
|
||||
|
||||
const uint8_t * qs_ptr = base + qs_offset;
|
||||
const uint8_t * scales_ptr = base + scales_offset;
|
||||
ggml_half2 dm_values = *reinterpret_cast<const ggml_half2 *>(base + dm_offset);
|
||||
|
||||
const float dall = dm_values.x();
|
||||
const float dmin = dm_values.y();
|
||||
|
||||
if (tid < 12) {
|
||||
scales_local[tid] = scales_ptr[tid];
|
||||
}
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
dequantize_q4_K_common(y, qs_ptr, dall, dmin, scales_local, il, ir);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue