llama : allow custom list of swa_layers (#13726)
This commit is contained in:
parent
9ecf3e66a3
commit
8a2afb7520
3 changed files with 54 additions and 23 deletions
|
@ -102,20 +102,12 @@ struct llama_hparams {
|
|||
|
||||
// Sliding Window Attention (SWA)
|
||||
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
|
||||
uint32_t n_swa = 0; // the size of the sliding window (0 - no SWA)
|
||||
uint32_t n_swa_pattern = 1; // this value n means that every nth layer is dense (i.e. non-SWA)
|
||||
// by default n == 1, all layers are dense
|
||||
// note that if n_swa_pattern == 0, all layers are SWA
|
||||
// example: n_swa_pattern = 3
|
||||
// il == 0: swa
|
||||
// il == 1: swa
|
||||
// il == 2: dense
|
||||
// il == 3: swa
|
||||
// il == 4: swa
|
||||
// il == 5: dense
|
||||
// il == 6: swa
|
||||
// etc ...
|
||||
// the size of the sliding window (0 - no SWA)
|
||||
uint32_t n_swa = 0;
|
||||
// if swa_layers[il] == true, then layer il is SWA
|
||||
// if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
|
||||
// by default, all layers are dense
|
||||
std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
|
||||
|
||||
// for State Space Models
|
||||
uint32_t ssm_d_conv = 0;
|
||||
|
@ -153,6 +145,25 @@ struct llama_hparams {
|
|||
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
||||
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
|
||||
|
||||
llama_hparams();
|
||||
|
||||
// this value n_pattern means that every nth layer is dense (i.e. non-SWA)
|
||||
// note that if n_pattern == 0, all layers are SWA
|
||||
// if n_pattern == 1, all layers are dense
|
||||
// example: n_pattern = 3
|
||||
// il == 0: swa
|
||||
// il == 1: swa
|
||||
// il == 2: dense
|
||||
// il == 3: swa
|
||||
// il == 4: swa
|
||||
// il == 5: dense
|
||||
// il == 6: swa
|
||||
// etc ...
|
||||
void set_swa_pattern(uint32_t n_pattern);
|
||||
|
||||
// return true if one of the layers is SWA
|
||||
bool is_swa_any() const;
|
||||
|
||||
uint32_t n_head(uint32_t il = 0) const;
|
||||
|
||||
uint32_t n_head_kv(uint32_t il = 0) const;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue