SYCL: Reduce most of the compiler warnings (#10748)
* Try to reduce some unused and typecast warnings * Reduce compiler warnings step 2 * add a newline at the end of the file * Initialize nreduce as size_t * [SYCL] Remove pragma directives from mmq.cpp * SYCL: mmq add condition to prevent blocks_per_tile_x_row variable from becoming 0 * SYCL softmax: Initialize nreduce as size_t * ggml-sycl.cpp: fix some trailing whitespaces * SYCL: remove the unused variables instead of commenting it out * SYCL poo2d kernel: set NAN for invalid pooling op * SYCL gemm.hpp: remove pragma directives * SYCL gemm.hpp: use const cast to properly support dnnl::memory * SYCL: wkv6 remove a comment * SYCL: clean comments step 2 * SYCL: clean comments and variables step 3 * SYCL: Use GGML_UNUSED for unused variables * SYCL: remove extra empty lines and a comment * Remove TODO * cleanup spaces * add a stdout for unsupported op * use sycl printf over fprintf * remove prints for CI * SYCL ggml-sycl: pool2D use sycl::nan and remove if-else block --------- Co-authored-by: Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com>
This commit is contained in:
parent
d583cd03f6
commit
83ed24a97b
17 changed files with 205 additions and 187 deletions
|
@ -813,7 +813,7 @@ load_tiles_q4_K(const void *__restrict__ vx, int *__restrict__ x_ql,
|
|||
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
|
||||
}
|
||||
|
||||
const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
|
||||
constexpr int blocks_per_tile_x_row = QI4_K > WARP_SIZE ? 1 : WARP_SIZE / QI4_K; // == 1 if QK_K == 256
|
||||
const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
|
||||
|
||||
#pragma unroll
|
||||
|
@ -961,7 +961,7 @@ load_tiles_q5_K(const void *__restrict__ vx, int *__restrict__ x_ql,
|
|||
x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
|
||||
}
|
||||
|
||||
const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
|
||||
constexpr int blocks_per_tile_x_row = QI5_K > WARP_SIZE ? 1 : WARP_SIZE / QI5_K; // == 1 if QK_K == 256
|
||||
const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
|
||||
|
||||
#pragma unroll
|
||||
|
@ -1109,7 +1109,7 @@ load_tiles_q6_K(const void *__restrict__ vx, int *__restrict__ x_ql,
|
|||
dpct::sub_sat());
|
||||
}
|
||||
|
||||
const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
|
||||
constexpr int blocks_per_tile_x_row = QI6_K > WARP_SIZE ? 1 : WARP_SIZE / QI6_K; // == 1 if QK_K == 256
|
||||
const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
|
||||
float * x_dmf = (float *) x_dm;
|
||||
|
||||
|
@ -3020,9 +3020,9 @@ void ggml_sycl_op_mul_mat_q(
|
|||
break;
|
||||
}
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
(void) src1_ddf_i;
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_ddf_i);
|
||||
}
|
||||
catch (sycl::exception const &exc) {
|
||||
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue