musa: fix all warnings, re-enable -DLLAMA_FATAL_WARNINGS=ON in ci and update doc (#12611)
* musa: fix all warnings Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * musa: enable -DLLAMA_FATAL_WARNINGS=ON in run.sh Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * musa: update ci doc (install ccache) Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * fix Windows build issue Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * Address review comments Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * Address review comments Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> --------- Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
This commit is contained in:
parent
d3f1f0acfb
commit
492d7f1ff7
20 changed files with 191 additions and 77 deletions
|
|
@ -34,6 +34,10 @@ static __global__ void conv_transpose_1d_kernel(
|
|||
}
|
||||
}
|
||||
dst[global_index] = accumulator;
|
||||
GGML_UNUSED(p0); GGML_UNUSED(d0); GGML_UNUSED(src0_ne3);
|
||||
GGML_UNUSED(src1_ne3); GGML_UNUSED(dst_ne3);
|
||||
GGML_UNUSED(src1_ne1); GGML_UNUSED(dst_ne1);
|
||||
GGML_UNUSED(src1_ne2); GGML_UNUSED(dst_ne2);
|
||||
}
|
||||
|
||||
static void conv_transpose_1d_f32_f32_cuda(
|
||||
|
|
@ -75,8 +79,6 @@ void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor
|
|||
const int p0 = 0;//opts[3];
|
||||
const int d0 = 1;//opts[4];
|
||||
|
||||
const int64_t kernel_size = ggml_nelements(src0);
|
||||
const int64_t input_size = ggml_nelements(src1);
|
||||
const int64_t output_size = ggml_nelements(dst);
|
||||
|
||||
conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue