ggml : change ggml_graph_compute() API to not require context (#1999)
* ggml_graph_compute: deprecate using ggml_context, try resolve issue #287 * rewrite: no longer consider backward compitability; plan and make_plan * minor: rename ctx as plan; const * remove ggml_graph_compute from tests/test-grad0.c, but current change breaks backward * add static ggml_graph_compute_sugar() * minor: update comments * reusable buffers * ggml : more consistent naming + metal fixes * ggml : fix docs * tests : disable grad / opt + minor naming changes * ggml : add ggml_graph_compute_with_ctx() - backwards compatible API - deduplicates a lot of copy-paste * ci : enable test-grad0 * examples : factor out plan allocation into a helper function * llama : factor out plan stuff into a helper function * ci : fix env * llama : fix duplicate symbols + refactor example benchmark * ggml : remove obsolete assert + refactor n_tasks section * ggml : fix indentation in switch * llama : avoid unnecessary bool * ggml : remove comments from source file and match order in header --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
7242140283
commit
1d656d6360
13 changed files with 571 additions and 449 deletions
|
@ -10,6 +10,8 @@
|
|||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
#pragma GCC diagnostic ignored "-Wdouble-promotion"
|
||||
|
||||
#define MAX_NARGS 3
|
||||
|
||||
#undef MIN
|
||||
|
@ -49,7 +51,7 @@ float frand(void) {
|
|||
|
||||
int irand(int n) {
|
||||
if (n == 0) return 0;
|
||||
else return rand()%n;
|
||||
return rand()%n;
|
||||
}
|
||||
|
||||
void get_random_dims(int64_t * dims, int ndims) {
|
||||
|
@ -159,12 +161,14 @@ struct ggml_tensor * get_random_tensor_int(
|
|||
float get_element(const struct ggml_tensor * t, int idx) {
|
||||
if (t->type == GGML_TYPE_F32) {
|
||||
return ((float *)t->data)[idx];
|
||||
} else if (t->type == GGML_TYPE_I32) {
|
||||
return ((int32_t *)t->data)[idx];
|
||||
} else {
|
||||
assert(false);
|
||||
return INFINITY;
|
||||
}
|
||||
|
||||
if (t->type == GGML_TYPE_I32) {
|
||||
return ((int32_t *)t->data)[idx];
|
||||
}
|
||||
|
||||
assert(false);
|
||||
return INFINITY;
|
||||
}
|
||||
|
||||
void set_element(struct ggml_tensor * t, int idx, float value) {
|
||||
|
@ -215,15 +219,14 @@ bool check_gradient(
|
|||
}
|
||||
|
||||
struct ggml_cgraph gf = ggml_build_forward (f);
|
||||
gf.n_threads = n_threads;
|
||||
|
||||
struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
|
||||
gb.n_threads = n_threads;
|
||||
|
||||
ggml_graph_compute(ctx0, &gf);
|
||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||
|
||||
ggml_graph_reset (&gf);
|
||||
ggml_set_f32 (f->grad, 1.0f);
|
||||
ggml_graph_compute(ctx0, &gb);
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
|
||||
|
||||
// ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot");
|
||||
// ggml_graph_dump_dot(&gb, &gf, "test-grad0-backward.dot");
|
||||
|
@ -236,15 +239,16 @@ bool check_gradient(
|
|||
const float xm = x0 - eps;
|
||||
const float xp = x0 + eps;
|
||||
set_element(x[i], k, xp);
|
||||
ggml_graph_compute(ctx0, &gf);
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||
|
||||
const float f0 = ggml_get_f32_1d(f, 0);
|
||||
|
||||
set_element(x[i], k, xm);
|
||||
ggml_graph_compute(ctx0, &gf);
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||
|
||||
const float f1 = ggml_get_f32_1d(f, 0);
|
||||
|
||||
const float g0 = (f0 - f1)/(2.0f*eps);
|
||||
|
||||
set_element(x[i], k, x0);
|
||||
|
@ -252,12 +256,13 @@ bool check_gradient(
|
|||
// compute gradient using backward graph
|
||||
ggml_graph_reset (&gf);
|
||||
ggml_set_f32 (f->grad, 1.0f);
|
||||
ggml_graph_compute(ctx0, &gb);
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
|
||||
|
||||
const float g1 = get_element(x[i]->grad, k);
|
||||
|
||||
const float error_abs = fabsf(g0 - g1);
|
||||
const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabs(g0) : 0;
|
||||
const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabsf(g0) : 0;
|
||||
|
||||
if (error_abs > max_error_abs || error_rel > max_error_rel) {
|
||||
printf("%s: ndims=%d, i=%d, k=%d, x0=%f, xm=%f, xp=%f, f0=%f, f1=%f, g0=%f, g1=%f, eps=%f, error_abs=%f, error_rel=%f\n",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue