
* Prefer vector flash decoding kernel for Gemma models Vector flash decoding kernel was not being picked for models with head dimension 256. Gemma models are in this category. Removing this limit improves e2e performance by upto 12% in gen phase throughput for Gemm models. * Update ggml/src/ggml-cuda/fattn.cu Co-authored-by: Johannes Gäßler <johannesg@5d6.de> --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
319 lines
13 KiB
Text
319 lines
13 KiB
Text
#include "common.cuh"
|
|
#include "fattn-common.cuh"
|
|
#include "fattn-mma-f16.cuh"
|
|
#include "fattn-tile-f16.cuh"
|
|
#include "fattn-tile-f32.cuh"
|
|
#include "fattn-vec-f16.cuh"
|
|
#include "fattn-vec-f32.cuh"
|
|
#include "fattn-wmma-f16.cuh"
|
|
#include "fattn.cuh"
|
|
|
|
template <int D, int ncols2>
|
|
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
const ggml_tensor * Q = dst->src[0];
|
|
|
|
if (Q->ne[1] <= 8/ncols2) {
|
|
ggml_cuda_flash_attn_ext_mma_f16_case<D, 8/ncols2, ncols2>(ctx, dst);
|
|
return;
|
|
}
|
|
|
|
if (Q->ne[1] <= 16/ncols2) {
|
|
ggml_cuda_flash_attn_ext_mma_f16_case<D, 16/ncols2, ncols2>(ctx, dst);
|
|
return;
|
|
}
|
|
|
|
if (Q->ne[1] <= 32/ncols2) {
|
|
ggml_cuda_flash_attn_ext_mma_f16_case<D, 32/ncols2, ncols2>(ctx, dst);
|
|
return;
|
|
}
|
|
|
|
ggml_cuda_flash_attn_ext_mma_f16_case<D, 64/ncols2, ncols2>(ctx, dst);
|
|
}
|
|
|
|
template <int ncols2>
|
|
static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
const ggml_tensor * Q = dst->src[0];
|
|
|
|
switch (Q->ne[0]) {
|
|
case 64:
|
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst);
|
|
break;
|
|
case 80:
|
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst);
|
|
break;
|
|
case 96:
|
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst);
|
|
break;
|
|
case 112:
|
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst);
|
|
break;
|
|
case 128:
|
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
|
|
break;
|
|
case 256:
|
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
|
|
break;
|
|
default:
|
|
GGML_ABORT("fatal error");
|
|
break;
|
|
}
|
|
}
|
|
|
|
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
const ggml_tensor * KQV = dst;
|
|
const ggml_tensor * Q = dst->src[0];
|
|
const ggml_tensor * K = dst->src[1];
|
|
const ggml_tensor * mask = dst->src[3];
|
|
|
|
float max_bias = 0.0f;
|
|
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
|
|
|
const float use_gqa_opt = mask && max_bias == 0.0f;
|
|
|
|
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
|
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
|
|
|
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
|
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
|
|
return;
|
|
}
|
|
|
|
if (use_gqa_opt && gqa_ratio == 4) {
|
|
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst);
|
|
return;
|
|
}
|
|
|
|
if (use_gqa_opt && gqa_ratio == 2) {
|
|
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst);
|
|
return;
|
|
}
|
|
|
|
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst);
|
|
}
|
|
|
|
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
|
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
|
|
ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
|
|
return; \
|
|
} \
|
|
|
|
static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
ggml_tensor * Q = dst->src[0];
|
|
ggml_tensor * K = dst->src[1];
|
|
ggml_tensor * V = dst->src[2];
|
|
|
|
#ifdef GGML_CUDA_FA_ALL_QUANTS
|
|
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16 )
|
|
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0)
|
|
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1)
|
|
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0)
|
|
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
|
|
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
|
|
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
|
|
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
#else
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
|
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
|
|
|
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
#endif // GGML_CUDA_FA_ALL_QUANTS
|
|
|
|
on_no_fattn_vec_case(Q->ne[0]);
|
|
}
|
|
|
|
#define FATTN_VEC_F32_CASE(D, type_K, type_V) \
|
|
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
|
|
ggml_cuda_flash_attn_ext_vec_f32_case<D, type_K, type_V>(ctx, dst); \
|
|
return; \
|
|
} \
|
|
|
|
static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
ggml_tensor * Q = dst->src[0];
|
|
ggml_tensor * K = dst->src[1];
|
|
ggml_tensor * V = dst->src[2];
|
|
|
|
#ifdef GGML_CUDA_FA_ALL_QUANTS
|
|
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0)
|
|
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1)
|
|
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0)
|
|
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
|
|
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
|
|
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
|
|
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
#else
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
|
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
|
|
|
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
#endif // GGML_CUDA_FA_ALL_QUANTS
|
|
|
|
on_no_fattn_vec_case(Q->ne[0]);
|
|
}
|
|
|
|
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
const ggml_tensor * KQV = dst;
|
|
const ggml_tensor * Q = dst->src[0];
|
|
const ggml_tensor * K = dst->src[1];
|
|
const ggml_tensor * V = dst->src[2];
|
|
const ggml_tensor * mask = dst->src[3];
|
|
|
|
ggml_cuda_set_device(ctx.device);
|
|
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
|
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
|
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
|
|
|
if (GGML_CUDA_CC_IS_AMD(cc)) {
|
|
#if defined(GGML_HIP_ROCWMMA_FATTN)
|
|
if (fp16_mma_available(cc)) {
|
|
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
|
return;
|
|
}
|
|
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
|
|
|
|
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
|
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
|
} else {
|
|
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
|
}
|
|
return;
|
|
}
|
|
|
|
if (!fast_fp16_available(cc)) {
|
|
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
|
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
|
} else {
|
|
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
|
}
|
|
return;
|
|
}
|
|
|
|
if (!fp16_mma_available(cc)) {
|
|
if (prec == GGML_PREC_DEFAULT) {
|
|
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
|
} else {
|
|
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
|
|
}
|
|
} else {
|
|
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
|
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
|
} else {
|
|
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
|
}
|
|
}
|
|
return;
|
|
}
|
|
|
|
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
|
|
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
|
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
|
|
const bool can_use_vector_kernel = Q->ne[0] % (2*warp_size) == 0;
|
|
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
|
|
if (prec == GGML_PREC_DEFAULT) {
|
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
|
} else {
|
|
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
|
}
|
|
return;
|
|
}
|
|
|
|
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
|
|
if (fp16_mma_available(cc) && !new_mma_available(cc)) {
|
|
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
|
return;
|
|
}
|
|
|
|
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|
|
}
|