server: tests: passkey challenge / self-extend with context shift demo (#5832)

* server: tests: add models endpoint scenario

* server: /v1/models add some metadata

* server: tests: add debug field in context before scenario

* server: tests: download model from HF, add batch size

* server: tests: add passkey test

* server: tests: add group attention params

* server: do not truncate prompt tokens if self-extend through group attention is enabled

* server: logs: do not truncate log values

* server: tests - passkey - first good working value of nga

* server: tests: fix server timeout

* server: tests: fix passkey, add doc, fix regex content matching, fix timeout

* server: tests: fix regex content matching

* server: tests: schedule slow tests on master

* server: metrics: fix when no prompt processed

* server: tests: self-extend add llama-2-7B and Mixtral-8x7B-v0.1

* server: tests: increase timeout for completion

* server: tests: keep only the PHI-2 test

* server: tests: passkey add a negative test
This commit is contained in:
Pierrick Hymbert 2024-03-02 22:00:14 +01:00 committed by GitHub
parent 4a6e2d6142
commit 9731134296
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 363 additions and 112 deletions

View file

@ -441,8 +441,8 @@ struct llama_server_context
const int ga_w = params.grp_attn_w;
if (ga_n != 1) {
GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
@ -1709,8 +1709,8 @@ struct llama_server_context
}
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
// if input prompt is too big, truncate it
if (slot.n_prompt_tokens >= slot.n_ctx)
// if input prompt is too big, truncate it, if group attention self-extend is disabled
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx)
{
const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_block_size = n_left / 2;
@ -1785,9 +1785,11 @@ struct llama_server_context
}
LOG_INFO("slot progression", {
{ "slot_id", slot.id },
{ "task_id", slot.task_id },
{ "n_past", slot.n_past },
{ "slot_id", slot.id },
{ "task_id", slot.task_id },
{ "n_past", slot.n_past },
{ "n_past_se", slot.n_past_se },
{ "ga_i", slot.ga_i },
{ "n_prompt_tokens_processed", slot.n_prompt_tokens_processed }
});
}
@ -2001,6 +2003,17 @@ struct llama_server_context
LOG_VERBOSE("slots updated", {});
return true;
}
json model_meta() {
return json{
{"vocab_type", llama_vocab_type(model)},
{"n_vocab", llama_n_vocab(model)},
{"n_ctx_train", llama_n_ctx_train(model)},
{"n_embd", llama_n_embd(model)},
{"n_params", llama_model_n_params(model)},
{"size", llama_model_size(model)},
};
}
};
static void server_print_usage(const char *argv0, const gpt_params &params,
@ -2911,9 +2924,10 @@ int main(int argc, char **argv)
for (const auto& metric_def : metrics_def) {
std::string name = metric_def["name"];
std::string help = metric_def["help"];
prometheus << "# HELP llamacpp:" << name << " " << help << "\n"
<< "# TYPE llamacpp:" << name << " " << type << "\n"
<< "llamacpp:" << name << " " << metric_def["value"] << "\n";
auto value = json_value(metric_def, "value", 0);
prometheus << "# HELP llamacpp:" << name << " " << help << "\n"
<< "# TYPE llamacpp:" << name << " " << type << "\n"
<< "llamacpp:" << name << " " << value << "\n";
}
}
@ -2994,6 +3008,7 @@ int main(int argc, char **argv)
state.store(SERVER_STATE_READY);
LOG_INFO("model loaded", {});
}
const auto model_meta = llama.model_meta();
if (sparams.chat_template.empty()) { // custom chat template is not supplied
// check if the template comes with the model is supported by us
@ -3143,7 +3158,7 @@ int main(int argc, char **argv)
}
});
svr.Get("/v1/models", [&params](const httplib::Request& req, httplib::Response& res)
svr.Get("/v1/models", [&params, &model_meta](const httplib::Request& req, httplib::Response& res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
std::time_t t = std::time(0);
@ -3152,10 +3167,11 @@ int main(int argc, char **argv)
{"object", "list"},
{"data", {
{
{"id", params.model_alias},
{"object", "model"},
{"created", t},
{"owned_by", "llamacpp"}
{"id", params.model_alias},
{"object", "model"},
{"created", t},
{"owned_by", "llamacpp"},
{"meta", model_meta}
},
}}
};