CUDA: batched+noncont MMQ, refactor bs>1 MoE code (#13199)
This commit is contained in:
parent
6f67cf1f48
commit
e1e8e0991f
9 changed files with 869 additions and 440 deletions
|
@ -1551,7 +1551,7 @@ static void ggml_cuda_op_mul_mat(
|
|||
|
||||
if (src1_on_device && src1_is_contiguous) {
|
||||
quantize_src1(
|
||||
dev[id].src1_ddf, dev[id].src1_ddq, src0->type, ne10,
|
||||
dev[id].src1_ddf, nullptr, dev[id].src1_ddq, src0->type, ne10,
|
||||
nb11/sizeof(float), nb12/sizeof(float), nb13/sizeof(float),
|
||||
src1_padded_col_size, ne11, ne12, ne13, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
@ -1649,7 +1649,7 @@ static void ggml_cuda_op_mul_mat(
|
|||
|
||||
if (quantize_src1 && !src1_is_contiguous) {
|
||||
quantize_src1(
|
||||
src1_ddf_i, src1_ddq_i, src0->type, ne10, ne10, ne11*ne10, ne12*ne11*ne10,
|
||||
src1_ddf_i, nullptr, src1_ddq_i, src0->type, ne10, ne10, ne11*ne10, ne12*ne11*ne10,
|
||||
src1_padded_col_size, src1_ncols, 1, 1, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
@ -1949,6 +1949,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|||
ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
|
||||
} else if (!split && use_mul_mat_vec_q) {
|
||||
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
|
||||
} else if (!split && use_mul_mat_q) {
|
||||
ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
|
||||
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
|
||||
!ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
||||
// general KQ + KQV multi-batch without FlashAttention
|
||||
|
@ -1964,183 +1966,145 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|||
}
|
||||
}
|
||||
|
||||
struct mmid_row_mapping {
|
||||
int32_t i1;
|
||||
int32_t i2;
|
||||
};
|
||||
|
||||
static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous,
|
||||
int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping,
|
||||
const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
|
||||
int64_t ne11, int64_t ne10,
|
||||
size_t nb11, size_t nb12) {
|
||||
int32_t iid1 = blockIdx.x;
|
||||
int32_t id = blockIdx.y;
|
||||
|
||||
const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
|
||||
|
||||
if (row_id_i != i02) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i11 = id % ne11;
|
||||
const int64_t i12 = iid1;
|
||||
|
||||
__shared__ int src1_row;
|
||||
if (threadIdx.x == 0) {
|
||||
src1_row = atomicAdd(cur_src1_row, 1);
|
||||
row_mapping[src1_row] = {id, iid1};
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
|
||||
float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
|
||||
|
||||
for (int i = threadIdx.x; i < ne10; i += blockDim.x) {
|
||||
src1_row_contiguous[i] = src1_row_original[i];
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous,
|
||||
const mmid_row_mapping * __restrict__ row_mapping,
|
||||
int64_t ne0,
|
||||
size_t nb1, size_t nb2) {
|
||||
int32_t i = blockIdx.x;
|
||||
|
||||
const int32_t i1 = row_mapping[i].i1;
|
||||
const int32_t i2 = row_mapping[i].i2;
|
||||
|
||||
const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
|
||||
float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
|
||||
|
||||
for (int j = threadIdx.x; j < ne0; j += blockDim.x) {
|
||||
dst_row_original[j] = dst_row_contiguous[j];
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * ids = dst->src[2];
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers");
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && ne2 == 1) {
|
||||
if (ggml_is_quantized(src0->type)) {
|
||||
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
|
||||
} else {
|
||||
ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
|
||||
GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers");
|
||||
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
if (ne2 == 1) {
|
||||
if (ggml_is_quantized(src0->type)) {
|
||||
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
|
||||
} else {
|
||||
ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (ggml_cuda_should_use_mmq(src0->type, cc, ne12)) {
|
||||
ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
const int64_t n_as = ne02;
|
||||
const int64_t n_ids = ids->ne[0];
|
||||
GGML_ASSERT(nb12 % nb11 == 0);
|
||||
GGML_ASSERT(nb2 % nb1 == 0);
|
||||
|
||||
const ggml_type type_src1_sorted = (src0->type == GGML_TYPE_F16 && !fast_fp16_hardware_available(cc))
|
||||
|| ggml_is_quantized(src0->type) ? GGML_TYPE_F32 : src0->type;
|
||||
const ggml_type type_dst_sorted = GGML_TYPE_F32;
|
||||
const size_t ts_src1_sorted = ggml_type_size(type_src1_sorted);
|
||||
const size_t ts_dst_sorted = ggml_type_size(type_dst_sorted);
|
||||
|
||||
const int64_t n_expert_used = ids->ne[0];
|
||||
const int64_t ne_get_rows = ne12 * n_expert_used;
|
||||
|
||||
std::vector<int32_t> ids_to_sorted_host;
|
||||
ids_to_sorted_host.reserve(2*ne_get_rows);
|
||||
std::vector<int32_t> ids_from_sorted_host(ne_get_rows);
|
||||
|
||||
ggml_cuda_pool_alloc<int32_t> ids_buf_dev(ctx.pool(), 2*ne_get_rows);
|
||||
|
||||
std::vector<int32_t> tokens_per_expert(ne02);
|
||||
|
||||
ggml_cuda_pool_alloc<char> src1_sorted(ctx.pool(), ne12*n_expert_used*ne10*ts_src1_sorted);
|
||||
ggml_cuda_pool_alloc<char> dst_sorted(ctx.pool(), ne2 *n_expert_used* ne0*ts_dst_sorted);
|
||||
|
||||
std::vector<char> ids_host(ggml_nbytes(ids));
|
||||
const char * ids_dev = (const char *) ids->data;
|
||||
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
ggml_tensor src0_row = *src0;
|
||||
ggml_tensor src1_row = *src1;
|
||||
ggml_tensor dst_row = *dst;
|
||||
|
||||
char * src0_original = (char *) src0->data;
|
||||
char * src1_original = (char *) src1->data;
|
||||
char * dst_original = (char *) dst->data;
|
||||
|
||||
src0_row.ne[2] = 1;
|
||||
src0_row.ne[3] = 1;
|
||||
src0_row.nb[3] = nb02;
|
||||
|
||||
src1_row.ne[1] = 1;
|
||||
src1_row.ne[2] = 1;
|
||||
src1_row.ne[3] = 1;
|
||||
src1_row.nb[2] = nb11;
|
||||
src1_row.nb[3] = nb11;
|
||||
|
||||
dst_row.ne[1] = 1;
|
||||
dst_row.ne[2] = 1;
|
||||
dst_row.ne[3] = 1;
|
||||
dst_row.nb[2] = nb1;
|
||||
dst_row.nb[3] = nb1;
|
||||
|
||||
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
|
||||
ggml_cuda_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
|
||||
|
||||
src1_row.data = src1_contiguous.get();
|
||||
dst_row.data = dst_contiguous.get();
|
||||
|
||||
for (int64_t i02 = 0; i02 < n_as; i02++) {
|
||||
int64_t num_src1_rows = 0;
|
||||
|
||||
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
|
||||
for (int64_t id = 0; id < n_ids; id++) {
|
||||
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
|
||||
|
||||
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
|
||||
|
||||
if (row_id_i != i02) {
|
||||
continue;
|
||||
for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens
|
||||
for (int64_t iex = 0; iex < n_expert_used; ++iex) {
|
||||
const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]);
|
||||
assert(expert_to_use >= 0 && expert_to_use < ne02);
|
||||
if (expert_to_use == i02) {
|
||||
ids_from_sorted_host[i12*n_expert_used + iex] = ids_to_sorted_host.size();
|
||||
ids_to_sorted_host.push_back(i12*ne11 + iex % ne11);
|
||||
tokens_per_expert[i02]++;
|
||||
break;
|
||||
}
|
||||
|
||||
num_src1_rows++;
|
||||
}
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(ids_to_sorted_host.size() == size_t(ne_get_rows));
|
||||
|
||||
if (num_src1_rows == 0) {
|
||||
ids_to_sorted_host.insert(ids_to_sorted_host.end(), ids_from_sorted_host.begin(), ids_from_sorted_host.end());
|
||||
|
||||
CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_to_sorted_host.data(), 2*ne_get_rows*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
const int32_t * ids_to_sorted = ids_buf_dev.ptr + 0*ne_get_rows;
|
||||
const int32_t * ids_from_sorted = ids_buf_dev.ptr + 1*ne_get_rows;
|
||||
|
||||
get_rows_cuda(src1->data, src1->type, ids_to_sorted, src1_sorted.ptr, type_src1_sorted,
|
||||
ne10, nb11, nb12, nb13,
|
||||
ne_get_rows, 1, 1, sizeof(int32_t), ne_get_rows*sizeof(int32_t), ne_get_rows*sizeof(int32_t),
|
||||
ne10*ts_src1_sorted, ne_get_rows*ne10*ts_src1_sorted, ne_get_rows*ne10*ts_src1_sorted, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
char * src1_data_cur = (char *) src1_sorted.ptr;
|
||||
char * dst_data_cur = (char *) dst_sorted.ptr;
|
||||
for (int64_t i02 = 0; i02 < ne02; ++i02) {
|
||||
if (tokens_per_expert[i02] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
|
||||
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
|
||||
CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
|
||||
ggml_tensor src0_slice = *src0;
|
||||
src0_slice.ne[2] = 1;
|
||||
src0_slice.nb[3] = src0_slice.nb[2];
|
||||
src0_slice.data = (char *) src0->data + i02*nb02;
|
||||
|
||||
{
|
||||
dim3 block_dims(std::min((unsigned int)ne10, 768u));
|
||||
dim3 grid_dims(ids->ne[1], n_ids);
|
||||
k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
|
||||
src1_original, src1_contiguous.get(),
|
||||
dev_cur_src1_row.get(), dev_row_mapping.get(),
|
||||
ids_dev, i02, ids->nb[1], ids->nb[0],
|
||||
ne11, ne10,
|
||||
nb11, nb12);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
ggml_tensor src1_slice;
|
||||
memset(&src1_slice, 0, sizeof(src1_slice));
|
||||
src1_slice.buffer = src1->buffer;
|
||||
src1_slice.type = type_src1_sorted;
|
||||
src1_slice.ne[0] = ne10;
|
||||
src1_slice.ne[1] = tokens_per_expert[i02];
|
||||
src1_slice.ne[2] = 1;
|
||||
src1_slice.ne[3] = 1;
|
||||
src1_slice.nb[0] = ts_src1_sorted;
|
||||
src1_slice.nb[1] = src1_slice.ne[0] * src1_slice.nb[0];
|
||||
src1_slice.nb[2] = src1_slice.ne[1] * src1_slice.nb[1];
|
||||
src1_slice.nb[3] = src1_slice.ne[2] * src1_slice.nb[2];
|
||||
src1_slice.data = src1_data_cur;
|
||||
|
||||
src0_row.data = src0_original + i02*nb02;
|
||||
ggml_tensor dst_slice;
|
||||
memset(&dst_slice, 0, sizeof(dst_slice));
|
||||
dst_slice.buffer = dst->buffer;
|
||||
dst_slice.type = type_dst_sorted;
|
||||
dst_slice.ne[0] = ne0;
|
||||
dst_slice.ne[1] = tokens_per_expert[i02];
|
||||
dst_slice.ne[2] = 1;
|
||||
dst_slice.ne[3] = 1;
|
||||
dst_slice.nb[0] = ts_dst_sorted;
|
||||
dst_slice.nb[1] = dst_slice.ne[0] * dst_slice.nb[0];
|
||||
dst_slice.nb[2] = dst_slice.ne[1] * dst_slice.nb[1];
|
||||
dst_slice.nb[3] = dst_slice.ne[2] * dst_slice.nb[2];
|
||||
dst_slice.data = dst_data_cur;
|
||||
|
||||
GGML_ASSERT(nb11 == sizeof(float)*ne10);
|
||||
GGML_ASSERT(nb1 == sizeof(float)*ne0);
|
||||
ggml_cuda_mul_mat(ctx, &src0_slice, &src1_slice, &dst_slice);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
src1_row.ne[1] = num_src1_rows;
|
||||
src1_row.nb[1] = nb11;
|
||||
src1_row.nb[2] = num_src1_rows*nb11;
|
||||
src1_row.nb[3] = num_src1_rows*nb11;
|
||||
|
||||
dst_row.ne[1] = num_src1_rows;
|
||||
dst_row.nb[1] = nb1;
|
||||
dst_row.nb[2] = num_src1_rows*nb1;
|
||||
dst_row.nb[3] = num_src1_rows*nb1;
|
||||
|
||||
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
|
||||
|
||||
{
|
||||
dim3 block_dims(std::min((unsigned int)ne0, 768u));
|
||||
dim3 grid_dims(num_src1_rows);
|
||||
k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
|
||||
dst_original, dst_contiguous.get(),
|
||||
dev_row_mapping.get(),
|
||||
ne0,
|
||||
nb1, nb2);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
src1_data_cur += src1_slice.nb[2];
|
||||
dst_data_cur += dst_slice.nb[2];
|
||||
}
|
||||
|
||||
get_rows_cuda(dst_sorted.ptr, type_dst_sorted, ids_from_sorted, dst->data, dst->type,
|
||||
ne0, ne0*ts_dst_sorted, ne_get_rows*ne0*ts_dst_sorted, ne_get_rows*ne0*ts_dst_sorted,
|
||||
ne_get_rows, 1, 1, sizeof(int32_t), ne_get_rows*sizeof(int32_t), ne_get_rows*sizeof(int32_t),
|
||||
nb1, nb2, nb3, stream);
|
||||
}
|
||||
|
||||
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue