llama : allow custom list of swa_layers (#13726)
This commit is contained in:
parent
9ecf3e66a3
commit
8a2afb7520
3 changed files with 54 additions and 23 deletions
|
@ -574,7 +574,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
|
||||
hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
|
||||
hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full
|
||||
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
|
||||
|
||||
switch (hparams.n_expert) {
|
||||
case 16: type = LLM_TYPE_17B_16E; break;
|
||||
|
@ -863,7 +863,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
|
||||
hparams.n_swa = 0;
|
||||
hparams.n_swa_pattern = 1;
|
||||
hparams.set_swa_pattern(1);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_PHIMOE:
|
||||
|
@ -935,7 +935,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
{
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
hparams.n_swa = 4096; // default value of gemma 2
|
||||
hparams.n_swa_pattern = 2;
|
||||
hparams.set_swa_pattern(2);
|
||||
hparams.attn_soft_cap = true;
|
||||
|
||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
||||
|
@ -953,7 +953,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
case LLM_ARCH_GEMMA3:
|
||||
{
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
hparams.n_swa_pattern = 6;
|
||||
hparams.set_swa_pattern(6);
|
||||
|
||||
hparams.rope_freq_base_train_swa = 10000.0f;
|
||||
hparams.rope_freq_scale_train_swa = 1.0f;
|
||||
|
@ -1038,7 +1038,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
case LLM_ARCH_COHERE2:
|
||||
{
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
hparams.n_swa_pattern = 4;
|
||||
hparams.set_swa_pattern(4);
|
||||
|
||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
||||
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
|
||||
|
@ -4320,7 +4320,7 @@ void llama_model::print_info() const {
|
|||
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
|
||||
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
|
||||
LLAMA_LOG_INFO("%s: n_swa_pattern = %u\n", __func__, hparams.n_swa_pattern);
|
||||
LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any());
|
||||
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
|
||||
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
|
||||
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str());
|
||||
|
@ -13216,7 +13216,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
||||
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
GGML_ASSERT(hparams.n_swa_pattern != 1);
|
||||
GGML_ASSERT(hparams.is_swa_any());
|
||||
|
||||
res = new llama_kv_cache_unified_iswa(
|
||||
*this,
|
||||
|
@ -13230,7 +13230,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
cparams.n_batch,
|
||||
padding);
|
||||
} else {
|
||||
GGML_ASSERT(hparams.n_swa_pattern == 1);
|
||||
GGML_ASSERT(!hparams.is_swa_any());
|
||||
|
||||
res = new llama_kv_cache_unified(
|
||||
*this,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue