CPU/CUDA: Gemma 2 FlashAttention support (#8542)

* CPU/CUDA: Gemma 2 FlashAttention support

* apply logit_softcap to scale in kernel

* disable logit softcapping tests on Metal

* remove metal check
This commit is contained in:
Johannes Gäßler 2024-08-24 21:34:59 +02:00 committed by GitHub
parent 8f824ffe8e
commit e11bd856d5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 319 additions and 79 deletions

View file

@ -4,7 +4,7 @@
#define FATTN_KQ_STRIDE_TILE_F16 64
template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f16(
const float m0,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
@ -44,6 +45,12 @@ static __global__ void flash_attn_tile_ext_f16(
const int ne2,
const int ne3) {
#ifdef FP16_AVAILABLE
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
return;
}
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
@ -154,7 +161,13 @@ static __global__ void flash_attn_tile_ext_f16(
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
half sum;
if (use_logit_softcap) {
const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
sum = logit_softcap * tanhf(tmp.x + tmp.y);
} else {
sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
}
sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
@ -270,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16(
#endif // FP16_AVAILABLE
}
template <int cols_per_block, int parallel_blocks>
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
} break;
default: {
@ -296,24 +309,45 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const int32_t precision = KQV->op_params[2];
const int32_t precision = KQV->op_params[3];
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16;
constexpr int parallel_blocks = 4;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 4;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
}
return;
}
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 1;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
}
}