vulkan: Implement grouped query attention in the coopmat2 FA shader (#12559)
When adjacent batches of Q share the same batches of K/V, batch them into the same workgroup. For example, when: dst(128,32,1,1) = FA(q(128,1,32,1), k(128,16640,8,1), v(128,16640,8,1)) previously we would run 32 workgroups computing 1 result each, now we will run 8 workgroups computing 4 results each. This doesn't directly translate to better performance (at least when you have >=32 SMs), but in a subsequent change I'll enable split_k which will scale much better with 4x fewer workgroups.
This commit is contained in:
parent
92e3006bb6
commit
be0a0f8cae
2 changed files with 71 additions and 20 deletions
|
@ -31,6 +31,7 @@
|
|||
|
||||
#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
|
||||
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
|
||||
static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
|
||||
|
||||
#define VK_VENDOR_ID_AMD 0x1002
|
||||
#define VK_VENDOR_ID_APPLE 0x106b
|
||||
|
@ -501,6 +502,8 @@ struct vk_flash_attn_push_constants {
|
|||
uint32_t n_head_log2;
|
||||
float m0;
|
||||
float m1;
|
||||
|
||||
uint32_t gqa_ratio;
|
||||
};
|
||||
|
||||
struct vk_op_push_constants {
|
||||
|
@ -5402,7 +5405,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
const uint32_t nbm1 = mask ? mask->nb[1] : 0;
|
||||
|
||||
const uint32_t D = neq0;
|
||||
const uint32_t N = neq1;
|
||||
uint32_t N = neq1;
|
||||
const uint32_t KV = nek1;
|
||||
|
||||
GGML_ASSERT(ne0 == D);
|
||||
|
@ -5460,6 +5463,22 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
vk_pipeline pipeline = pipelines[aligned];
|
||||
assert(pipeline);
|
||||
|
||||
uint32_t gqa_ratio = 1;
|
||||
uint32_t qk_ratio = neq2 / nek2;
|
||||
uint32_t workgroups_x = (uint32_t)neq1;
|
||||
uint32_t workgroups_y = (uint32_t)neq2;
|
||||
uint32_t workgroups_z = (uint32_t)neq3;
|
||||
|
||||
if (N == 1 && qk_ratio > 1 && is_pow2(qk_ratio) && gqa_ratio <= flash_attention_num_small_rows &&
|
||||
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
|
||||
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
|
||||
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
|
||||
// and change addressing calculations to index Q's dimension 2.
|
||||
gqa_ratio = qk_ratio;
|
||||
N = gqa_ratio;
|
||||
workgroups_y /= N;
|
||||
}
|
||||
|
||||
if (dryrun) {
|
||||
// Request descriptor sets
|
||||
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
|
||||
|
@ -5549,7 +5568,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
|
||||
nbm1,
|
||||
scale, max_bias, logit_softcap,
|
||||
mask != nullptr, n_head_log2, m0, m1 };
|
||||
mask != nullptr, n_head_log2, m0, m1, gqa_ratio };
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{
|
||||
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
|
||||
|
@ -5558,7 +5577,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
|
||||
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
||||
},
|
||||
sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 });
|
||||
sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z });
|
||||
}
|
||||
|
||||
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue