CANN: Refactor to reduce duplicate code (#12731)

* CANN: Refactor to reduce duplicate code

* CANN: fix review comment
This commit is contained in:
hipudding 2025-04-07 17:10:36 +08:00 committed by GitHub
parent 916c83bfe7
commit d0d5b2232b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 482 additions and 1245 deletions

View file

@ -1300,47 +1300,59 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
ggml_cann_dup(ctx, dst);
break;
case GGML_OP_ADD:
ggml_cann_add(ctx, dst);
case GGML_OP_ADD1:
ggml_cann_binary_op<aclnn_add>(ctx, dst);
break;
case GGML_OP_SUB:
ggml_cann_binary_op<aclnn_sub>(ctx, dst);
break;
case GGML_OP_ACC:
ggml_cann_acc(ctx, dst);
break;
case GGML_OP_MUL:
ggml_cann_mul_div<aclnnMulGetWorkspaceSize, aclnnMul>(ctx, dst);
ggml_cann_binary_op<aclnn_mul>(ctx, dst);
break;
case GGML_OP_DIV:
ggml_cann_mul_div<aclnnDivGetWorkspaceSize, aclnnDiv>(ctx, dst);
ggml_cann_binary_op<aclnn_div>(ctx, dst);
break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(dst)) {
case GGML_UNARY_OP_ABS:
GGML_CANN_CALL_UNARY_OP(Abs);
break;
case GGML_UNARY_OP_NEG:
GGML_CANN_CALL_UNARY_OP(Neg);
break;
case GGML_UNARY_OP_GELU:
ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
ctx, dst);
GGML_CANN_CALL_UNARY_OP(Gelu);
break;
case GGML_UNARY_OP_SILU:
ggml_cann_activation<aclnnSiluGetWorkspaceSize, aclnnSilu>(
ctx, dst);
GGML_CANN_CALL_UNARY_OP(Silu);
break;
// TODO: Use faster gelu??
case GGML_UNARY_OP_GELU_QUICK:
ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
ctx, dst);
case GGML_UNARY_OP_GELU_QUICK: {
auto lambda = [](auto ctx, auto acl_src, auto acl_dst) {
GGML_CANN_CALL_ACLNN_OP(GeluV2, acl_src, 0, acl_dst);
};
ggml_cann_unary_op<lambda>(ctx, dst);
}
break;
case GGML_UNARY_OP_TANH:
ggml_cann_activation<aclnnTanhGetWorkspaceSize, aclnnTanh>(
ctx, dst);
GGML_CANN_CALL_UNARY_OP(Tanh);
break;
case GGML_UNARY_OP_RELU:
ggml_cann_activation<aclnnReluGetWorkspaceSize, aclnnRelu>(
ctx, dst);
GGML_CANN_CALL_UNARY_OP(Relu);
break;
case GGML_UNARY_OP_SIGMOID:
GGML_CANN_CALL_UNARY_OP(Sigmoid);
break;
case GGML_UNARY_OP_HARDSIGMOID:
ggml_cann_activation<aclnnHardsigmoidGetWorkspaceSize,
aclnnHardsigmoid>(ctx, dst);
GGML_CANN_CALL_UNARY_OP(Hardsigmoid);
break;
case GGML_UNARY_OP_HARDSWISH:
ggml_cann_activation<aclnnHardswishGetWorkspaceSize,
aclnnHardswish>(ctx, dst);
GGML_CANN_CALL_UNARY_OP(Hardswish);
break;
case GGML_UNARY_OP_EXP:
GGML_CANN_CALL_UNARY_OP(Exp);
break;
default:
return false;
@ -1382,7 +1394,12 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
ggml_cann_scale(ctx, dst);
break;
case GGML_OP_SQR:
ggml_cann_sqr(ctx, dst);
GGML_ASSERT(dst->src[1] == nullptr);
dst->src[1] = dst->src[0];
ggml_cann_binary_op<aclnn_mul>(ctx, dst);
break;
case GGML_OP_SQRT:
GGML_CANN_CALL_UNARY_OP(Sqrt);
break;
case GGML_OP_CLAMP:
ggml_cann_clamp(ctx, dst);
@ -1414,6 +1431,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
case GGML_OP_POOL_2D:
ggml_cann_pool2d(ctx, dst);
break;
case GGML_OP_SUM:
ggml_cann_sum(ctx, dst);
break;
case GGML_OP_SUM_ROWS:
ggml_cann_sum_rows(ctx, dst);
break;
@ -1424,11 +1444,11 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
ggml_cann_argmax(ctx, dst);
break;
case GGML_OP_COS:
ggml_cann_cos(ctx, dst);
ggml_cann_unary_op<aclnn_cos>(ctx, dst);
break;
case GGML_OP_SIN:
ggml_cann_sin(ctx, dst);
break;
ggml_cann_unary_op<aclnn_sin>(ctx, dst);
break;
default:
return false;
}
@ -1679,13 +1699,17 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
switch (op->op) {
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_ABS:
case GGML_UNARY_OP_NEG:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_EXP:
return true;
default:
return false;
@ -1784,6 +1808,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
// value of paddingW should be at most half of kernelW
return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
}
case GGML_OP_SUM:
case GGML_OP_DUP:
case GGML_OP_IM2COL:
case GGML_OP_CONCAT:
@ -1795,11 +1820,14 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_OP_TRANSPOSE:
case GGML_OP_NORM:
case GGML_OP_ADD:
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_RMS_NORM:
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_CLAMP:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX: