llama : DeepSeek V2/V3 MLA implementation (#12801)
* Merged using squash to remove all noise commit messages * Force flash attention off for `LLM_ARCH_DEEPSEEK2` - embedding too large * Removed 3 conts (2x RoPE and 1x RMS-norm) * Changed to use `<cmath>` instead of `<math.h>` * Reverted removal of the 3 conts * Used `reshape` in `llm_graph_context::build_attn_mha()` * Use `k_pe = ggml_reshape` * Removed the 3 conts again * Removed the 3D views of `wk_b` and `wv_b`, and just save and 3D in GGUF * Removed MQA optimisation from `build_attn_mha()` as no gains now * Simplified `is_mla` branch in `llm_build_deepseek2()` * Removed `build_attn_mla` and added `nullptr` to all `build_atnn` calls * Fixed call to `build_attn` in `llm_build_t5_enc`
This commit is contained in:
parent
eccc7a1602
commit
daa422881a
13 changed files with 289 additions and 165 deletions
|
@ -4422,6 +4422,10 @@ class DeepseekV2Model(Model):
|
|||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
|
||||
# note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group)
|
||||
self.hparams["num_key_value_heads"] = 1
|
||||
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
|
||||
|
@ -4430,8 +4434,13 @@ class DeepseekV2Model(Model):
|
|||
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
|
||||
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
|
||||
self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"])
|
||||
self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
|
||||
self.gguf_writer.add_value_length(hparams["v_head_dim"])
|
||||
|
||||
# note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
|
||||
self.gguf_writer.add_key_length(hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"])
|
||||
self.gguf_writer.add_value_length(hparams["kv_lora_rank"])
|
||||
self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
|
||||
self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])
|
||||
|
||||
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
|
||||
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
|
||||
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
|
||||
|
@ -4500,6 +4509,26 @@ class DeepseekV2Model(Model):
|
|||
else:
|
||||
return []
|
||||
|
||||
# note: MLA with the absorption optimization, needs these two split and k_b_proj transposed
|
||||
if name.endswith("kv_b_proj.weight"):
|
||||
name_kb = name.replace("kv_b_proj", "k_b_proj")
|
||||
name_vb = name.replace("kv_b_proj", "v_b_proj")
|
||||
|
||||
n_head_kv = self.hparams["num_key_value_heads"]
|
||||
v_head_dim = self.hparams["v_head_dim"]
|
||||
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
|
||||
|
||||
assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)
|
||||
|
||||
kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
|
||||
k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
|
||||
k_b = k_b.transpose(1, 2)
|
||||
|
||||
return [
|
||||
(self.map_tensor_name(name_kb), k_b),
|
||||
(self.map_tensor_name(name_vb), v_b)
|
||||
]
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def prepare_tensors(self):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue