metal : add BF16 support (#8439)
* ggml : add initial BF16 support ggml-ci * metal : add mul_mat_id BF16 support ggml-ci * metal : check for bfloat support on the Metal device ggml-ci * metal : better var names [no ci] * metal : do not build bfloat kernels when not supported ggml-ci * metal : try to fix BF16 support check ggml-ci * metal : this should correctly check bfloat support
This commit is contained in:
parent
b11f9ba9b8
commit
5c333e0140
4 changed files with 317 additions and 184 deletions
|
|
@ -12,6 +12,20 @@ using namespace metal;
|
|||
|
||||
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
||||
|
||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
||||
//
|
||||
// cmd:
|
||||
// .../usr/bin/metal -dM -E -c ggml/src/ggml-metal.metal
|
||||
// .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal.metal
|
||||
//
|
||||
#if __METAL_VERSION__ < 310
|
||||
#define GGML_METAL_NO_BFLOAT
|
||||
#endif
|
||||
|
||||
#if !defined(GGML_METAL_NO_BFLOAT)
|
||||
typedef matrix<bfloat, 4, 4> bfloat4x4;
|
||||
#endif
|
||||
|
||||
constexpr constant static float kvalues_iq4nl_f[16] = {
|
||||
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
||||
};
|
||||
|
|
@ -27,6 +41,13 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
|
|||
reg = (type4x4)(*src);
|
||||
}
|
||||
|
||||
#if !defined(GGML_METAL_NO_BFLOAT)
|
||||
template <typename type4x4>
|
||||
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
|
||||
reg = (type4x4)(*src);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
||||
|
|
@ -2041,6 +2062,10 @@ typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
|
|||
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
|
||||
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
|
||||
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
|
||||
#if !defined(GGML_METAL_NO_BFLOAT)
|
||||
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, float, float4>;
|
||||
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
|
||||
#endif
|
||||
|
||||
template<typename T, typename T4>
|
||||
kernel void kernel_mul_mv_1row(
|
||||
|
|
@ -2110,6 +2135,9 @@ kernel void kernel_mul_mv_1row(
|
|||
typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
|
||||
|
||||
template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
|
||||
#if !defined(GGML_METAL_NO_BFLOAT)
|
||||
template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<bfloat, bfloat4>;
|
||||
#endif
|
||||
|
||||
// Assumes row size (ne00) is a multiple of 4
|
||||
template<typename T, typename T4>
|
||||
|
|
@ -2169,6 +2197,9 @@ kernel void kernel_mul_mv_l4(
|
|||
typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
|
||||
|
||||
template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
|
||||
#if !defined(GGML_METAL_NO_BFLOAT)
|
||||
template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<bfloat, bfloat4>;
|
||||
#endif
|
||||
|
||||
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
||||
|
|
@ -3565,10 +3596,17 @@ kernel void kernel_cpy(
|
|||
|
||||
typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
|
||||
|
||||
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
|
||||
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
|
||||
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
|
||||
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
|
||||
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
|
||||
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
|
||||
#if !defined(GGML_METAL_NO_BFLOAT)
|
||||
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
|
||||
#endif
|
||||
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
|
||||
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
|
||||
#if !defined(GGML_METAL_NO_BFLOAT)
|
||||
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bfloat, float>;
|
||||
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
|
||||
#endif
|
||||
|
||||
kernel void kernel_cpy_f32_q8_0(
|
||||
device const float * src0,
|
||||
|
|
@ -6473,6 +6511,9 @@ typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
|
|||
|
||||
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
|
||||
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
|
||||
#if !defined(GGML_METAL_NO_BFLOAT)
|
||||
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat>;
|
||||
#endif
|
||||
|
||||
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
|
||||
|
||||
|
|
@ -6504,6 +6545,9 @@ typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, de
|
|||
|
||||
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
||||
#if !defined(GGML_METAL_NO_BFLOAT)
|
||||
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
||||
|
|
@ -6532,6 +6576,9 @@ typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
|
|||
|
||||
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
||||
#if !defined(GGML_METAL_NO_BFLOAT)
|
||||
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<bfloat4x4, 1, dequantize_bf16>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
|
||||
|
|
@ -6755,6 +6802,9 @@ typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float
|
|||
|
||||
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
|
||||
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
|
||||
#if !defined(GGML_METAL_NO_BFLOAT)
|
||||
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue