CUDA: refactor and optimize IQ MMVQ (#8215)
* CUDA: refactor and optimize IQ MMVQ * uint -> uint32_t * __dp4a -> ggml_cuda_dp4a * remove MIN_CC_DP4A checks * change default * try CI fix
This commit is contained in:
parent
dae57a1ebc
commit
cb5fad4c6c
8 changed files with 406 additions and 487 deletions
|
@ -3,6 +3,7 @@
|
|||
#include "ggml.h"
|
||||
#include "ggml-cuda.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
#if defined(GGML_USE_HIPBLAS)
|
||||
|
@ -268,30 +269,15 @@ static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigne
|
|||
return c;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
|
||||
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
|
||||
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
||||
#elif defined(RDNA3)
|
||||
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
|
||||
#elif defined(__gfx1010__) || defined(__gfx900__)
|
||||
int tmp1;
|
||||
int tmp2;
|
||||
asm("\n \
|
||||
v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
|
||||
v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
|
||||
v_add3_u32 %0, %1, %2, %0 \n \
|
||||
v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
|
||||
v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
|
||||
v_add3_u32 %0, %1, %2, %0 \n \
|
||||
"
|
||||
: "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
|
||||
: "v"(a), "v"(b)
|
||||
);
|
||||
#else
|
||||
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
|
||||
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
|
||||
c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
|
||||
#endif
|
||||
static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigned int b) {
|
||||
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
||||
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
||||
unsigned int c;
|
||||
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
|
@ -467,8 +453,48 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
|
|||
}
|
||||
#endif // CUDART_VERSION < 12000
|
||||
|
||||
static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
|
||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
|
||||
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
||||
#elif defined(RDNA3)
|
||||
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
|
||||
#elif defined(__gfx1010__) || defined(__gfx900__)
|
||||
int tmp1;
|
||||
int tmp2;
|
||||
asm("\n \
|
||||
v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
|
||||
v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
|
||||
v_add3_u32 %0, %1, %2, %0 \n \
|
||||
v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
|
||||
v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
|
||||
v_add3_u32 %0, %1, %2, %0 \n \
|
||||
"
|
||||
: "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
|
||||
: "v"(a), "v"(b)
|
||||
);
|
||||
#else
|
||||
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
|
||||
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
|
||||
c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
|
||||
#endif
|
||||
return c;
|
||||
|
||||
#else // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
|
||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||
return __dp4a(a, b, c);
|
||||
#else // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||
const int8_t * a8 = (const int8_t *) &a;
|
||||
const int8_t * b8 = (const int8_t *) &b;
|
||||
return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
|
||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||
|
||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
}
|
||||
|
||||
// TODO: move to ggml-common.h
|
||||
static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
||||
static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
||||
|
||||
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue