llama : DeepSeek V2/V3 MLA implementation (#12801)
* Merged using squash to remove all noise commit messages * Force flash attention off for `LLM_ARCH_DEEPSEEK2` - embedding too large * Removed 3 conts (2x RoPE and 1x RMS-norm) * Changed to use `<cmath>` instead of `<math.h>` * Reverted removal of the 3 conts * Used `reshape` in `llm_graph_context::build_attn_mha()` * Use `k_pe = ggml_reshape` * Removed the 3 conts again * Removed the 3D views of `wk_b` and `wv_b`, and just save and 3D in GGUF * Removed MQA optimisation from `build_attn_mha()` as no gains now * Simplified `is_mla` branch in `llm_build_deepseek2()` * Removed `build_attn_mla` and added `nullptr` to all `build_atnn` calls * Fixed call to `build_attn` in `llm_build_t5_enc`
This commit is contained in:
parent
eccc7a1602
commit
daa422881a
13 changed files with 289 additions and 165 deletions
|
@ -1188,6 +1188,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|||
ggml_tensor * v,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * kq_mask,
|
||||
ggml_tensor * v_mla,
|
||||
bool v_trans,
|
||||
float kq_scale) const {
|
||||
//const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
|
@ -1199,7 +1200,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|||
//const auto & n_embd_head_k = hparams.n_embd_head_k;
|
||||
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
||||
|
||||
const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0];
|
||||
// note: for MLA with the absorption optimization, the final embedding size will be changed via v_mla
|
||||
const auto n_embd_head_v = v_mla == nullptr ? v_trans ? v->ne[1] : v->ne[0] : v_mla->ne[1];
|
||||
|
||||
const auto n_tokens = q->ne[1];
|
||||
const auto n_head = q->ne[2];
|
||||
|
@ -1267,6 +1269,11 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|||
|
||||
ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
||||
|
||||
// for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
|
||||
if (v_mla) {
|
||||
kqv = ggml_mul_mat(ctx0, v_mla, kqv);
|
||||
}
|
||||
|
||||
ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
|
||||
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
|
||||
|
@ -1304,6 +1311,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
GGML_UNUSED(n_tokens);
|
||||
|
@ -1325,7 +1333,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
||||
//cb(k, "v", il);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
|
||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
|
||||
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
|
@ -1379,6 +1387,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
// these nodes are added to the graph together so that they are not reordered
|
||||
|
@ -1464,7 +1473,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
|
||||
0);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale);
|
||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (wo) {
|
||||
|
@ -1504,6 +1513,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
// these nodes are added to the graph together so that they are not reordered
|
||||
|
@ -1523,7 +1533,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
||||
//cb(k, "v", il);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
|
||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
|
||||
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
|
@ -1692,4 +1702,3 @@ void llm_graph_context::build_pooling(
|
|||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue