CUDA: app option to compile without FlashAttention (#12025)
This commit is contained in:
parent
36c258ee92
commit
a28e0d5eb1
13 changed files with 46 additions and 31 deletions
|
@ -69,6 +69,10 @@ if (CUDAToolkit_FOUND)
|
|||
add_compile_definitions(GGML_CUDA_NO_VMM)
|
||||
endif()
|
||||
|
||||
if (NOT GGML_CUDA_FA)
|
||||
add_compile_definitions(GGML_CUDA_NO_FA)
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
|
||||
add_compile_definitions(GGML_CUDA_F16)
|
||||
endif()
|
||||
|
|
|
@ -204,9 +204,9 @@ typedef float2 dfloat2;
|
|||
#define CP_ASYNC_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
|
||||
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
#define FLASH_ATTN_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
|
||||
static bool fp16_available(const int cc) {
|
||||
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
|
||||
|
|
|
@ -839,10 +839,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifndef NEW_MMA_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
|
@ -933,6 +930,9 @@ static __global__ void flash_attn_ext_f16(
|
|||
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template <int D, int ncols1, int ncols2>
|
||||
|
|
|
@ -44,12 +44,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifdef FP16_AVAILABLE
|
||||
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
#ifdef FP16_MMA_AVAILABLE
|
||||
|
@ -290,7 +285,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FP16_AVAILABLE
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
}
|
||||
|
||||
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
|
||||
|
|
|
@ -44,10 +44,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
#ifdef FLASH_ATTN_AVAILABLE
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
#ifdef FP16_MMA_AVAILABLE
|
||||
|
@ -285,6 +282,9 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
}
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
||||
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
|
||||
|
|
|
@ -41,12 +41,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifdef FP16_AVAILABLE
|
||||
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
|
@ -300,7 +295,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FP16_AVAILABLE
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
}
|
||||
|
||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||
|
|
|
@ -41,10 +41,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
#ifdef FLASH_ATTN_AVAILABLE
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
|
@ -281,6 +278,9 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||
|
|
|
@ -51,7 +51,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
|
@ -425,7 +425,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
}
|
||||
|
||||
constexpr int get_max_power_of_2(int x) {
|
||||
|
|
|
@ -3203,7 +3203,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_FLASH_ATTN_EXT: {
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
return false;
|
||||
#endif
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
|
||||
return false;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue