CUDA: app option to compile without FlashAttention (#12025)
This commit is contained in:
parent
36c258ee92
commit
a28e0d5eb1
13 changed files with 46 additions and 31 deletions
|
|
@ -839,10 +839,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifndef NEW_MMA_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
|
|
@ -933,6 +930,9 @@ static __global__ void flash_attn_ext_f16(
|
|||
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template <int D, int ncols1, int ncols2>
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue