llama : allow custom list of swa_layers (#13726)
This commit is contained in:
parent
9ecf3e66a3
commit
8a2afb7520
3 changed files with 54 additions and 23 deletions
|
@ -2,6 +2,26 @@
|
|||
|
||||
#include "ggml.h"
|
||||
|
||||
llama_hparams::llama_hparams() {
|
||||
swa_layers.fill(false);
|
||||
}
|
||||
|
||||
void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
|
||||
}
|
||||
}
|
||||
|
||||
bool llama_hparams::is_swa_any() const {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
if (swa_layers[il]) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_head(uint32_t il) const {
|
||||
if (il < n_layer) {
|
||||
return n_head_arr[il];
|
||||
|
@ -72,7 +92,7 @@ uint32_t llama_hparams::n_embd_v_s() const {
|
|||
|
||||
bool llama_hparams::is_swa(uint32_t il) const {
|
||||
if (il < n_layer) {
|
||||
return n_swa_pattern == 0 || (il % n_swa_pattern < (n_swa_pattern - 1));
|
||||
return swa_layers[il];
|
||||
}
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue