ggml : automatic selection of best CPU backend (#10606)

* ggml : automatic selection of best CPU backend

* amx : minor opt

* add GGML_AVX_VNNI to enable avx-vnni, fix checks
This commit is contained in:
Diego Devesa 2024-12-01 16:12:41 +01:00 committed by GitHub
parent 86dc11c5bc
commit 3420909dff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 599 additions and 156 deletions

View file

@ -78,7 +78,6 @@ inline void parallel_for_ggml(const ggml_compute_params * params, int n, const f
int tbegin, tend;
balance211(n, params->nth, params->ith, tbegin, tend);
f(tbegin, tend);
ggml_barrier(params->threadpool); // TODO: might not always be needed
}
// quantized types that have AMX support

View file

@ -1340,21 +1340,19 @@ struct tinygemm_kernel_avx<float, ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K>
__m512 vb[COLS];
__m512 vc[ROWS * COLS];
auto loadc = [&](int idx) {
auto loadc = [&](auto idx) {
vc[idx] = _mm512_setzero_ps();
};
Unroll<ROWS * COLS>{}(loadc);
auto compute = [&](int idx, int k) {
// TODO: use `constexpr` here to get rid of interger div
// when upgraded to C++17
const int row = idx / COLS;
const int col = idx % COLS;
auto compute = [&](auto idx, auto k) {
constexpr int row = idx / COLS;
constexpr int col = idx % COLS;
if (col == 0) {
if constexpr (col == 0) {
va = _mm512_loadu_ps(A + row * K + k);
}
if (row == 0) {
if constexpr (row == 0) {
vb[col] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(B + col * K + k)));
}
vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
@ -1364,9 +1362,9 @@ struct tinygemm_kernel_avx<float, ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K>
Unroll<ROWS * COLS>{}(compute, k);
}
auto storec = [&](int idx) {
const int row = idx / COLS;
const int col = idx % COLS;
auto storec = [&](auto idx) {
constexpr int row = idx / COLS;
constexpr int col = idx % COLS;
C[row * ldc + col] = _mm512_reduce_add_ps(vc[idx]);
};
Unroll<ROWS * COLS>{}(storec);
@ -1429,14 +1427,14 @@ struct tinygemm_kernel_vnni<block_q8_0, block_q4_0, float, BLOCK_M, BLOCK_N, BLO
const __m512i off = _mm512_set1_epi8(8);
const __m512i lowMask = _mm512_set1_epi8(0xF);
auto loadc = [&](int col) {
auto loadc = [&](auto col) {
vc[col] = _mm512_setzero_ps();
};
Unroll<COLS>{}(loadc);
auto compute = [&](int col, int i) {
auto compute = [&](auto col, auto i) {
// load a and compute compensation
if (col == 0) {
if constexpr (col == 0) {
const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs);
vcomp = _mm512_setzero_si512();
for (int k = 0; k < 8; ++k) {
@ -1468,7 +1466,7 @@ struct tinygemm_kernel_vnni<block_q8_0, block_q4_0, float, BLOCK_M, BLOCK_N, BLO
}
//store to C
auto storec = [&](int col) {
auto storec = [&](auto col) {
_mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
};
Unroll<COLS>{}(storec);
@ -1492,14 +1490,14 @@ struct tinygemm_kernel_vnni<block_q8_1, block_q4_1, float, 1, BLOCK_N, BLOCK_K>
const __m512i lowMask = _mm512_set1_epi8(0xF);
auto loadc = [&](int col) {
auto loadc = [&](auto col) {
vc[col] = _mm512_setzero_ps();
};
Unroll<COLS>{}(loadc);
auto compute = [&](int col, int i) {
auto compute = [&](auto col, auto i) {
// load a
if (col == 0) {
if constexpr (col == 0) {
const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs);
for (int k = 0; k < 8; ++k) {
va[k] = _mm512_set1_epi32(a_ptr[k]);
@ -1533,7 +1531,7 @@ struct tinygemm_kernel_vnni<block_q8_1, block_q4_1, float, 1, BLOCK_N, BLOCK_K>
}
//store to C
auto storec = [&](int col) {
auto storec = [&](auto col) {
_mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
};
Unroll<COLS>{}(storec);
@ -1564,14 +1562,14 @@ struct tinygemm_kernel_vnni<block_q8_0, block_q8_0, float, BLOCK_M, BLOCK_N, BLO
//
const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
auto loadc = [&](int col) {
auto loadc = [&](auto col) {
vc[col] = _mm512_setzero_ps();
};
Unroll<COLS>{}(loadc);
auto compute = [&](int col, int i) {
auto compute = [&](auto col, auto i) {
// load a and add offset 128
if (col == 0) {
if constexpr (col == 0) {
const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs);
for (int k = 0; k < 8; ++k) {
va[k] = _mm512_set1_epi32(a_ptr[k]);
@ -1604,7 +1602,7 @@ struct tinygemm_kernel_vnni<block_q8_0, block_q8_0, float, BLOCK_M, BLOCK_N, BLO
}
//store to C
auto storec = [&](int col) {
auto storec = [&](auto col) {
_mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
};
Unroll<COLS>{}(storec);
@ -1636,7 +1634,7 @@ struct tinygemm_kernel_vnni<block_q8_K, block_q4_K, float, BLOCK_M, BLOCK_N, BLO
const __m512i lowMask = _mm512_set1_epi8(0xF);
auto loadc = [&](int col) {
auto loadc = [&](auto col) {
vc[col] = _mm512_setzero_ps();
};
Unroll<COLS>{}(loadc);
@ -1650,9 +1648,9 @@ struct tinygemm_kernel_vnni<block_q8_K, block_q4_K, float, BLOCK_M, BLOCK_N, BLO
// int16 {k/2, n, 2}, viewed as 2d {k/2, 2n}, k = 8
// from {16, 8} to {4, 32}
//
auto compute = [&](int col, int i) {
auto compute = [&](auto col, auto i) {
// load a
if (col == 0) {
if constexpr (col == 0) {
for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32)));
}
@ -1704,7 +1702,7 @@ struct tinygemm_kernel_vnni<block_q8_K, block_q4_K, float, BLOCK_M, BLOCK_N, BLO
}
//store to C
auto storec = [&](int col) {
auto storec = [&](auto col) {
_mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
};
Unroll<COLS>{}(storec);
@ -1737,15 +1735,15 @@ struct tinygemm_kernel_vnni<block_q8_K, block_q5_K, float, BLOCK_M, BLOCK_N, BLO
const __m512i lowMask = _mm512_set1_epi8(0xF);
auto loadc = [&](int col) {
auto loadc = [&](auto col) {
vc[col] = _mm512_setzero_ps();
};
Unroll<COLS>{}(loadc);
// Q5_K and Q4_K shares the same vnni formats, refer to notes above.
auto compute = [&](int col, int i) {
auto compute = [&](auto col, auto i) {
// load a
if (col == 0) {
if constexpr (col == 0) {
for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32)));
}
@ -1810,7 +1808,7 @@ struct tinygemm_kernel_vnni<block_q8_K, block_q5_K, float, BLOCK_M, BLOCK_N, BLO
}
//store to C
auto storec = [&](int col) {
auto storec = [&](auto col) {
_mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
};
Unroll<COLS>{}(storec);
@ -1843,13 +1841,13 @@ struct tinygemm_kernel_vnni<block_q8_K, block_q6_K, float, BLOCK_M, BLOCK_N, BLO
const __m512i m32s = _mm512_set1_epi32(32);
const __m512i lowMask = _mm512_set1_epi8(0xF);
auto loadc = [&](int col) {
auto loadc = [&](auto col) {
vc[col] = _mm512_setzero_ps();
};
Unroll<COLS>{}(loadc);
auto compute = [&](int col, int i) {
if (col == 0) {
auto compute = [&](auto col, auto i) {
if constexpr (col == 0) {
// load a
va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0));
va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64));
@ -1961,13 +1959,13 @@ struct tinygemm_kernel_vnni<block_q8_K, block_iq4_xs, float, BLOCK_M, BLOCK_N, B
const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
const __m512i values256 = _mm512_add_epi8(values128, off);
auto loadc = [&](int col) {
auto loadc = [&](auto col) {
vc[col] = _mm512_setzero_ps();
};
Unroll<COLS>{}(loadc);
auto compute = [&](int col, int i) {
if (col == 0) {
auto compute = [&](auto col, auto i) {
if constexpr (col == 0) {
// load a
va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0));
va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64));
@ -2017,7 +2015,7 @@ struct tinygemm_kernel_vnni<block_q8_K, block_iq4_xs, float, BLOCK_M, BLOCK_N, B
}
//store to C
auto storec = [&](int col) {
auto storec = [&](auto col) {
_mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
};
Unroll<COLS>{}(storec);