llama: Add support for RWKV v7 architecture (#12412)
* ggml: Add op l2_norm Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * ggml: Add op rwkv_wkv7 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: Add support for RWKV7 and ARWKV7 models Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: fix inference with RWKV6Qwen2 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: add more (a)rwkv7 variants in size Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Apply code-format changes Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * fix MUSA build Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: fix shape error with rwkv using llama-parallel Signed-off-by: Molly Sophia <mollysophia379@gmail.com> --------- Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
parent
60c902926c
commit
7dfad387e3
35 changed files with 2948 additions and 438 deletions
|
@ -36,7 +36,7 @@
|
|||
#include "ggml-cuda/tsembd.cuh"
|
||||
#include "ggml-cuda/unary.cuh"
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/wkv6.cuh"
|
||||
#include "ggml-cuda/wkv.cuh"
|
||||
#include "ggml-cuda/gla.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
|
@ -2196,6 +2196,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_GROUP_NORM:
|
||||
ggml_cuda_op_group_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_L2_NORM:
|
||||
ggml_cuda_op_l2_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONCAT:
|
||||
ggml_cuda_op_concat(ctx, dst);
|
||||
break;
|
||||
|
@ -2304,6 +2307,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
ggml_cuda_op_gated_linear_attn(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
ggml_cuda_op_rwkv_wkv7(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
||||
break;
|
||||
|
@ -3161,6 +3167,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
break;
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return true;
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
|
||||
|
@ -3215,6 +3222,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT: {
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
|
|
|
@ -201,6 +201,85 @@ static __global__ void rms_norm_back_f32(
|
|||
}
|
||||
}
|
||||
|
||||
// template <int block_size>
|
||||
// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
||||
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
// const int tid = threadIdx.x;
|
||||
|
||||
// float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
// for (int col = tid; col < ncols; col += block_size) {
|
||||
// const float xi = x[row*ncols + col];
|
||||
// tmp += xi * xi;
|
||||
// }
|
||||
|
||||
// // sum up partial sums
|
||||
// tmp = warp_reduce_sum(tmp);
|
||||
// if (block_size > WARP_SIZE) {
|
||||
// __shared__ float s_sum[32];
|
||||
// int warp_id = threadIdx.x / WARP_SIZE;
|
||||
// int lane_id = threadIdx.x % WARP_SIZE;
|
||||
// if (lane_id == 0) {
|
||||
// s_sum[warp_id] = tmp;
|
||||
// }
|
||||
// __syncthreads();
|
||||
// tmp = s_sum[lane_id];
|
||||
// tmp = warp_reduce_sum(tmp);
|
||||
// }
|
||||
|
||||
// // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
|
||||
// const float scale = rsqrtf(fmaxf(tmp, eps * eps));
|
||||
|
||||
// for (int col = tid; col < ncols; col += block_size) {
|
||||
// dst[row*ncols + col] = scale * x[row*ncols + col];
|
||||
// }
|
||||
// }
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void l2_norm_f32(
|
||||
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
||||
const int64_t stride_sample, const float eps) {
|
||||
const int nrows = gridDim.x;
|
||||
const int nchannels = gridDim.y;
|
||||
|
||||
const int row = blockIdx.x;
|
||||
const int channel = blockIdx.y;
|
||||
const int sample = blockIdx.z;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
x += sample*stride_sample + channel*stride_channel + row*stride_row;
|
||||
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[col];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
if constexpr (block_size > WARP_SIZE) {
|
||||
static_assert(block_size == 1024, "unexpected block_size");
|
||||
__shared__ float s_sum[32];
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = tmp;
|
||||
}
|
||||
__syncthreads();
|
||||
tmp = s_sum[lane_id];
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
}
|
||||
|
||||
// from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
|
||||
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[col] = scale * x[col];
|
||||
}
|
||||
}
|
||||
|
||||
static void norm_f32_cuda(
|
||||
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
||||
|
@ -248,6 +327,19 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
|
|||
}
|
||||
}
|
||||
|
||||
static void l2_norm_f32_cuda(
|
||||
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
||||
const dim3 blocks_num(nrows, nchannels, nsamples);
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
l2_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
|
@ -340,3 +432,27 @@ void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * d
|
|||
|
||||
rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
const size_t ts0 = ggml_type_size(src0->type);
|
||||
GGML_ASSERT(nb00 == ts0);
|
||||
const int64_t s01 = nb01 / ts0;
|
||||
const int64_t s02 = nb02 / ts0;
|
||||
const int64_t s03 = nb03 / ts0;
|
||||
|
||||
l2_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
|
||||
}
|
||||
|
|
|
@ -7,3 +7,5 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|||
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
199
ggml/src/ggml-cuda/wkv.cu
Normal file
199
ggml/src/ggml-cuda/wkv.cu
Normal file
|
@ -0,0 +1,199 @@
|
|||
#include "common.cuh"
|
||||
#include "wkv.cuh"
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
|
||||
const int tid = threadIdx.x;
|
||||
const int bid = blockIdx.x;
|
||||
|
||||
const int head_size = block_size;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
float state[head_size];
|
||||
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
_tf[tid] = tf[head_i * head_size + tid];
|
||||
__syncthreads();
|
||||
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
||||
__syncthreads();
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
__syncthreads();
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0;
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
const float4& k = (float4&)(_k[j]);
|
||||
const float4& r = (float4&)(_r[j]);
|
||||
const float4& tf = (float4&)(_tf[j]);
|
||||
const float4& td = (float4&)(_td[j]);
|
||||
float4& s = (float4&)(state[j]);
|
||||
float4 kv;
|
||||
|
||||
kv.x = k.x * _v;
|
||||
kv.y = k.y * _v;
|
||||
kv.z = k.z * _v;
|
||||
kv.w = k.w * _v;
|
||||
|
||||
y += r.x * (tf.x * kv.x + s.x);
|
||||
y += r.y * (tf.y * kv.y + s.y);
|
||||
y += r.z * (tf.z * kv.z + s.z);
|
||||
y += r.w * (tf.w * kv.w + s.w);
|
||||
|
||||
s.x = s.x * td.x + kv.x;
|
||||
s.y = s.y * td.y + kv.y;
|
||||
s.z = s.z * td.z + kv.z;
|
||||
s.w = s.w * td.w + kv.w;
|
||||
}
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) {
|
||||
const int tid = threadIdx.x;
|
||||
const int bid = blockIdx.x;
|
||||
|
||||
const int head_size = block_size;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
float state[head_size];
|
||||
__shared__ float _r[head_size], _w[head_size], _k[head_size], _a[head_size], _b[head_size];
|
||||
|
||||
#ifndef GGML_USE_MUSA
|
||||
#pragma unroll
|
||||
#endif
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
|
||||
}
|
||||
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
||||
__syncthreads();
|
||||
_r[tid] = r[t];
|
||||
_w[tid] = w[t];
|
||||
_k[tid] = k[t];
|
||||
_a[tid] = a[t];
|
||||
_b[tid] = b[t];
|
||||
__syncthreads();
|
||||
|
||||
float sa = 0;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < head_size; j += 4)
|
||||
{
|
||||
const float4& a = (float4&)(_a[j]);
|
||||
const float4& s = (float4&)(state[j]);
|
||||
sa += a.x * s.x;
|
||||
sa += a.y * s.y;
|
||||
sa += a.z * s.z;
|
||||
sa += a.w * s.w;
|
||||
}
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0;
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
const float4& r = (float4&)(_r[j]);
|
||||
const float4& w = (float4&)(_w[j]);
|
||||
const float4& k = (float4&)(_k[j]);
|
||||
const float4& b = (float4&)(_b[j]);
|
||||
float4& s = (float4&)(state[j]);
|
||||
float4 kv;
|
||||
|
||||
kv.x = k.x * _v;
|
||||
kv.y = k.y * _v;
|
||||
kv.z = k.z * _v;
|
||||
kv.w = k.w * _v;
|
||||
|
||||
s.x = s.x * w.x + kv.x + sa * b.x;
|
||||
s.y = s.y * w.y + kv.y + sa * b.y;
|
||||
s.z = s.z * w.z + kv.z + sa * b.z;
|
||||
s.w = s.w * w.w + kv.w + sa * b.w;
|
||||
|
||||
y += s.x * r.x;
|
||||
y += s.y * r.y;
|
||||
y += s.z * r.z;
|
||||
y += s.w * r.w;
|
||||
}
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const float * k_d = (const float *)dst->src[0]->data;
|
||||
const float * v_d = (const float *)dst->src[1]->data;
|
||||
const float * r_d = (const float *)dst->src[2]->data;
|
||||
const float * tf_d = (const float *)dst->src[3]->data;
|
||||
const float * td_d = (const float *)dst->src[4]->data;
|
||||
const float * s_d = (const float *)dst->src[5]->data;
|
||||
|
||||
const int64_t B = dst->src[5]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
|
||||
|
||||
if (C / H == CUDA_WKV_BLOCK_SIZE) {
|
||||
rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
||||
} else {
|
||||
rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const float * r_d = (const float *)dst->src[0]->data;
|
||||
const float * w_d = (const float *)dst->src[1]->data;
|
||||
const float * k_d = (const float *)dst->src[2]->data;
|
||||
const float * v_d = (const float *)dst->src[3]->data;
|
||||
const float * a_d = (const float *)dst->src[4]->data;
|
||||
const float * b_d = (const float *)dst->src[5]->data;
|
||||
const float * s_d = (const float *)dst->src[6]->data;
|
||||
|
||||
const int64_t B = dst->src[6]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
|
||||
|
||||
if (C / H == CUDA_WKV_BLOCK_SIZE) {
|
||||
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
|
||||
} else {
|
||||
rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
|
||||
}
|
||||
}
|
|
@ -3,3 +3,5 @@
|
|||
#define CUDA_WKV_BLOCK_SIZE 64
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@ -1,89 +0,0 @@
|
|||
#include "common.cuh"
|
||||
#include "wkv6.cuh"
|
||||
|
||||
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
|
||||
const int tid = threadIdx.x;
|
||||
const int bid = blockIdx.x;
|
||||
|
||||
const int head_size = CUDA_WKV_BLOCK_SIZE;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
float state[head_size];
|
||||
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
_tf[tid] = tf[head_i * head_size + tid];
|
||||
__syncthreads();
|
||||
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
||||
__syncthreads();
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
__syncthreads();
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0;
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
const float4& k = (float4&)(_k[j]);
|
||||
const float4& r = (float4&)(_r[j]);
|
||||
const float4& tf = (float4&)(_tf[j]);
|
||||
const float4& td = (float4&)(_td[j]);
|
||||
float4& s = (float4&)(state[j]);
|
||||
float4 kv;
|
||||
|
||||
kv.x = k.x * _v;
|
||||
kv.y = k.y * _v;
|
||||
kv.z = k.z * _v;
|
||||
kv.w = k.w * _v;
|
||||
|
||||
y += r.x * (tf.x * kv.x + s.x);
|
||||
y += r.y * (tf.y * kv.y + s.y);
|
||||
y += r.z * (tf.z * kv.z + s.z);
|
||||
y += r.w * (tf.w * kv.w + s.w);
|
||||
|
||||
s.x = s.x * td.x + kv.x;
|
||||
s.y = s.y * td.y + kv.y;
|
||||
s.z = s.z * td.z + kv.z;
|
||||
s.w = s.w * td.w + kv.w;
|
||||
}
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const float * k_d = (const float *)dst->src[0]->data;
|
||||
const float * v_d = (const float *)dst->src[1]->data;
|
||||
const float * r_d = (const float *)dst->src[2]->data;
|
||||
const float * tf_d = (const float *)dst->src[3]->data;
|
||||
const float * td_d = (const float *)dst->src[4]->data;
|
||||
const float * s_d = (const float *)dst->src[5]->data;
|
||||
|
||||
const int64_t B = dst->src[5]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64
|
||||
|
||||
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue