HIP: enable vec fattn on RDNA4 (#14323)

This commit is contained in:
uvos 2025-06-22 16:51:23 +02:00 committed by GitHub
parent 5d5c066de8
commit af3373f1ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 14 additions and 7 deletions

View file

@ -241,8 +241,18 @@ static bool fp16_mma_available(const int cc) {
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
return false; return false;
#else #else
return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) || if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc); GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
return true;
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
return true;
#else
return false;
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
} else {
return false;
}
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
} }

View file

@ -100,8 +100,7 @@ int ggml_cuda_get_device() {
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
ggml_cuda_set_device(device); ggml_cuda_set_device(device);
cudaError_t err; cudaError_t err;
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) {
{
err = cudaMallocManaged(ptr, size); err = cudaMallocManaged(ptr, size);
#if defined(GGML_USE_HIP) #if defined(GGML_USE_HIP)
if (err == hipSuccess) { if (err == hipSuccess) {
@ -119,9 +118,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
err = cudaMalloc(ptr, size); err = cudaMalloc(ptr, size);
} }
#endif // defined(GGML_USE_HIP) #endif // defined(GGML_USE_HIP)
} } else {
else
{
err = cudaMalloc(ptr, size); err = cudaMalloc(ptr, size);
} }
return err; return err;