metal : add quantized FA support (#10149)
* metal : add quantized FA (vec) support ggml-ci * metal : add quantized FA (non-vec) support * metal : fix support check ggml-ci * metal : clean-up * metal : clean-up (cont) * metal : fix shared memory calc + reduce smem + comments * metal : float-correctness * metal : minor [no ci]
This commit is contained in:
parent
b8deef0ec0
commit
a1eaf6a960
2 changed files with 567 additions and 191 deletions
|
|
@ -255,9 +255,49 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
||||
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
||||
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
||||
|
|
@ -710,9 +750,49 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm);
|
||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
|
||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
||||
|
|
@ -869,13 +949,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|||
case GGML_OP_LEAKY_RELU:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
if (op->src[1]->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[2]->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 256) {
|
||||
if (op->src[1]->type != op->src[2]->type) {
|
||||
return false;
|
||||
}
|
||||
return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||
|
|
@ -2822,6 +2896,7 @@ static void ggml_metal_encode_node(
|
|||
GGML_ASSERT(ne11 % 32 == 0);
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == src2->type);
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape (src1, src2));
|
||||
|
||||
|
|
@ -2869,26 +2944,154 @@ static void ggml_metal_encode_node(
|
|||
bool use_vec_kernel = false;
|
||||
|
||||
if (ne01 >= 4 || (ne00%128 != 0)) {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
||||
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
||||
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
||||
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
||||
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
{
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
|
||||
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
|
||||
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
{
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
|
||||
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
|
||||
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
{
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
|
||||
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
|
||||
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
{
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
|
||||
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
|
||||
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
{
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
|
||||
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
|
||||
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
{
|
||||
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
||||
GGML_LOG_ERROR("add template specialization for this type\n");
|
||||
GGML_ABORT("add template specialization for this type");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
use_vec_kernel = true;
|
||||
|
||||
switch (ne00) {
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
||||
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
||||
case 128:
|
||||
{
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
|
||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
|
||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
|
||||
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break;
|
||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
||||
GGML_LOG_ERROR("add template specialization for this type\n");
|
||||
GGML_ABORT("add template specialization for this type");
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case 256:
|
||||
{
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
|
||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
|
||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
|
||||
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break;
|
||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
||||
GGML_LOG_ERROR("add template specialization for this type\n");
|
||||
GGML_ABORT("add template specialization for this type");
|
||||
}
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
|
|
@ -2942,10 +3145,16 @@ static void ggml_metal_encode_node(
|
|||
GGML_ASSERT(nqptg % 8 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
// 16*32*(nsg)
|
||||
// the shared memory needed for the simdgroups to load the KV cache
|
||||
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
||||
//
|
||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
||||
|
||||
int64_t nsgmax = 2;
|
||||
|
||||
while (true) {
|
||||
const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
|
||||
const size_t smem = FATTN_SMEM(nsgmax);
|
||||
if (smem > device.maxThreadgroupMemoryLength) {
|
||||
break;
|
||||
}
|
||||
|
|
@ -2956,16 +3165,15 @@ static void ggml_metal_encode_node(
|
|||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
||||
|
||||
const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
//printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
|
||||
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
||||
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
||||
|
||||
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
#undef FATTN_SMEM
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
} else {
|
||||
// half1x4 kernel
|
||||
// half4x4 kernel
|
||||
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||
|
||||
|
|
@ -2973,8 +3181,28 @@ static void ggml_metal_encode_node(
|
|||
GGML_ASSERT(nqptg % 1 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
// ne00 + 2*ncpsg*(nsg)
|
||||
// for each query, we load it as f16 in shared memory (ne00)
|
||||
// and store the attention scores (nqptg x ncpsg) as f32
|
||||
//
|
||||
// 2*ne00*(nsg)
|
||||
// each simdgroup has a full f32 head vector in shared mem to accumulate results
|
||||
//
|
||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16))
|
||||
|
||||
int64_t nsgmax = 2;
|
||||
|
||||
while (true) {
|
||||
const size_t smem = FATTN_SMEM(nsgmax);
|
||||
if (smem > device.maxThreadgroupMemoryLength) {
|
||||
break;
|
||||
}
|
||||
nsgmax *= 2;
|
||||
}
|
||||
nsgmax /= 2;
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
|
||||
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
|
||||
|
||||
int64_t nsg = 1;
|
||||
while (nsg <= nsgt) {
|
||||
|
|
@ -2982,12 +3210,12 @@ static void ggml_metal_encode_node(
|
|||
}
|
||||
nsg /= 2;
|
||||
|
||||
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
//printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
|
||||
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
||||
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
||||
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
#undef FATTN_SMEM
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
}
|
||||
} break;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue