CPU/CUDA: fix (GQA) mul mat back, add CUDA support (#11380)
This commit is contained in:
parent
1af6945eb0
commit
8137b4bb2b
7 changed files with 156 additions and 61 deletions
|
|
@ -34,6 +34,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
|
||||
CUBLAS_CHECK(cublasSetStream(handle, stream));
|
||||
|
||||
const int64_t lda = nb01 / sizeof(float);
|
||||
const int64_t ldc = nb1 / sizeof(float);
|
||||
|
||||
const bool src1_T = ggml_is_transposed(src1);
|
||||
const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T;
|
||||
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
|
||||
|
|
@ -57,9 +60,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
CUBLAS_CHECK(
|
||||
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
ne0, ne1, ne01,
|
||||
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, ne00,
|
||||
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
|
||||
src1_d + i3 *s13 + i2 *s12, ldb,
|
||||
&beta, dst_d + i3 *s3 + i2 *s2, ne0));
|
||||
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue