HIP: implement FlashAttention via rocWMMA for CDNA and RDNA3+ (#12032)
Adds GGML_HIP_ROCWMMA_FATTN and rocwmma header check Adds rocWMMA support to fattn-wmma-f16 --- Signed-off-by: Carl Klemm <carl@uvos.xyz> Co-authored-by: Johannes Gäßler <johannesg@5d6.de> Co-authored-by: Ben Jackson <ben@ben.com>
This commit is contained in:
parent
dfd6b2c0be
commit
becade5de7
6 changed files with 145 additions and 95 deletions
|
@ -250,10 +250,18 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
||||
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||
|
||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||
if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
|
||||
#if defined(GGML_HIP_ROCWMMA_FATTN)
|
||||
if (fp16_mma_available(cc)) {
|
||||
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
||||
return;
|
||||
}
|
||||
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
|
||||
|
||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
} else {
|
||||
|
@ -291,7 +299,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
|
||||
K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
|
||||
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0 && !mma_fast_for_bs1) {
|
||||
if (Q->ne[1] == 1 && Q->ne[0] % (2*warp_size) == 0 && !mma_fast_for_bs1) {
|
||||
if (prec == GGML_PREC_DEFAULT) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
return;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue