graph : make FA compatible with MLA + add initial Metal kernels (#12953)

* graph : make mla compatible with FA

* metal : add exp FA kernels for DeepSeek models

ggml-ci

* llama : minor naming updates

ggml-ci

* ggml : disable FA for DS head sizes

* tests : add FA tests for MLA shapes

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-04-17 18:16:36 +03:00 committed by GitHub
parent 207c22ec2d
commit 2f74c354c0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 117 additions and 26 deletions

View file

@ -484,7 +484,7 @@ ggml_tensor * llama_context::build_rope_shift(
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
const float yarn_attn_factor_scaled = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
ggml_tensor * tmp;
@ -504,14 +504,14 @@ ggml_tensor * llama_context::build_rope_shift(
tmp = ggml_rope_ext_inplace(ctx0, tmp,
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow);
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
tmp = ggml_cpy(ctx0, tmp, cur);
} else {
// we rotate only the first n_rot dimensions
tmp = ggml_rope_ext_inplace(ctx0, cur,
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow);
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
}
return tmp;
@ -2278,11 +2278,6 @@ llama_context * llama_init_from_model(
params.flash_attn = false;
}
if (params.flash_attn && model->arch == LLM_ARCH_DEEPSEEK2) {
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Deepseek2 - forcing off\n", __func__);
params.flash_attn = false;
}
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
return nullptr;

View file

@ -1200,9 +1200,6 @@ ggml_tensor * llm_graph_context::build_attn_mha(
//const auto & n_embd_head_k = hparams.n_embd_head_k;
//const auto & n_embd_head_v = hparams.n_embd_head_v;
// note: for MLA with the absorption optimization, the final embedding size will be changed via v_mla
const auto n_embd_head_v = v_mla == nullptr ? v_trans ? v->ne[1] : v->ne[0] : v_mla->ne[1];
const auto n_tokens = q->ne[1];
const auto n_head = q->ne[2];
const auto n_kv = k->ne[1];
@ -1231,7 +1228,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
if (v_mla) {
cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
cur = ggml_mul_mat(ctx0, v_mla, cur);
}
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
} else {
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
@ -1274,9 +1276,9 @@ ggml_tensor * llm_graph_context::build_attn_mha(
kqv = ggml_mul_mat(ctx0, v_mla, kqv);
}
ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
if (!cparams.offload_kqv) {
// all nodes between the KV store and the attention output are run on the CPU

View file

@ -10050,7 +10050,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(n_embd_head_k));
const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
const float attn_factor = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
ggml_tensor * cur;
ggml_tensor * inpL;
@ -10127,13 +10127,13 @@ struct llm_build_deepseek2 : public llm_graph_context {
q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor_scaled, beta_fast, beta_slow
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(q_pe, "q_pe", il);
k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor_scaled, beta_fast, beta_slow
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(k_pe, "k_pe", il);