llama : add AWQ for llama, llama2, mpt, and mistral models (#4593)
* update: awq support llama-7b model * update: change order * update: benchmark results for llama2-7b * update: mistral 7b v1 benchmark * update: support 4 models * fix: Readme * update: ready for PR * update: readme * fix: readme * update: change order import * black * format code * update: work for bot mpt and awqmpt * update: readme * Rename to llm_build_ffn_mpt_awq * Formatted other files * Fixed params count * fix: remove code * update: more detail for mpt * fix: readme * fix: readme * update: change folder architecture * fix: common.cpp * fix: readme * fix: remove ggml_repeat * update: cicd * update: cicd * uppdate: remove use_awq arg * update: readme * llama : adapt plamo to new ffn ggml-ci --------- Co-authored-by: Trần Đức Nam <v.namtd12@vinai.io> Co-authored-by: Le Hoang Anh <v.anhlh33@vinai.io> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
879b690a9e
commit
f6793491b5
8 changed files with 443 additions and 5 deletions
27
llama.cpp
27
llama.cpp
|
@ -354,6 +354,7 @@ enum llm_tensor {
|
|||
LLM_TENSOR_FFN_GATE,
|
||||
LLM_TENSOR_FFN_DOWN,
|
||||
LLM_TENSOR_FFN_UP,
|
||||
LLM_TENSOR_FFN_ACT,
|
||||
LLM_TENSOR_FFN_DOWN_EXP,
|
||||
LLM_TENSOR_FFN_GATE_EXP,
|
||||
LLM_TENSOR_FFN_UP_EXP,
|
||||
|
@ -473,6 +474,7 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
|
|||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_ACT, "blk.%d.ffn.act" },
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -1285,6 +1287,7 @@ struct llama_hparams {
|
|||
float f_clamp_kqv;
|
||||
float f_max_alibi_bias;
|
||||
|
||||
|
||||
bool operator!=(const llama_hparams & other) const {
|
||||
if (this->vocab_only != other.vocab_only) return true;
|
||||
if (this->n_vocab != other.n_vocab) return true;
|
||||
|
@ -1388,6 +1391,7 @@ struct llama_layer {
|
|||
// ff bias
|
||||
struct ggml_tensor * ffn_down_b; // b2
|
||||
struct ggml_tensor * ffn_up_b; // b3
|
||||
struct ggml_tensor * ffn_act;
|
||||
};
|
||||
|
||||
struct llama_kv_cell {
|
||||
|
@ -3471,7 +3475,6 @@ static bool llm_load_tensors(
|
|||
case LLM_ARCH_MPT:
|
||||
{
|
||||
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
|
||||
|
||||
// output
|
||||
{
|
||||
ggml_backend_type backend_norm;
|
||||
|
@ -3509,6 +3512,9 @@ static bool llm_load_tensors(
|
|||
|
||||
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
|
||||
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
||||
|
||||
// AWQ ScaleActivation layer
|
||||
layer.ffn_act = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, backend, false);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_STABLELM:
|
||||
|
@ -4039,6 +4045,7 @@ static struct ggml_tensor * llm_build_ffn(
|
|||
struct ggml_tensor * gate_b,
|
||||
struct ggml_tensor * down,
|
||||
struct ggml_tensor * down_b,
|
||||
struct ggml_tensor * act_scales,
|
||||
llm_ffn_op_type type_op,
|
||||
llm_ffn_gate_type type_gate,
|
||||
const llm_build_cb & cb,
|
||||
|
@ -4083,6 +4090,10 @@ static struct ggml_tensor * llm_build_ffn(
|
|||
{
|
||||
cur = ggml_gelu(ctx, cur);
|
||||
cb(cur, "ffn_gelu", il);
|
||||
if (act_scales != NULL) {
|
||||
cur = ggml_div(ctx, cur, act_scales);
|
||||
cb(cur, "ffn_act", il);
|
||||
}
|
||||
} break;
|
||||
case LLM_FFN_RELU:
|
||||
{
|
||||
|
@ -4401,6 +4412,7 @@ struct llm_build_context {
|
|||
model.layers[il].ffn_up, NULL,
|
||||
model.layers[il].ffn_gate, NULL,
|
||||
model.layers[il].ffn_down, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
|
@ -4580,6 +4592,7 @@ struct llm_build_context {
|
|||
model.layers[il].ffn_up, NULL,
|
||||
model.layers[il].ffn_gate, NULL,
|
||||
model.layers[il].ffn_down, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
@ -4694,6 +4707,7 @@ struct llm_build_context {
|
|||
model.layers[il].ffn_up, NULL,
|
||||
NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL,
|
||||
NULL,
|
||||
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
@ -4798,6 +4812,7 @@ struct llm_build_context {
|
|||
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
|
||||
NULL, NULL,
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
|
||||
NULL,
|
||||
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
@ -5002,6 +5017,7 @@ struct llm_build_context {
|
|||
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
|
||||
NULL, NULL,
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
|
||||
NULL,
|
||||
LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
@ -5088,6 +5104,7 @@ struct llm_build_context {
|
|||
model.layers[il].ffn_up, NULL,
|
||||
model.layers[il].ffn_gate, NULL,
|
||||
model.layers[il].ffn_down, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
@ -5183,6 +5200,7 @@ struct llm_build_context {
|
|||
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
|
||||
NULL, NULL,
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
|
||||
NULL,
|
||||
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
@ -5268,11 +5286,11 @@ struct llm_build_context {
|
|||
NULL,
|
||||
LLM_NORM, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = llm_build_ffn(ctx0, cur,
|
||||
model.layers[il].ffn_up, NULL,
|
||||
NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL,
|
||||
model.layers[il].ffn_act,
|
||||
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
@ -5381,6 +5399,7 @@ struct llm_build_context {
|
|||
model.layers[il].ffn_up, NULL,
|
||||
model.layers[il].ffn_gate, NULL,
|
||||
model.layers[il].ffn_down, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
@ -5493,6 +5512,7 @@ struct llm_build_context {
|
|||
model.layers[il].ffn_up, NULL,
|
||||
model.layers[il].ffn_gate, NULL,
|
||||
model.layers[il].ffn_down, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
@ -5600,6 +5620,7 @@ struct llm_build_context {
|
|||
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
|
||||
NULL, NULL,
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
|
||||
NULL,
|
||||
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
||||
cb(ffn_output, "ffn_out", il);
|
||||
}
|
||||
|
@ -5703,6 +5724,7 @@ struct llm_build_context {
|
|||
model.layers[il].ffn_up, NULL,
|
||||
model.layers[il].ffn_gate, NULL,
|
||||
model.layers[il].ffn_down, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
@ -5887,6 +5909,7 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
|
|||
{ "ffn_gate", OFFLOAD_FUNC },
|
||||
{ "ffn_gate_b", OFFLOAD_FUNC },
|
||||
{ "ffn_gate_par", OFFLOAD_FUNC },
|
||||
{ "ffn_act", OFFLOAD_FUNC },
|
||||
{ "ffn_down", OFFLOAD_FUNC },
|
||||
{ "ffn_down_b", OFFLOAD_FUNC },
|
||||
{ "ffn_out", OFFLOAD_FUNC },
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue