llama: Add support for RWKV v7 architecture (#12412)

* ggml: Add op l2_norm

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* ggml: Add op rwkv_wkv7

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: Add support for RWKV7 and ARWKV7 models

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: fix inference with RWKV6Qwen2

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: add more (a)rwkv7 variants in size

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Apply code-format changes

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* fix MUSA build

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: fix shape error with rwkv using llama-parallel

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

---------

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2025-03-18 07:27:50 +08:00 committed by GitHub
parent 60c902926c
commit 7dfad387e3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
35 changed files with 2948 additions and 438 deletions

View file

@ -118,22 +118,26 @@ class Keys:
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
CLAMP_KQV = "{arch}.attention.clamp_kqv"
KEY_LENGTH = "{arch}.attention.key_length"
VALUE_LENGTH = "{arch}.attention.value_length"
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
GROUPNORM_EPS = "{arch}.attention.group_norm_epsilon"
GROUPNORM_GROUPS = "{arch}.attention.group_norm_groups"
CAUSAL = "{arch}.attention.causal"
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
SLIDING_WINDOW = "{arch}.attention.sliding_window"
SCALE = "{arch}.attention.scale"
HEAD_COUNT = "{arch}.attention.head_count"
HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
CLAMP_KQV = "{arch}.attention.clamp_kqv"
KEY_LENGTH = "{arch}.attention.key_length"
VALUE_LENGTH = "{arch}.attention.value_length"
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
GROUPNORM_EPS = "{arch}.attention.group_norm_epsilon"
GROUPNORM_GROUPS = "{arch}.attention.group_norm_groups"
CAUSAL = "{arch}.attention.causal"
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
DECAY_LORA_RANK = "{arch}.attention.decay_lora_rank"
ICLR_LORA_RANK = "{arch}.attention.iclr_lora_rank"
VALUE_RESIDUAL_MIX_LORA_RANK = "{arch}.attention.value_residual_mix_lora_rank"
GATE_LORA_RANK = "{arch}.attention.gate_lora_rank"
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
SLIDING_WINDOW = "{arch}.attention.sliding_window"
SCALE = "{arch}.attention.scale"
class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
@ -257,6 +261,8 @@ class MODEL_ARCH(IntEnum):
STARCODER2 = auto()
RWKV6 = auto()
RWKV6QWEN2 = auto()
RWKV7 = auto()
ARWKV7 = auto()
MAMBA = auto()
XVERSE = auto()
COMMAND_R = auto()
@ -329,8 +335,20 @@ class MODEL_TENSOR(IntEnum):
SSM_A = auto()
SSM_D = auto()
SSM_OUT = auto()
TIME_MIX_W0 = auto()
TIME_MIX_W1 = auto()
TIME_MIX_W2 = auto()
TIME_MIX_A0 = auto()
TIME_MIX_A1 = auto()
TIME_MIX_A2 = auto()
TIME_MIX_V0 = auto()
TIME_MIX_V1 = auto()
TIME_MIX_V2 = auto()
TIME_MIX_G1 = auto()
TIME_MIX_G2 = auto()
TIME_MIX_K_K = auto()
TIME_MIX_K_A = auto()
TIME_MIX_R_K = auto()
TIME_MIX_LERP_X = auto()
TIME_MIX_LERP_K = auto()
TIME_MIX_LERP_V = auto()
@ -445,6 +463,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
MODEL_ARCH.RWKV7: "rwkv7",
MODEL_ARCH.ARWKV7: "arwkv7",
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
@ -517,8 +537,20 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
MODEL_TENSOR.TIME_MIX_A0: "blk.{bid}.time_mix_a0",
MODEL_TENSOR.TIME_MIX_A1: "blk.{bid}.time_mix_a1",
MODEL_TENSOR.TIME_MIX_A2: "blk.{bid}.time_mix_a2",
MODEL_TENSOR.TIME_MIX_V0: "blk.{bid}.time_mix_v0",
MODEL_TENSOR.TIME_MIX_V1: "blk.{bid}.time_mix_v1",
MODEL_TENSOR.TIME_MIX_V2: "blk.{bid}.time_mix_v2",
MODEL_TENSOR.TIME_MIX_G1: "blk.{bid}.time_mix_g1",
MODEL_TENSOR.TIME_MIX_G2: "blk.{bid}.time_mix_g2",
MODEL_TENSOR.TIME_MIX_K_K: "blk.{bid}.time_mix_k_k",
MODEL_TENSOR.TIME_MIX_K_A: "blk.{bid}.time_mix_k_a",
MODEL_TENSOR.TIME_MIX_R_K: "blk.{bid}.time_mix_r_k",
MODEL_TENSOR.TIME_MIX_LERP_X: "blk.{bid}.time_mix_lerp_x",
MODEL_TENSOR.TIME_MIX_LERP_K: "blk.{bid}.time_mix_lerp_k",
MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v",
@ -1172,6 +1204,68 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.RWKV7: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_NORM_2,
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
MODEL_TENSOR.TIME_MIX_W0,
MODEL_TENSOR.TIME_MIX_W1,
MODEL_TENSOR.TIME_MIX_W2,
MODEL_TENSOR.TIME_MIX_A0,
MODEL_TENSOR.TIME_MIX_A1,
MODEL_TENSOR.TIME_MIX_A2,
MODEL_TENSOR.TIME_MIX_V0,
MODEL_TENSOR.TIME_MIX_V1,
MODEL_TENSOR.TIME_MIX_V2,
MODEL_TENSOR.TIME_MIX_G1,
MODEL_TENSOR.TIME_MIX_G2,
MODEL_TENSOR.TIME_MIX_K_K,
MODEL_TENSOR.TIME_MIX_K_A,
MODEL_TENSOR.TIME_MIX_R_K,
MODEL_TENSOR.TIME_MIX_KEY,
MODEL_TENSOR.TIME_MIX_VALUE,
MODEL_TENSOR.TIME_MIX_RECEPTANCE,
MODEL_TENSOR.TIME_MIX_LN,
MODEL_TENSOR.TIME_MIX_OUTPUT,
MODEL_TENSOR.CHANNEL_MIX_LERP_K,
MODEL_TENSOR.CHANNEL_MIX_KEY,
MODEL_TENSOR.CHANNEL_MIX_VALUE,
],
MODEL_ARCH.ARWKV7: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
MODEL_TENSOR.TIME_MIX_W0,
MODEL_TENSOR.TIME_MIX_W1,
MODEL_TENSOR.TIME_MIX_W2,
MODEL_TENSOR.TIME_MIX_A0,
MODEL_TENSOR.TIME_MIX_A1,
MODEL_TENSOR.TIME_MIX_A2,
MODEL_TENSOR.TIME_MIX_V0,
MODEL_TENSOR.TIME_MIX_V1,
MODEL_TENSOR.TIME_MIX_V2,
MODEL_TENSOR.TIME_MIX_G1,
MODEL_TENSOR.TIME_MIX_G2,
MODEL_TENSOR.TIME_MIX_K_K,
MODEL_TENSOR.TIME_MIX_K_A,
MODEL_TENSOR.TIME_MIX_R_K,
MODEL_TENSOR.TIME_MIX_KEY,
MODEL_TENSOR.TIME_MIX_VALUE,
MODEL_TENSOR.TIME_MIX_RECEPTANCE,
MODEL_TENSOR.TIME_MIX_LN,
MODEL_TENSOR.TIME_MIX_OUTPUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.MAMBA: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,