SYCL : Move to compile time oneMKL interface backend selection for NVIDIA backend (#10584)
* [SYCL] Move to Compile Time backend selection on oneMKL Interface for NVIDIA backend Move to compile time selection to backend to avoid latency at run time. Add it to all mkl gemm calls and only for NVIDIA backend. Signed-off-by: nscipione <nicolo.scipione@codeplay.com> * Formatting * Address PR comments to increase readibility --------- Signed-off-by: nscipione <nicolo.scipione@codeplay.com>
This commit is contained in:
parent
98036d5670
commit
40c6d79fb5
4 changed files with 50 additions and 25 deletions
|
@ -1689,9 +1689,14 @@ namespace dpct
|
|||
auto data_a = get_memory<const Ta>(a);
|
||||
auto data_b = get_memory<const Tb>(b);
|
||||
auto data_c = get_memory<Tc>(c);
|
||||
oneapi::mkl::blas::column_major::gemm(
|
||||
q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
|
||||
data_b, ldb, beta_value, data_c, ldc);
|
||||
#ifdef GGML_SYCL_NVIDIA
|
||||
oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
|
||||
a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
|
||||
beta_value, data_c, ldc);
|
||||
#else
|
||||
oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
|
||||
beta_value, data_c, ldc);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename VecT, class BinaryOperation, class = void>
|
||||
|
@ -1754,14 +1759,22 @@ namespace dpct
|
|||
matrix_info->ld_info[2] = ldc;
|
||||
matrix_info->groupsize_info = batch_size;
|
||||
|
||||
#ifdef GGML_SYCL_NVIDIA
|
||||
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
||||
q, matrix_info->transpose_info, matrix_info->transpose_info + 1,
|
||||
matrix_info->size_info, matrix_info->size_info + 1,
|
||||
matrix_info->size_info + 2, matrix_info->value_info,
|
||||
reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
|
||||
reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
||||
matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
|
||||
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
|
||||
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
|
||||
matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a),
|
||||
matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
||||
matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
|
||||
&(matrix_info->groupsize_info));
|
||||
#else
|
||||
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
||||
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
|
||||
matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info,
|
||||
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
||||
matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
|
||||
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
||||
#endif
|
||||
|
||||
q.submit([&](sycl::handler &cgh)
|
||||
{
|
||||
|
@ -1783,10 +1796,16 @@ namespace dpct
|
|||
auto data_a = get_memory<const Ta>(a);
|
||||
auto data_b = get_memory<const Tb>(b);
|
||||
auto data_c = get_memory<Tc>(c);
|
||||
#ifdef GGML_SYCL_NVIDIA
|
||||
oneapi::mkl::blas::column_major::gemm_batch(
|
||||
q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
|
||||
stride_a, data_b, ldb, stride_b, beta_value,
|
||||
data_c, ldc, stride_c, batch_size);
|
||||
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, a_trans, b_trans, m, n, k,
|
||||
alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
|
||||
batch_size);
|
||||
#else
|
||||
oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
|
||||
stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
|
||||
stride_c, batch_size);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue