llama: Add support for RWKV v7 architecture (#12412)
* ggml: Add op l2_norm Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * ggml: Add op rwkv_wkv7 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: Add support for RWKV7 and ARWKV7 models Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: fix inference with RWKV6Qwen2 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: add more (a)rwkv7 variants in size Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Apply code-format changes Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * fix MUSA build Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: fix shape error with rwkv using llama-parallel Signed-off-by: Molly Sophia <mollysophia379@gmail.com> --------- Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
parent
60c902926c
commit
7dfad387e3
35 changed files with 2948 additions and 438 deletions
|
@ -908,6 +908,40 @@ class Model:
|
|||
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_rwkv_world(self):
|
||||
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
|
||||
vocab_size = self.hparams.get("vocab_size", 65536)
|
||||
|
||||
tokens: list[bytes] = ['<s>'.encode("utf-8")]
|
||||
toktypes: list[int] = [gguf.TokenType.CONTROL]
|
||||
|
||||
with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
parts = line.split(' ')
|
||||
assert len(parts) >= 3
|
||||
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
|
||||
token = token.encode("utf-8") if isinstance(token, str) else token
|
||||
assert isinstance(token, bytes)
|
||||
assert len(token) == token_len
|
||||
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
|
||||
tokens.append(token_text.encode("utf-8"))
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
remainder = vocab_size - len(tokens)
|
||||
assert remainder >= 0
|
||||
for i in range(len(tokens), vocab_size):
|
||||
tokens.append(f"[PAD{i}]".encode("utf-8"))
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("rwkv")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
|
||||
special_vocab.chat_template = "rwkv-world"
|
||||
# hack: Add '\n\n' as the EOT token to make it chat normally
|
||||
special_vocab._set_special_token("eot", 261)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab_size: int):
|
||||
tokenizer_path = Path(sys.path[0]) / "models" / f"ggml-vocab-{model_name}.gguf"
|
||||
logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
|
||||
|
@ -3412,38 +3446,7 @@ class Rwkv6Model(Model):
|
|||
model_arch = gguf.MODEL_ARCH.RWKV6
|
||||
|
||||
def set_vocab(self):
|
||||
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
|
||||
vocab_size = self.hparams.get("vocab_size", 65536)
|
||||
|
||||
tokens: list[bytes] = ['<s>'.encode("utf-8")]
|
||||
toktypes: list[int] = [gguf.TokenType.CONTROL]
|
||||
|
||||
with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
parts = line.split(' ')
|
||||
assert len(parts) >= 3
|
||||
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
|
||||
token = token.encode("utf-8") if isinstance(token, str) else token
|
||||
assert isinstance(token, bytes)
|
||||
assert len(token) == token_len
|
||||
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
|
||||
tokens.append(token_text.encode("utf-8"))
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
remainder = vocab_size - len(tokens)
|
||||
assert remainder >= 0
|
||||
for i in range(len(tokens), vocab_size):
|
||||
tokens.append(f"[PAD{i}]".encode("utf-8"))
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("rwkv")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
|
||||
special_vocab.chat_template = "rwkv-world"
|
||||
# hack: Add '\n\n' as the EOT token to make it chat normally
|
||||
special_vocab._set_special_token("eot", 261)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
self._set_vocab_rwkv_world()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
|
@ -3565,6 +3568,168 @@ class RWKV6Qwen2Model(Rwkv6Model):
|
|||
yield (new_name, data)
|
||||
|
||||
|
||||
@Model.register("Rwkv7ForCausalLM", "RWKV7ForCausalLM")
|
||||
class Rwkv7Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.RWKV7
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_rwkv_world()
|
||||
|
||||
def calc_lora_rank(self, hidden_size, exponent, multiplier):
|
||||
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
try:
|
||||
head_size = self.hparams["head_size"]
|
||||
layer_norm_eps = self.hparams["layer_norm_epsilon"]
|
||||
except KeyError:
|
||||
head_size = self.hparams["head_dim"]
|
||||
layer_norm_eps = self.hparams["norm_eps"]
|
||||
hidden_size = self.hparams["hidden_size"]
|
||||
intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else (hidden_size * 4)
|
||||
|
||||
# ICLR: In-Context-Learning-Rate
|
||||
try:
|
||||
lora_rank_decay = self.hparams["lora_rank_decay"] if self.hparams["lora_rank_decay"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
|
||||
lora_rank_iclr = self.hparams["lora_rank_iclr"] if self.hparams["lora_rank_iclr"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
|
||||
lora_rank_value_residual_mix = self.hparams["lora_rank_value_residual_mix"] if self.hparams["lora_rank_value_residual_mix"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
|
||||
lora_rank_gate = self.hparams["lora_rank_gate"] if self.hparams["lora_rank_gate"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)
|
||||
except KeyError:
|
||||
lora_rank_decay = self.hparams["decay_low_rank_dim"] if self.hparams["decay_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
|
||||
lora_rank_iclr = self.hparams["a_low_rank_dim"] if self.hparams["a_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
|
||||
lora_rank_value_residual_mix = self.hparams["v_low_rank_dim"] if self.hparams["v_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
|
||||
lora_rank_gate = self.hparams["gate_low_rank_dim"] if self.hparams["gate_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)
|
||||
|
||||
# RWKV isn't context limited
|
||||
self.gguf_writer.add_context_length(1048576)
|
||||
self.gguf_writer.add_embedding_length(hidden_size)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
|
||||
self.gguf_writer.add_wkv_head_size(head_size)
|
||||
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
|
||||
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
|
||||
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
|
||||
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
|
||||
self.gguf_writer.add_feed_forward_length(intermediate_size)
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
# required by llama.cpp, unused
|
||||
self.gguf_writer.add_head_count(0)
|
||||
|
||||
lerp_weights: dict[int, dict[str, Tensor]] = {}
|
||||
lora_needs_transpose: bool = True
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# unify tensor names here to make life easier
|
||||
name = name.replace("blocks", "layers").replace("ffn", "feed_forward")
|
||||
name = name.replace("self_attn", "attention").replace("attn", "attention")
|
||||
name = name.replace("time_mixer.", "")
|
||||
# lora layer names in fla-hub's impl
|
||||
if "_lora.lora" in name:
|
||||
self.lora_needs_transpose = False
|
||||
name = name.replace("_lora.lora.0.weight", "1.weight")
|
||||
name = name.replace("_lora.lora.2.weight", "2.weight")
|
||||
name = name.replace("_lora.lora.2.bias", "0.weight")
|
||||
|
||||
name = name.replace("feed_forward_norm", "ln2")
|
||||
name = name.replace("g_norm", "ln_x")
|
||||
|
||||
if "attention.v" in name and "value" not in self.map_tensor_name(name) and bid == 0:
|
||||
# some models have dummy v0/v1/v2 on first layer while others don't
|
||||
# ignore them all since they are not used
|
||||
return
|
||||
|
||||
wkv_has_gate = self.hparams.get("wkv_has_gate", True)
|
||||
lerp_list = ["r", "w", "k", "v", "a", "g"] if wkv_has_gate else ["r", "w", "k", "v", "a"]
|
||||
|
||||
if bid is not None and "attention.x_" in name:
|
||||
if "attention.x_x" in name:
|
||||
# already concatenated
|
||||
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
|
||||
data = data_torch.reshape(len(lerp_list), 1, 1, -1)
|
||||
yield (new_name, data)
|
||||
else:
|
||||
try:
|
||||
self.lerp_weights[bid][name] = data_torch
|
||||
except KeyError:
|
||||
self.lerp_weights[bid] = {name: data_torch}
|
||||
if all(f"model.layers.{bid}.attention.x_{i}" in self.lerp_weights[bid].keys() for i in lerp_list):
|
||||
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
|
||||
data = torch.stack([self.lerp_weights[bid][f"model.layers.{bid}.attention.x_{i}"] for i in lerp_list], dim=0)
|
||||
yield (new_name, data)
|
||||
return
|
||||
else:
|
||||
data_torch = data_torch.squeeze()
|
||||
new_name = self.map_tensor_name(name)
|
||||
|
||||
if not (new_name.endswith(".weight") or new_name.endswith(".bias")):
|
||||
new_name += ".weight"
|
||||
|
||||
if self.lora_needs_transpose and any(
|
||||
new_name.endswith(t) for t in [
|
||||
"time_mix_w1.weight", "time_mix_w2.weight",
|
||||
"time_mix_a1.weight", "time_mix_a2.weight",
|
||||
"time_mix_v1.weight", "time_mix_v2.weight",
|
||||
"time_mix_g1.weight", "time_mix_g2.weight",
|
||||
]
|
||||
):
|
||||
data_torch = data_torch.transpose(0, 1)
|
||||
|
||||
if 'r_k' in new_name:
|
||||
data_torch = data_torch.flatten()
|
||||
|
||||
if bid == 0 and "time_mix_a" in new_name:
|
||||
# dummy v0/v1/v2 on first layer
|
||||
# easist way to make llama happy
|
||||
yield (new_name.replace("time_mix_a", "time_mix_v"), data_torch)
|
||||
|
||||
yield (new_name, data_torch)
|
||||
|
||||
|
||||
@Model.register("RwkvHybridForCausalLM")
|
||||
class ARwkv7Model(Rwkv7Model):
|
||||
model_arch = gguf.MODEL_ARCH.ARWKV7
|
||||
|
||||
def set_vocab(self):
|
||||
try:
|
||||
self._set_vocab_sentencepiece()
|
||||
except FileNotFoundError:
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
hidden_size = self.hparams["hidden_size"]
|
||||
head_size = self.hparams["head_size"]
|
||||
rms_norm_eps = self.hparams["rms_norm_eps"]
|
||||
intermediate_size = self.hparams["intermediate_size"]
|
||||
wkv_has_gate = self.hparams["wkv_has_gate"]
|
||||
assert self.hparams["wkv_version"] == 7
|
||||
|
||||
# ICLR: In-Context-Learning-Rate
|
||||
lora_rank_decay = 64
|
||||
lora_rank_iclr = 64
|
||||
lora_rank_value_residual_mix = 32
|
||||
lora_rank_gate = 128 if wkv_has_gate else 0
|
||||
|
||||
# RWKV isn't context limited
|
||||
self.gguf_writer.add_context_length(1048576)
|
||||
self.gguf_writer.add_embedding_length(hidden_size)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||
self.gguf_writer.add_wkv_head_size(head_size)
|
||||
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
|
||||
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
|
||||
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
|
||||
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
|
||||
self.gguf_writer.add_feed_forward_length(intermediate_size)
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
self.gguf_writer.add_token_shift_count(1)
|
||||
|
||||
# required by llama.cpp, unused
|
||||
self.gguf_writer.add_head_count(0)
|
||||
|
||||
|
||||
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
|
||||
class MambaModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.MAMBA
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue