CUDA: use mma PTX instructions for FlashAttention (#11583)
* CUDA: use mma PTX instructions for FlashAttention * __shfl_sync workaround for movmatrix * add __shfl_sync to HIP Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
parent
84ec8a58f7
commit
864a0b67a6
29 changed files with 2058 additions and 998 deletions
|
@ -1,5 +1,6 @@
|
|||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
#include "fattn-mma-f16.cuh"
|
||||
#include "fattn-tile-f16.cuh"
|
||||
#include "fattn-tile-f32.cuh"
|
||||
#include "fattn-vec-f16.cuh"
|
||||
|
@ -7,144 +8,56 @@
|
|||
#include "fattn-wmma-f16.cuh"
|
||||
#include "fattn.cuh"
|
||||
|
||||
#include <cstdint>
|
||||
template <int cols_per_block>
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||
|
||||
if (prec != GGML_PREC_DEFAULT) {
|
||||
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
||||
constexpr int cols_per_block = 16;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
constexpr int cols_per_block = 32;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
// case 256:
|
||||
// ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
||||
// break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
|
||||
constexpr int cols_per_block = 8;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 32) {
|
||||
constexpr int cols_per_block = 16;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 32;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case< 64, cols_per_block>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case< 80, cols_per_block>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case< 96, cols_per_block>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<112, cols_per_block>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<128, cols_per_block>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<256, cols_per_block>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
if (Q->ne[1] <= 8) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 16) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<16>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 32) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<32>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<64>(ctx, dst);
|
||||
}
|
||||
|
||||
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
||||
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
|
||||
|
@ -322,11 +235,19 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
return;
|
||||
}
|
||||
|
||||
if (!fp16_mma_available(cc)) {
|
||||
if (Q->ne[1] <= 8) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
if (!new_mma_available(cc)) {
|
||||
if (prec == GGML_PREC_DEFAULT) {
|
||||
if (Q->ne[1] <= 8) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
|
||||
}
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
|
||||
if (Q->ne[1] <= 8) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -341,5 +262,10 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
}
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
||||
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
|
||||
if (cc == GGML_CUDA_CC_VOLTA) {
|
||||
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue