metal : add BF16 support (#8439)

* ggml : add initial BF16 support

ggml-ci

* metal : add mul_mat_id BF16 support

ggml-ci

* metal : check for bfloat support on the Metal device

ggml-ci

* metal : better var names [no ci]

* metal : do not build bfloat kernels when not supported

ggml-ci

* metal : try to fix BF16 support check

ggml-ci

* metal : this should correctly check bfloat support
This commit is contained in:
Georgi Gerganov 2024-11-06 19:53:51 +02:00 committed by GitHub
parent b11f9ba9b8
commit 5c333e0140
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 317 additions and 184 deletions

View file

@ -12,6 +12,20 @@ using namespace metal;
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
//
// cmd:
// .../usr/bin/metal -dM -E -c ggml/src/ggml-metal.metal
// .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal.metal
//
#if __METAL_VERSION__ < 310
#define GGML_METAL_NO_BFLOAT
#endif
#if !defined(GGML_METAL_NO_BFLOAT)
typedef matrix<bfloat, 4, 4> bfloat4x4;
#endif
constexpr constant static float kvalues_iq4nl_f[16] = {
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
};
@ -27,6 +41,13 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
reg = (type4x4)(*src);
}
#if !defined(GGML_METAL_NO_BFLOAT)
template <typename type4x4>
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src);
}
#endif
template <typename type4x4>
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
@ -2041,6 +2062,10 @@ typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
#if !defined(GGML_METAL_NO_BFLOAT)
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, float, float4>;
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
#endif
template<typename T, typename T4>
kernel void kernel_mul_mv_1row(
@ -2110,6 +2135,9 @@ kernel void kernel_mul_mv_1row(
typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
#if !defined(GGML_METAL_NO_BFLOAT)
template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<bfloat, bfloat4>;
#endif
// Assumes row size (ne00) is a multiple of 4
template<typename T, typename T4>
@ -2169,6 +2197,9 @@ kernel void kernel_mul_mv_l4(
typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
#if !defined(GGML_METAL_NO_BFLOAT)
template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<bfloat, bfloat4>;
#endif
static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
@ -3565,10 +3596,17 @@ kernel void kernel_cpy(
typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
#if !defined(GGML_METAL_NO_BFLOAT)
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
#endif
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
#if !defined(GGML_METAL_NO_BFLOAT)
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bfloat, float>;
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
#endif
kernel void kernel_cpy_f32_q8_0(
device const float * src0,
@ -6473,6 +6511,9 @@ typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
#if !defined(GGML_METAL_NO_BFLOAT)
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat>;
#endif
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
@ -6504,6 +6545,9 @@ typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, de
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
#if !defined(GGML_METAL_NO_BFLOAT)
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
#endif
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
@ -6532,6 +6576,9 @@ typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
#if !defined(GGML_METAL_NO_BFLOAT)
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<bfloat4x4, 1, dequantize_bf16>;
#endif
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
@ -6755,6 +6802,9 @@ typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
#if !defined(GGML_METAL_NO_BFLOAT)
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
#endif
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;