feat: Support Moore Threads GPU (#8383)
* Update doc for MUSA Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * Add GGML_MUSA in Makefile Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * Add GGML_MUSA in CMake Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * CUDA => MUSA Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * MUSA adds support for __vsubss4 Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * Fix CI build failure Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> --------- Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
This commit is contained in:
parent
5e2727fe03
commit
e54c35e4fb
9 changed files with 329 additions and 30 deletions
|
@ -167,7 +167,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|||
for (int id = 0; id < info.device_count; ++id) {
|
||||
int device_vmm = 0;
|
||||
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
||||
CUdevice device;
|
||||
CU_CHECK(cuDeviceGet(&device, id));
|
||||
CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
|
||||
|
@ -179,7 +179,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|||
alloc_prop.location.id = id;
|
||||
CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
|
||||
}
|
||||
#endif // !defined(GGML_USE_HIPBLAS)
|
||||
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
||||
info.devices[id].vmm = !!device_vmm;
|
||||
|
||||
cudaDeviceProp prop;
|
||||
|
@ -315,7 +315,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
|
|||
};
|
||||
|
||||
// pool with virtual memory
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
||||
struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
|
||||
|
||||
|
@ -409,14 +409,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
|||
GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
|
||||
}
|
||||
};
|
||||
#endif // !defined(GGML_USE_HIPBLAS)
|
||||
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
||||
|
||||
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
||||
if (ggml_cuda_info().devices[device].vmm) {
|
||||
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
|
||||
}
|
||||
#endif
|
||||
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
||||
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
|
||||
}
|
||||
|
||||
|
@ -1341,7 +1341,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
|
|||
static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
|
||||
void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {
|
||||
|
||||
#if !defined(GGML_USE_HIPBLAS)
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
||||
// cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
|
||||
cudaMemcpy3DPeerParms p = {};
|
||||
p.dstDevice = dstDevice;
|
||||
|
@ -1355,7 +1355,7 @@ static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
|
|||
GGML_UNUSED(dstDevice);
|
||||
GGML_UNUSED(srcDevice);
|
||||
return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream);
|
||||
#endif // !defined(GGML_USE_HIPBLAS)
|
||||
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
||||
}
|
||||
|
||||
static void ggml_cuda_op_mul_mat(
|
||||
|
@ -1828,6 +1828,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|||
}
|
||||
}
|
||||
#else
|
||||
#ifdef GGML_USE_MUSA
|
||||
GGML_ASSERT(false);
|
||||
#else // !GGML_USE_MUSA
|
||||
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
||||
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
||||
// use cublasGemmStridedBatchedEx
|
||||
|
@ -1870,6 +1873,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|||
cu_compute_type,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
}
|
||||
#endif // GGML_USE_MUSA
|
||||
#endif
|
||||
|
||||
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
|
||||
|
@ -3027,7 +3031,7 @@ GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size
|
|||
return false;
|
||||
}
|
||||
|
||||
#if CUDART_VERSION >= 11100
|
||||
#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
|
||||
cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
|
||||
if (err != cudaSuccess) {
|
||||
// clear the error
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue