sycl: Add reorder to Q6_K mmvq implementation (#13885)
* Add Reorder to Q6_K mmvq implementation * Address PR comments: clean up comments * Remove unused parameter after refactoring q4_k * Adding inline to function and removing unnecessary reference to int --------- Signed-off-by: nscipione <nicolo.scipione@codeplay.com>
This commit is contained in:
parent
91a8ee6a6f
commit
b460d16ae8
6 changed files with 244 additions and 30 deletions
|
@ -538,6 +538,38 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
|
|||
#endif
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_block_q6_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||
const sycl::nd_item<3> & item_ct1, int64_t n_blocks) {
|
||||
const int64_t ib = item_ct1.get_group(2);
|
||||
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
const int64_t ip = tid / 32; // ip is 0 or 1
|
||||
const int64_t il = tid - 32 * ip; // 0...32
|
||||
const int64_t is = 8 * ip + il / 16;
|
||||
|
||||
const uint8_t * base_ptr = static_cast<const uint8_t *>(vx);
|
||||
const auto ql_offset = ib * (QK_K / 2);
|
||||
const auto qh_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * ib;
|
||||
const auto base_scales_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * n_blocks + (QK_K / 16) * ib;
|
||||
const auto base_d_offset = ((QK_K / 2) + (QK_K / 4) + (QK_K / 16)) * n_blocks;
|
||||
const uint8_t * ql_ptr = base_ptr + ql_offset;
|
||||
const uint8_t * qh_ptr = base_ptr + qh_offset;
|
||||
const uint8_t * scales_ptr = base_ptr + base_scales_offset;
|
||||
const ggml_half * d = (const ggml_half *) (base_ptr + base_d_offset) + ib;
|
||||
|
||||
dst_t * y = yy + ib * QK_K + 128 * ip + il;
|
||||
|
||||
const uint8_t * ql = ql_ptr + 64 * ip + il;
|
||||
const uint8_t qh = *(qh_ptr + 32 * ip + il);
|
||||
const int8_t * sc = reinterpret_cast<const int8_t *>(scales_ptr + is);
|
||||
|
||||
y[0] = *d * sc[0] * ((int8_t) ((ql[0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
|
||||
y[32] = *d * sc[2] * ((int8_t) ((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
|
||||
y[64] = *d * sc[4] * ((int8_t) ((ql[0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
|
||||
y[96] = *d * sc[6] * ((int8_t) ((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||
const sycl::nd_item<3> &item_ct1,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue