CUDA: use async data loading for FlashAttention (#11894)

* CUDA: use async data loading for FlashAttention

---------

Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
Johannes Gäßler 2025-02-17 14:03:24 +01:00 committed by GitHub
parent f7b1116af1
commit 73e2ed3ce3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 744 additions and 739 deletions

View file

@ -1,7 +1,252 @@
#include "common.cuh"
#include "cp-async.cuh"
#include "mma.cuh"
#include "fattn-common.cuh"
using namespace ggml_cuda_mma;
typedef tile<16, 8, half2> tile_A;
typedef tile< 8, 8, half2> tile_B;
typedef tile<16, 8, float> tile_C_KQ;
typedef tile<16, 4, half2> tile_C_VKQ;
template<int D, int nwarps, int KQ_stride>
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) {
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
// If cp.async is available, load up to the highest power of 2 in D asynchronously:
#ifdef CP_ASYNC_AVAILABLE
static_assert(D >= 64 && D < 512, "bad D");
constexpr int k0_sync_start = D/2 < 64 ? 32 : (D/2 < 128 ? 64 : 128);
const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV);
constexpr int preload = 64;
constexpr int h2_per_chunk = 16/sizeof(half2);
constexpr int chunks_per_row = k0_sync_start / h2_per_chunk;
constexpr int stride_i = WARP_SIZE / chunks_per_row;
#pragma unroll
for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) {
const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row);
const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk;
cp_async_cg_16<preload>(tile_KV_32 + (i*D2_padded + k)*sizeof(half2), KV + i*stride_KV + k);
}
#else
constexpr int k0_sync_start = 0;
#endif // CP_ASYNC_AVAILABLE
static_assert(k0_sync_start % WARP_SIZE == 0, "bad k0_sync_start");
// If D is not a power of 2, the rest is loaded synchronously.
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
#pragma unroll
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k);
const int k0_stop = D/2 - (D/2) % (1*stride_k);
const int stride_i = WARP_SIZE / stride_k;
if (k0_start == k0_stop || k0_stop <= k0_sync_start) {
continue;
}
#pragma unroll
for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) {
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
tile_KV[i*D2_padded + k] = KV[i*stride_KV + k];
}
}
}
}
template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const float2 * const __restrict__ Q_f2,
const half2 * const __restrict__ K_h2,
const half2 * const __restrict__ V_h2,
const half * const __restrict__ maskh,
float2 * const __restrict__ dstk,
float2 * const __restrict__ dstk_fixup,
const float scale,
const float slope,
const float logit_softcap,
const int ne01,
const int ne02,
const int stride_Q,
const int stride_KV,
const int stride_mask,
const int jt,
half2 * const __restrict__ tile_K,
half2 * const __restrict__ tile_V,
const tile_B * const __restrict__ Q_B,
tile_C_VKQ * const __restrict__ VKQ_C,
float2 & KQ_max,
float2 & KQ_rowsum,
const int kb0) {
#ifdef NEW_MMA_AVAILABLE
constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column.
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
const int k_VKQ_0 = kb0*KQ_stride;
tile_C_KQ KQ_C[KQ_stride/(np*tile_C_KQ::I)];
#ifdef CP_ASYNC_AVAILABLE
cp_async_wait_all();
__syncthreads();
flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
#else
flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV);
__syncthreads();
#endif // CP_ASYNC_AVAILABLE
// Calculate tile of KQ:
#pragma unroll
for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*tile_A::I) {
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) {
tile_A K_A;
load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, ((tile_B *) Q_B)[k_KQ_0/tile_A::J]);
}
}
#ifndef CP_ASYNC_AVAILABLE
__syncthreads(); // Only needed if tile_K == tile_V.
#endif // CP_ASYNC_AVAILABLE
if (use_logit_softcap) {
static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
#pragma unroll
for (int i = 0; i < KQ_stride/(np*tile_C_KQ::I); ++i) {
#pragma unroll
for (int l = 0; l < tile_C_KQ::ne; ++l) {
KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
}
}
}
if (maskh) {
static_assert(KQ_stride % (np *tile_C_KQ::I) == 0, "bad loop size");
static_assert(ncols % (nwarps/np*tile_C_KQ::J) == 0, "bad loop size");
#pragma unroll
for (int i00 = 0; i00 < KQ_stride; i00 += np*tile_C_KQ::I) {
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
#pragma unroll
for (int l = 0; l < tile_C_KQ::ne; ++l) {
const int i = i0 + tile_C_KQ::get_i(l);
const int j = (threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l);
KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]);
}
}
}
// Calculate softmax for each KQ column using the current max. value.
// The divisor is stored in KQ_rowsum and will be applied at the end.
float2 KQ_max_new = KQ_max;
static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
#pragma unroll
for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) {
#pragma unroll
for (int l0 = 0; l0 < tile_C_KQ::ne; l0 += 2) {
KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]);
KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]);
}
}
// Values per KQ column are spread across 8 threads, does not need full warp reduce:
#pragma unroll
for (int offset = 16; offset > 2; offset >>= 1) {
KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE));
KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE));
}
float2 KQ_rowsum_add = make_float2(0.0f, 0.0f);
static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
#pragma unroll
for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) {
#pragma unroll
for (int l = 0; l < tile_C_KQ::ne; ++l) {
const float KQ_max_l = l % 2 == 0 ? KQ_max_new.x : KQ_max_new.y;
const float diff = KQ_C[k].x[l] - KQ_max_l;
KQ_C[k].x[l] = expf(diff);
if (l % 2 == 0) {
KQ_rowsum_add.x += KQ_C[k].x[l];
} else {
KQ_rowsum_add.y += KQ_C[k].x[l];
}
}
}
{
const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y);
const float2 KQ_max_scale = make_float2(expf(diff.x), expf(diff.y));
KQ_max = KQ_max_new;
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x;
KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y;
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y);
#pragma unroll
for (int i = 0; i < D/tile_C_VKQ::I; ++i) {
#pragma unroll
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
VKQ_C[i].x[l] *= KQ_max_scale_h2;
}
}
}
// Convert KQ C tiles into B tiles for VKQ calculation:
tile_B B[KQ_stride/(np*2*tile_B::J)];
static_assert(KQ_stride % (np*2*tile_B::J) == 0, "bad loop size");
#pragma unroll
for (int k = 0; k < KQ_stride/(np*2*tile_B::J); ++k) {
B[k] = get_transposed(get_half2(KQ_C[k]));
}
#ifdef CP_ASYNC_AVAILABLE
cp_async_wait_all();
__syncthreads();
if (!last_iter) {
flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + (k_VKQ_0 + KQ_stride)*stride_KV, tile_K, stride_KV);
}
#else
flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
__syncthreads();
#endif // CP_ASYNC_AVAILABLE
// Calculate VKQ tile:
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) {
static_assert((KQ_stride/2) % (np*tile_A::J) == 0, "bad loop size");
#pragma unroll
for (int k00 = 0; k00 < KQ_stride/2; k00 += np*tile_A::J) {
const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
tile_A A;
load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
}
}
#ifndef CP_ASYNC_AVAILABLE
__syncthreads(); // Only needed if tile_K == tile_V.
#endif // CP_ASYNC_AVAILABLE
#else
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const float2 * const __restrict__ Q_f2,
@ -13,61 +258,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const float scale,
const float slope,
const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3,
const int stride_Q,
const int stride_KV,
const int stride_mask,
const int jt,
const int kb0_start,
const int kb0_stop) {
#ifdef NEW_MMA_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
typedef mma_A_I16K8<half2> mma_A;
typedef mma_B_J8K8<half2> mma_B;
typedef mma_C_I16J8<float> mma_C_KQ;
typedef mma_C_I16J8<half2> mma_C_VKQ;
static_assert(nwarps*mma_B::J % ncols == 0, "bad nwarps");
constexpr int np = nwarps*mma_B::J / ncols; // Number of parallel CUDA warps per Q column.
static_assert(nwarps*tile_B::I % ncols == 0, "bad nwarps");
constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column.
static_assert(D % nwarps == 0, "bad D");
static_assert(KQ_stride % nwarps == 0, "bad KQ_stride");
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
extern __shared__ half2 tile_KV[]; // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements.
const int stride_Q = nb01 / sizeof(float2);
const int stride_KV = nb11 / sizeof(half2);
const int stride_mask = nb31 / sizeof(half);
// Temporary shared buffer for loading K/V data with KQ_stride*D logical elements:
extern __shared__ half2 tile_K[];
#ifdef CP_ASYNC_AVAILABLE
half2 * tile_V = tile_K + KQ_stride*D2_padded;
#else
half2 * tile_V = tile_K;
#endif // CP_ASYNC_AVAILABLE
mma_B Q_B[D/(2*mma_B::K)];
mma_C_VKQ VKQ_C[D/mma_C_VKQ::I];
tile_B Q_B[D/(2*tile_B::J)];
tile_C_VKQ VKQ_C[D/tile_C_VKQ::I];
float2 KQ_rowsum = {0.0f, 0.0f};
float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f};
float2 KQ_max_scale = {0.0f, 0.0f};
float2 KQ_rowsum = {0.0f, 0.0f};
float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f};
// Temporarily load Q data into tile_KV, will be loaded into registers afterwards.
// Temporarily load Q data into tile_K, will be loaded into registers afterwards.
// The loading is done with decreasing granularity for D for better memory bandwidth.
const half2 scale_h2 = make_half2(scale, scale);
#pragma unroll
@ -76,6 +300,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int k0_stop = D/2 - (D/2) % (1*stride_k);
const int stride_j = WARP_SIZE / stride_k;
if (k0_start == k0_stop) {
continue;
}
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
break;
}
@ -90,14 +318,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k];
tile_KV[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
tile_K[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
}
} else {
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
tile_KV[j*D2_padded + k] = make_half2(0.0f, 0.0f);
tile_K[j*D2_padded + k] = make_half2(0.0f, 0.0f);
}
}
}
@ -106,198 +334,42 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
__syncthreads();
{
const int j0 = (threadIdx.y / np) * mma_B::J;
const int j0 = (threadIdx.y / np) * tile_B::I;
#pragma unroll
for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
Q_B[k0/mma_B::K].load_ldmatrix(tile_KV + j0*D2_padded + k0, D2_padded);
for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded);
}
}
__syncthreads();
// Preload K data for first iteration when using cp_async:
#ifdef CP_ASYNC_AVAILABLE
flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + kb0_start*KQ_stride*stride_KV, tile_K, stride_KV);
#endif // CP_ASYNC_AVAILABLE
// Iterate over ne11 == previous tokens:
for (int kb0 = kb0_start; kb0 < kb0_stop; ++kb0) {
const int k_VKQ_0 = kb0*KQ_stride;
mma_C_KQ KQ_C[KQ_stride/(np*mma_C_KQ::I)];
// Load K data into tile with decreasing granularity for D for better memory bandwidth:
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
#pragma unroll
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
const int k0_stop = D/2 - (D/2) % (1*stride_k);
const int stride_i = WARP_SIZE / stride_k;
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < KQ_stride; i_KQ_0 += nwarps*stride_i) {
const int i_KQ = i_KQ_0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
#pragma unroll
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += stride_k) {
const int k_KQ = k_KQ_0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
tile_KV[i_KQ*D2_padded + k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV + k_KQ];
}
}
}
__syncthreads();
// Calculate tile of KQ:
#pragma unroll
for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*mma_A::I) {
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*mma_A::I;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += mma_A::K) {
mma_A K_A;
K_A.load_ldmatrix(tile_KV + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
KQ_C[i_KQ_00/(np*mma_A::I)].mma(K_A, Q_B[k_KQ_0/mma_A::K]);
}
}
__syncthreads();
if (use_logit_softcap) {
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
#pragma unroll
for (int i = 0; i < KQ_stride/(np*mma_C_KQ::I); ++i) {
#pragma unroll
for (int l = 0; l < mma_C_KQ::ne; ++l) {
KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
}
}
}
if (maskh) {
static_assert(KQ_stride % (np *mma_C_KQ::I) == 0, "bad loop size");
static_assert(ncols % (nwarps/np*mma_C_KQ::J) == 0, "bad loop size");
#pragma unroll
for (int i00 = 0; i00 < KQ_stride; i00 += np*mma_C_KQ::I) {
const int i0 = i00 + (threadIdx.y % np)*mma_C_KQ::I;
#pragma unroll
for (int l = 0; l < mma_C_KQ::ne; ++l) {
const int i = i0 + mma_C_KQ::get_i(l);
const int j = (threadIdx.y / np)*mma_C_KQ::J + mma_C_KQ::get_j(l);
KQ_C[i00/(np*mma_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]);
}
}
}
// Calculate softmax for each KQ column using the current max. value.
// The divisor is stored in KQ_rowsum and will be applied at the end.
float2 KQ_max_new = KQ_max;
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
#pragma unroll
for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
#pragma unroll
for (int l0 = 0; l0 < mma_C_KQ::ne; l0 += 2) {
KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]);
KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]);
}
}
// Values per KQ column are spread across 8 threads, does not need full warp reduce:
#pragma unroll
for (int offset = 16; offset > 2; offset >>= 1) {
KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE));
KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE));
}
{
const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y);
KQ_max_scale = make_float2(expf(diff.x), expf(diff.y));
if (diff.x <= SOFTMAX_FTZ_THRESHOLD) {
KQ_max_scale.x = 0.0f;
}
if (diff.y <= SOFTMAX_FTZ_THRESHOLD) {
KQ_max_scale.y = 0.0f;
}
KQ_max = KQ_max_new;
}
float2 KQ_rowsum_add = make_float2(0.0f, 0.0f);
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
#pragma unroll
for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
#pragma unroll
for (int l = 0; l < mma_C_KQ::ne; ++l) {
const float KQ_max_l = l % 2 == 0 ? KQ_max.x : KQ_max.y;
const float diff = KQ_C[k].x[l] - KQ_max_l;
KQ_C[k].x[l] = expf(diff);
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
KQ_C[k].x[l] = 0.0f;
}
if (l % 2 == 0) {
KQ_rowsum_add.x += KQ_C[k].x[l];
} else {
KQ_rowsum_add.y += KQ_C[k].x[l];
}
}
}
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x;
KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y;
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y);
#pragma unroll
for (int i = 0; i < D/mma_C_VKQ::I; ++i) {
#pragma unroll
for (int l = 0; l < mma_C_VKQ::ne; ++l) {
VKQ_C[i].x[l] *= KQ_max_scale_h2;
}
}
// Convert KQ C tiles into B tiles for VKQ calculation:
mma_B B[KQ_stride/(np*2*mma_B::K)];
static_assert(KQ_stride % (np*2*mma_B::K) == 0, "bad loop size");
#pragma unroll
for (int k = 0; k < KQ_stride/(np*2*mma_B::K); ++k) {
B[k] = KQ_C[k].to_mma_B();
}
// Load V data into tile with decreasing granularity for D for better memory bandwidth:
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
#pragma unroll
for (int stride_i : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
const int i0_start = stride_i == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_i);
const int i0_stop = D/2 - (D/2) % (1*stride_i);
const int stride_k = WARP_SIZE / stride_i;
#pragma unroll
for (int k_V_0 = 0; k_V_0 < KQ_stride; k_V_0 += nwarps*stride_k) {
const int k_V = k_V_0 + threadIdx.y*stride_k + (stride_i == WARP_SIZE ? 0 : threadIdx.x / stride_i);
#pragma unroll
for (int i_V_0 = i0_start; i_V_0 < i0_stop; i_V_0 += stride_i) {
const int i_V = i_V_0 + (stride_i == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_i);
tile_KV[k_V*D2_padded + i_V] = V_h2[(k_VKQ_0 + k_V)*stride_KV + i_V];
}
}
}
__syncthreads();
// Calculate VKQ tile:
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += mma_C_VKQ::I) {
static_assert((KQ_stride/2) % (np*mma_A::K) == 0, "bad loop size");
#pragma unroll
for (int k00 = 0; k00 < KQ_stride/2; k00 += np*mma_A::K) {
const int k0 = k00 + (threadIdx.y % np)*mma_A::K;
mma_A A;
A.load_ldmatrix_trans(tile_KV + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
VKQ_C[i_VKQ_0/mma_C_VKQ::I].mma(A, B[k00/(np*mma_A::K)]);
}
}
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
constexpr bool last_iter = false;
flash_attn_ext_f16_iter<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
}
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
constexpr bool last_iter = true;
flash_attn_ext_f16_iter<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
}
// With cp_async there is no __syncthreads at the end of the iter,
// there can be a race condition on shared memory access for combining/writing back results.
#ifdef CP_ASYNC_AVAILABLE
if (nwarps*tile_B::I > KQ_stride) {
__syncthreads();
}
#endif // CP_ASYNC_AVAILABLE
// Finally, sum up partial KQ rowsums.
// The partial sums are spread across 8 threads each, does not need full reduce.
@ -310,26 +382,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
// Write VKQ accumulators to shared memory in column-major format.
// It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
// Also for np > 1 the combination is done via these values in shared memory.
const int j_cwd = threadIdx.y*mma_B::J + mma_B::get_j(-1); // j combine write data
const int j_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // j combine write data
#pragma unroll
for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
const mma_B B = VKQ_C[k0/mma_B::K].to_mma_B(); // Conversion of C to B matrix puts it in column-major format.
for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
#pragma unroll
for (int l = 0; l < mma_B::ne; ++l) {
const int k = k0 + mma_B::get_k(l);
for (int l = 0; l < tile_B::ne; ++l) {
const int k = k0 + tile_B::get_j(l);
tile_KV[j_cwd*D2_padded + k] = B.x[l];
tile_K[j_cwd*D2_padded + k] = B.x[l];
}
}
const int j_cwmo = (threadIdx.x % (2*mma_C_VKQ::J)) / mma_C_VKQ::J; // j combine write meta offset
const int j_cwm = threadIdx.y*(2*mma_C_VKQ::J) + 2*mma_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta
const int j_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // j combine write meta offset
const int j_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta
const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum
if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*mma_C_VKQ::J) {
if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
// Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
((float2 *) tile_KV)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr;
((float2 *) tile_K)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr;
}
__syncthreads();
@ -337,11 +409,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
static_assert(np == 1 || np == 2 || np == 4, "bad np");
if (np == 1) {
// No combination is needed, the meta data can be directly written from registers to VRAM.
if (needs_fixup && threadIdx.x < mma_B::J) {
if (needs_fixup && threadIdx.x < tile_B::I) {
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
dstk_fixup_meta[j_cwm] = KQ_cmr;
}
if (is_fixup && threadIdx.x < mma_B::J) {
if (is_fixup && threadIdx.x < tile_B::I) {
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
dstk_fixup_meta[j_cwm] = KQ_cmr;
}
@ -350,42 +422,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
// Warps with threadIdx.y % np != 0 must NOT return early.
// All threads must return simultaneously to avoid race conditions with work on the next tile.
float * meta_j = (float *) tile_KV + (threadIdx.y*mma_B::J + threadIdx.x)*D2_padded + D/2;
float * meta_j = (float *) tile_K + (threadIdx.y*tile_B::I + threadIdx.x)*D2_padded + D/2;
float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp.
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
KQ_cm = meta_j[0];
}
float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps.
#pragma unroll
for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) {
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
}
const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp.
float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps.
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
KQ_crs = KQ_cms*meta_j[1];
}
#pragma unroll
for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) {
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
}
// Write back combined meta data:
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
meta_j[0] = KQ_cmn; // Combined max. KQ values.
meta_j[1] = KQ_crs; // Combined KQ rowsums.
meta_j[2] = KQ_cms; // KQ max scales per parallel warp.
if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
*((float2 *) meta_j) = make_float2(KQ_cms, KQ_crs); // Combined KQ max scale + rowsum.
}
if (needs_fixup && threadIdx.x < mma_B::J) {
if (needs_fixup && threadIdx.x < tile_B::I) {
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
}
if (is_fixup && threadIdx.x < mma_B::J) {
if (is_fixup && threadIdx.x < tile_B::I) {
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
}
}
@ -404,6 +474,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int k0_stop = D/2 - (D/2) % (1*stride_k);
const int stride_j = WARP_SIZE / stride_k;
if (k0_start == k0_stop) {
continue;
}
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
break;
}
@ -411,12 +485,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
#pragma unroll
for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) {
const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
const int j_tile_KV = (j_dst/mma_B::J)*(np*mma_B::J) + j_dst % mma_B::J;
const int j_tile_K = (j_dst/tile_B::I)*(np*tile_B::I) + j_dst % tile_B::I;
if (!is_fixup && jt*ncols + j_dst >= ne01) {
continue;
}
const float * meta_j = (const float *) tile_KV + j_tile_KV*D2_padded + D/2;
const float * meta_j = (const float *) tile_K + j_tile_K*D2_padded + D/2;
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
@ -424,8 +498,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
float2 dstk_val = make_float2(0.0f, 0.0f);
#pragma unroll
for (int ip = 0; ip < np; ++ip) {
const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*mma_B::J*D2_padded + 2];
const float2 dstk_val_add = __half22float2(tile_KV[(j_tile_KV + ip*mma_B::J)*D2_padded + k]);
const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*tile_B::I*D2_padded + 0];
const float2 dstk_val_add = __half22float2(tile_K[(j_tile_K + ip*tile_B::I)*D2_padded + k]);
dstk_val.x += dstk_val_add.x*KQ_crs;
dstk_val.y += dstk_val_add.y*KQ_crs;
}
@ -450,7 +524,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
__syncthreads();
}
#else
NO_DEVICE_CODE;
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
@ -494,6 +568,11 @@ static __global__ void flash_attn_ext_f16(
const int ne1,
const int ne2,
const int ne3) {
#ifndef NEW_MMA_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // NEW_MMA_AVAILABLE
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
@ -504,6 +583,10 @@ static __global__ void flash_attn_ext_f16(
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const int stride_Q = nb01 / sizeof(float2);
const int stride_KV = nb11 / sizeof(half2);
const int stride_mask = nb31 / sizeof(half);
const int iter_k = ne11 / KQ_stride;
const int iter_j = (ne01 + (ncols - 1)) / ncols;
@ -535,14 +618,12 @@ static __global__ void flash_attn_ext_f16(
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
jt, kb0_start, kb0_stop);
ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
} else {
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
jt, kb0_start, kb0_stop);
ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
}
kbc += iter_k;
@ -571,24 +652,27 @@ static __global__ void flash_attn_ext_f16(
constexpr bool needs_fixup = false;
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
jt, kb0_start, kb0_stop);
ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
}
template <int D, int cols_per_block>
void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
typedef mma_A_I16K8<half2> mma_A;
typedef mma_B_J8K8<half2> mma_B;
typedef tile<16, 8, half2> tile_A;
typedef tile< 8, 8, half2> tile_B;
static_assert(D % mma_B::K == 0, "bad D");
static_assert(cols_per_block % mma_B::J == 0, "bad cols_per_block");
static_assert(D % tile_B::J == 0, "bad D");
static_assert(cols_per_block % tile_B::I == 0, "bad cols_per_block");
const ggml_tensor * KQV = dst;
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
constexpr int KQ_stride = D <= 128 ? 64 : 32;
constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ?
cols_per_block/mma_B::J * KQ_stride/mma_A::I : (cols_per_block <= 8 ? 4 : 8);
constexpr size_t nbytes_shared = std::max(KQ_stride, nwarps*mma_B::J) * (D + 8) * sizeof(half);
constexpr int KQ_stride = D <= 128 ? 64 : 32;
constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ?
cols_per_block/tile_B::J * KQ_stride/tile_A::I : (cols_per_block <= 8 ? 4 : 8);
const int nrows_KQ = cp_async_available(cc) ? 2*KQ_stride : KQ_stride;
const int nrows_combine = nwarps*tile_B::J;
const size_t nbytes_shared = std::max(nrows_KQ, nrows_combine) * (D + 8) * sizeof(half);
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));