cuda : fix dmmv cols requirement to 2*GGML_CUDA_DMMV_X (#8800)

* cuda : fix dmmv cols requirement to 2*GGML_CUDA_DMMV_X

* update asserts

* only use dmmv for supported types

* add test
This commit is contained in:
slaren 2024-08-01 15:26:22 +02:00 committed by GitHub
parent c8a0090922
commit 7a11eb3a26
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 22 additions and 11 deletions

View file

@ -1885,10 +1885,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer);
bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
bool use_dequantize_mul_mat_vec = ggml_cuda_dmmv_type_supported(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[0] >= GGML_CUDA_DMMV_X*2
&& src1->ne[1] == 1;
&& src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1;
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;