CUDA: Improve flash decoding kernel GPU occupancy for BS=1 case (#12183)
- Find out active blocks per SM using cudaOccupancyMaxActiveBlocksPerMultiprocessor API. Use this value to determine the optimal parallel_blocks value. - Prefer vector flash attention kernels over MMA kernel for BS=1 Fixes Issue: #12182 --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
parent
a9b59288e2
commit
517b5ddbf0
12 changed files with 214 additions and 266 deletions
|
@ -259,6 +259,10 @@ static std::string var_to_str(ggml_type type) {
|
|||
return ggml_type_name(type);
|
||||
}
|
||||
|
||||
static std::string var_to_str(ggml_prec prec) {
|
||||
return prec == GGML_PREC_F32 ? "f32" : "def";
|
||||
}
|
||||
|
||||
static std::string var_to_str(ggml_op_pool pool) {
|
||||
switch (pool) {
|
||||
case GGML_OP_POOL_AVG: return "avg";
|
||||
|
@ -3206,11 +3210,12 @@ struct test_flash_attn_ext : public test_case {
|
|||
const float max_bias; // ALiBi
|
||||
const float logit_softcap; // Gemma 2
|
||||
|
||||
const ggml_prec prec;
|
||||
const ggml_type type_KV;
|
||||
std::array<int32_t, 4> permute;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR10(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
|
||||
return VARS_TO_STR11(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
|
@ -3225,9 +3230,9 @@ struct test_flash_attn_ext : public test_case {
|
|||
}
|
||||
|
||||
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
|
||||
bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16,
|
||||
std::array<int32_t, 4> permute = {0, 1, 2, 3})
|
||||
: hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
|
||||
bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
|
||||
ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})
|
||||
: hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
|
||||
|
@ -3261,6 +3266,7 @@ struct test_flash_attn_ext : public test_case {
|
|||
}
|
||||
|
||||
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, logit_softcap);
|
||||
ggml_flash_attn_ext_set_prec(out, prec);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
|
@ -4376,11 +4382,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
for (int kv : { 512, 1024, }) {
|
||||
if (nr != 1 && kv != 512) continue;
|
||||
for (int nb : { 1, 3, 32, 35, }) {
|
||||
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV));
|
||||
// run fewer test cases permuted
|
||||
if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
|
||||
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
|
||||
if (hs != 128 && prec == GGML_PREC_DEFAULT) continue;
|
||||
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(
|
||||
hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
|
||||
// run fewer test cases permuted
|
||||
if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(
|
||||
hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue