metal : fix build and some more comments (#10229)
This commit is contained in:
parent
bb38cdd8ba
commit
39a334a9aa
2 changed files with 6 additions and 4 deletions
|
|
@ -3356,8 +3356,8 @@ kernel void kernel_flash_attn_ext_vec(
|
|||
const short D4 = D/4;
|
||||
const short D16 = D/16;
|
||||
const short NW = N_SIMDWIDTH;
|
||||
const short NL = NW/4;
|
||||
const short SH = 2*C; // shared memory per simdgroup
|
||||
const short NL = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
|
||||
const short SH = 2*C; // shared memory per simdgroup
|
||||
|
||||
const short T = D + nsg*SH; // shared memory size per query in (half)
|
||||
|
||||
|
|
@ -3448,7 +3448,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|||
|
||||
// Q*K^T
|
||||
{
|
||||
// each simdgroup processes 1 query and 4 keys
|
||||
// each simdgroup processes 1 query and 4 (NW/NL) keys
|
||||
for (short cc = 0; cc < C/4; ++cc) {
|
||||
qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
|
||||
|
||||
|
|
@ -3646,7 +3646,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|||
half, half4, half4x4, \
|
||||
half4x4
|
||||
|
||||
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
|
||||
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>) flash_attn_ext_vec_t;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue