llama : add Mixtral support (#4406)
* convert : support Mixtral as LLAMA arch * convert : fix n_ff typo * llama : model loading * ggml : sync latest ggml_mul_mat_id * llama : update graph to support MoE * llama : fix cur -> cur_expert * llama : first working version * llama : fix expert weighting in the FFN * ggml : ggml_get_rows support 2D indexing [n_tokens, n_experts] (cpu only) * ggml : add n_as argument to ggml_mul_mat_id * ggml : fix ggml_get_rows to take into account ne02 / ne11 * metal : add more general support for ggml_get_rows + tests * llama : add basic support for offloading moe with CUDA * metal : add/mul/div use general kernel when src1 not cont * metal : reduce the kernel launches for ggml_mul_mat_id * ggml : get_rows : support non-contiguos tensors with gaps, generalize up to 3D * ggml : update get_rows f16 and q * cuda : support non-contiguous src1 in get_rows * llama : offload missing ffn_moe_silu * metal : fix ggml_get_rows to work with non-cont src1 * metal : add indirect mat-vec kernels for all quantization types * llama : do not quantize expert gating tensors * llama : add n_expert and n_expert_used to hparams + change quants * test-backend-ops : add moe test * cuda : fix get_rows when ncols is odd * convert : determine n_ctx correctly * metal : fix ggml_mul_mat_id for F32 * test-backend-ops : make experts more evenly probable (test_moe) * test-backend-ops : cleanup, add moe test for batches * test-backend-ops : add cpy from f32 -> all types test * test-backend-ops : fix dequantize block offset * llama : fix hard-coded number of experts * test-backend-ops : simplify and disable slow tests to avoid CI timeout * test-backend-ops : disable MOE test with thread sanitizer * cuda : fix mul_mat_id with multi gpu * convert : use 1e6 rope_freq_base for mixtral * convert : fix style * convert : support safetensors format * gguf-py : bump version * metal : add cpy f16 -> f32 kernel * metal : fix binary ops for ne10 % 4 != 0 * test-backend-ops : add one more sum_rows test * ggml : do not use BLAS with ggml_mul_mat_id * convert-hf : support for mixtral-instruct (#4428) * convert : typo fix, add additional hyperparameters, use LLaMA arch for Mixtral-instruct * convert : use sentencepiece tokenizer for Mixtral-instruct * convert : make flake8 happy * metal : fix soft_max kernels ref: https://github.com/ggerganov/ggml/pull/621/commits/1914017863d2f9ab8ecc0281cc2a56d683668b92 * metal : limit kernels to not use more than the allowed threads --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Radek Pilar <github@mrkva.eu>
This commit is contained in:
parent
fecac45658
commit
799a1cb13b
14 changed files with 2370 additions and 395 deletions
297
ggml-cuda.cu
297
ggml-cuda.cu
|
@ -1,13 +1,15 @@
|
|||
#include <algorithm>
|
||||
#include <assert.h>
|
||||
#include <atomic>
|
||||
#include <cinttypes>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cinttypes>
|
||||
#include <float.h>
|
||||
#include <limits>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <atomic>
|
||||
#include <assert.h>
|
||||
#include <vector>
|
||||
|
||||
|
||||
#if defined(GGML_USE_HIPBLAS)
|
||||
#include <hip/hip_runtime.h>
|
||||
|
@ -1684,31 +1686,65 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
|
|||
}
|
||||
|
||||
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static __global__ void k_get_rows(const void * x, const int32_t * y, dst_t * dst, const int ncols) {
|
||||
const int col = (blockIdx.x*blockDim.x + threadIdx.x)*2;
|
||||
const int row = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
static __global__ void k_get_rows(
|
||||
const void * src0, const int32_t * src1, dst_t * dst,
|
||||
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
|
||||
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
|
||||
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
|
||||
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
|
||||
size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
|
||||
|
||||
if (col >= ncols) {
|
||||
const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
|
||||
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
|
||||
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
|
||||
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int r = y[row];
|
||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
|
||||
// copy x[r*ncols + col] to dst[row*ncols + col]
|
||||
const int xi = r*ncols + col;
|
||||
const int di = row*ncols + col;
|
||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
|
||||
const int ib = xi/qk; // block index
|
||||
const int iqs = (xi%qk)/qr; // quant index
|
||||
const int iybs = di - di%qk; // y block start index
|
||||
const int ib = i00/qk; // block index
|
||||
const int iqs = (i00%qk)/qr; // quant index
|
||||
const int iybs = i00 - i00%qk; // dst block start index
|
||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
// dequantize
|
||||
dfloat2 v;
|
||||
dequantize_kernel(x, ib, iqs, v);
|
||||
dequantize_kernel(src0_row, ib, iqs, v);
|
||||
|
||||
dst[iybs + iqs + 0] = v.x;
|
||||
dst[iybs + iqs + y_offset] = v.y;
|
||||
dst_row[iybs + iqs + 0] = v.x;
|
||||
dst_row[iybs + iqs + y_offset] = v.y;
|
||||
}
|
||||
|
||||
template<typename src0_t, typename dst_t>
|
||||
static __global__ void k_get_rows_float(
|
||||
const src0_t * src0, const int32_t * src1, dst_t * dst,
|
||||
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
|
||||
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
|
||||
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
|
||||
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
|
||||
size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
|
||||
|
||||
const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
|
||||
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
|
||||
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
|
||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
|
||||
|
||||
dst_row[i00] = src0_row[i00];
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
|
@ -5053,11 +5089,69 @@ static __global__ void im2col_f32_f16(
|
|||
}
|
||||
|
||||
template<int qk, int qr, dequantize_kernel_t dq>
|
||||
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
|
||||
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
||||
const int block_num_x = (ncols + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
|
||||
const dim3 block_nums(block_num_x, nrows, 1);
|
||||
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols);
|
||||
const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
|
||||
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
|
||||
|
||||
// strides in elements
|
||||
//const size_t s0 = nb0 / ggml_element_size(dst);
|
||||
const size_t s1 = nb1 / ggml_element_size(dst);
|
||||
const size_t s2 = nb2 / ggml_element_size(dst);
|
||||
const size_t s3 = nb3 / ggml_element_size(dst);
|
||||
|
||||
const size_t s10 = nb10 / ggml_element_size(src1);
|
||||
const size_t s11 = nb11 / ggml_element_size(src1);
|
||||
const size_t s12 = nb12 / ggml_element_size(src1);
|
||||
//const size_t s13 = nb13 / ggml_element_size(src1);
|
||||
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
|
||||
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_dd, src1_dd, dst_dd,
|
||||
ne00, /*ne01, ne02, ne03,*/
|
||||
/*ne10, ne11,*/ ne12, /*ne13,*/
|
||||
/* s0,*/ s1, s2, s3,
|
||||
/* nb00,*/ nb01, nb02, nb03,
|
||||
s10, s11, s12/*, s13*/);
|
||||
|
||||
(void) dst;
|
||||
}
|
||||
|
||||
template<typename src0_t>
|
||||
static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
||||
const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
|
||||
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
|
||||
|
||||
// strides in elements
|
||||
//const size_t s0 = nb0 / ggml_element_size(dst);
|
||||
const size_t s1 = nb1 / ggml_element_size(dst);
|
||||
const size_t s2 = nb2 / ggml_element_size(dst);
|
||||
const size_t s3 = nb3 / ggml_element_size(dst);
|
||||
|
||||
const size_t s10 = nb10 / ggml_element_size(src1);
|
||||
const size_t s11 = nb11 / ggml_element_size(src1);
|
||||
const size_t s12 = nb12 / ggml_element_size(src1);
|
||||
//const size_t s13 = nb13 / ggml_element_size(src1);
|
||||
|
||||
k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_dd, src1_dd, dst_dd,
|
||||
ne00, /*ne01, ne02, ne03,*/
|
||||
/*ne10, ne11,*/ ne12, /*ne13,*/
|
||||
/* s0,*/ s1, s2, s3,
|
||||
/* nb00,*/ nb01, nb02, nb03,
|
||||
s10, s11, s12/*, s13*/);
|
||||
|
||||
(void) dst;
|
||||
}
|
||||
|
||||
template<float (*bin_op)(const float, const float)>
|
||||
|
@ -5069,7 +5163,6 @@ struct bin_bcast_cuda {
|
|||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
|
||||
int nr0 = ne10/ne0;
|
||||
int nr1 = ne11/ne1;
|
||||
int nr2 = ne12/ne2;
|
||||
|
@ -5117,26 +5210,28 @@ struct bin_bcast_cuda {
|
|||
int64_t ne12 = cne1[2];
|
||||
int64_t ne13 = cne1[3];
|
||||
|
||||
//size_t nb0 = cnb0[0];
|
||||
size_t nb0 = cnb0[0];
|
||||
size_t nb1 = cnb0[1];
|
||||
size_t nb2 = cnb0[2];
|
||||
size_t nb3 = cnb0[3];
|
||||
|
||||
//size_t nb10 = cnb1[0];
|
||||
size_t nb10 = cnb1[0];
|
||||
size_t nb11 = cnb1[1];
|
||||
size_t nb12 = cnb1[2];
|
||||
size_t nb13 = cnb1[3];
|
||||
|
||||
//size_t s0 = nb0 / sizeof(src1_t);
|
||||
size_t s0 = nb0 / sizeof(src1_t);
|
||||
size_t s1 = nb1 / sizeof(src1_t);
|
||||
size_t s2 = nb2 / sizeof(src1_t);
|
||||
size_t s3 = nb3 / sizeof(src1_t);
|
||||
|
||||
//size_t s10 = nb10 / sizeof(src1_t);
|
||||
size_t s10 = nb10 / sizeof(src1_t);
|
||||
size_t s11 = nb11 / sizeof(src1_t);
|
||||
size_t s12 = nb12 / sizeof(src1_t);
|
||||
size_t s13 = nb13 / sizeof(src1_t);
|
||||
|
||||
GGML_ASSERT(s0 == 1);
|
||||
GGML_ASSERT(s10 == 1);
|
||||
|
||||
const int block_size = 128;
|
||||
|
||||
|
@ -6447,36 +6542,34 @@ static void ggml_cuda_op_get_rows(
|
|||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
const int ncols = src0->ne[0];
|
||||
const int nrows = ggml_nelements(src1);
|
||||
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
||||
GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
|
||||
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
|
||||
|
||||
const int32_t * src1_i32 = (const int32_t *) src1_d;
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F16:
|
||||
get_rows_cuda<1, 1, convert_f16>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
|
||||
get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_F32:
|
||||
get_rows_cuda<1, 1, convert_f32>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
|
||||
get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
|
||||
get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
|
||||
get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
|
||||
get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
|
||||
get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
|
||||
get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
break;
|
||||
default:
|
||||
// TODO: k-quants
|
||||
|
@ -8234,36 +8327,69 @@ static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
|
|||
}
|
||||
#endif
|
||||
|
||||
static void ggml_cuda_mul_mat_id(const ggml_tensor * _src0, const ggml_tensor * _src1, ggml_tensor * dst) {
|
||||
static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
#if 0
|
||||
//#ifdef CUDA_USE_TENSOR_CORES
|
||||
// const bool use_tensor_cores = true;
|
||||
//#else
|
||||
// const bool use_tensor_cores = false;
|
||||
//#endif
|
||||
|
||||
ggml_cuda_mul_mat_id_cublas(dst);
|
||||
|
||||
// TODO: mmq/mmv support
|
||||
#else
|
||||
const struct ggml_tensor * ids = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
const int id = dst->op_params[0];
|
||||
|
||||
int32_t * ids_dev = (int32_t *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
|
||||
|
||||
int32_t a_id;
|
||||
CUDA_CHECK(cudaMemcpyAsync(&a_id, ids_dev + id, sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
|
||||
CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
|
||||
|
||||
GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
|
||||
const struct ggml_tensor * src0 = dst->src[a_id + 2];
|
||||
|
||||
ggml_cuda_mul_mat(src0, src1, dst);
|
||||
#endif
|
||||
|
||||
(void) _src0;
|
||||
(void) _src1;
|
||||
GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
|
||||
|
||||
const struct ggml_tensor * ids = src0;
|
||||
const int32_t id = ((int32_t *) dst->op_params)[0];
|
||||
const int32_t n_as = ((int32_t *) dst->op_params)[1];
|
||||
|
||||
std::vector<char> ids_host(ggml_nbytes(ids));
|
||||
|
||||
if (ids->backend == GGML_BACKEND_GPU) {
|
||||
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
|
||||
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
|
||||
CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
|
||||
} else {
|
||||
memcpy(ids_host.data(), ids->data, ggml_nbytes(ids));
|
||||
}
|
||||
|
||||
const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra;
|
||||
const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra;
|
||||
|
||||
ggml_tensor_extra_gpu src1_row_extra;
|
||||
ggml_tensor_extra_gpu dst_row_extra;
|
||||
|
||||
ggml_tensor src1_row = *src1;
|
||||
ggml_tensor dst_row = *dst;
|
||||
|
||||
src1_row.ne[1] = 1;
|
||||
dst_row.ne[1] = 1;
|
||||
|
||||
src1_row.nb[2] = src1_row.nb[1];
|
||||
dst_row.nb[2] = dst_row.nb[1];
|
||||
|
||||
src1_row.nb[3] = src1_row.nb[1];
|
||||
dst_row.nb[3] = dst_row.nb[1];
|
||||
|
||||
src1_row.extra = &src1_row_extra;
|
||||
dst_row.extra = &dst_row_extra;
|
||||
|
||||
|
||||
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
||||
//int32_t row_id;
|
||||
//CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
|
||||
//CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
|
||||
|
||||
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
|
||||
|
||||
GGML_ASSERT(row_id >= 0 && row_id < n_as);
|
||||
|
||||
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
|
||||
|
||||
src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1];
|
||||
src1_row.data = (char *) src1->data + i01*src1->nb[1];
|
||||
|
||||
dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1];
|
||||
dst_row.data = (char *) dst->data + i01*dst->nb[1];
|
||||
|
||||
ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
|
@ -9181,6 +9307,45 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
|
|||
}
|
||||
return true;
|
||||
} break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_CPY:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
ggml_type src1_type = op->src[1]->type;
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
|
||||
return true;
|
||||
}
|
||||
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
} break;
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
|
@ -9188,7 +9353,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
|
|||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_GET_ROWS:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_MUL:
|
||||
|
@ -9197,7 +9361,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
|
|||
case GGML_OP_SCALE:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
|
@ -9264,7 +9427,9 @@ static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * use
|
|||
UNUSED(params);
|
||||
}
|
||||
|
||||
extern "C" int ggml_backend_cuda_reg_devices() {
|
||||
extern "C" int ggml_backend_cuda_reg_devices();
|
||||
|
||||
int ggml_backend_cuda_reg_devices() {
|
||||
int device_count = ggml_cuda_get_device_count();
|
||||
//int device_count = 1; // DEBUG: some tools require delaying CUDA initialization
|
||||
for (int i = 0; i < device_count; i++) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue