llama/ggml: add LLM training support (#10544)
* llama/ggml: add LLM training support more compact progress bar llama_save_model_to_file llama_opt_param_filter ggml_graph_dup force_grads refactor ggml_opt, fix test-opt * remove logits_all * refactor CUDA implementation for ACC * reset graph at beginning of opt period
This commit is contained in:
parent
064cc596ac
commit
10d2af0eaa
31 changed files with 1415 additions and 359 deletions
|
@ -359,7 +359,9 @@ llama_context::llama_context(
|
|||
}
|
||||
}
|
||||
|
||||
llama_context::~llama_context() = default;
|
||||
llama_context::~llama_context() {
|
||||
ggml_opt_free(opt_ctx);
|
||||
}
|
||||
|
||||
void llama_context::synchronize() {
|
||||
ggml_backend_sched_synchronize(sched.get());
|
||||
|
@ -1846,6 +1848,215 @@ void llama_context::perf_reset() {
|
|||
t_p_eval_us = n_p_eval = 0;
|
||||
}
|
||||
|
||||
//
|
||||
// training
|
||||
//
|
||||
|
||||
static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) {
|
||||
if (!tensor || tensor->type != GGML_TYPE_F32) {
|
||||
return;
|
||||
}
|
||||
if (!param_filter(tensor, userdata)) {
|
||||
return;
|
||||
}
|
||||
if (strcmp(tensor->name, "token_embd.weight") == 0) {
|
||||
return; // FIXME
|
||||
}
|
||||
if (strcmp(tensor->name, "rope_freqs.weight") == 0) {
|
||||
return; // FIXME
|
||||
}
|
||||
ggml_set_param(tensor);
|
||||
}
|
||||
|
||||
void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) {
|
||||
GGML_ASSERT(!opt_ctx);
|
||||
model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx();
|
||||
const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train);
|
||||
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
||||
GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0);
|
||||
GGML_ASSERT(n_batch % n_ubatch == 0);
|
||||
|
||||
ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY);
|
||||
opt_params.opt_period = n_batch / n_ubatch;
|
||||
opt_params.get_opt_pars = lopt_params.get_opt_pars;
|
||||
opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
|
||||
|
||||
opt_ctx = ggml_opt_init(opt_params);
|
||||
|
||||
llama_opt_param_filter param_filter = lopt_params.param_filter;
|
||||
void * param_filter_ud = lopt_params.param_filter_ud;
|
||||
|
||||
//llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME
|
||||
llama_set_param(model->type_embd, param_filter, param_filter_ud);
|
||||
llama_set_param(model->pos_embd, param_filter, param_filter_ud);
|
||||
llama_set_param(model->tok_norm, param_filter, param_filter_ud);
|
||||
llama_set_param(model->tok_norm_b, param_filter, param_filter_ud);
|
||||
llama_set_param(model->output_norm, param_filter, param_filter_ud);
|
||||
llama_set_param(model->output_norm_b, param_filter, param_filter_ud);
|
||||
llama_set_param(model->output, param_filter, param_filter_ud);
|
||||
llama_set_param(model->output_b, param_filter, param_filter_ud);
|
||||
llama_set_param(model->output_norm_enc, param_filter, param_filter_ud);
|
||||
llama_set_param(model->cls, param_filter, param_filter_ud);
|
||||
llama_set_param(model->cls_b, param_filter, param_filter_ud);
|
||||
llama_set_param(model->cls_out, param_filter, param_filter_ud);
|
||||
llama_set_param(model->cls_out_b, param_filter, param_filter_ud);
|
||||
|
||||
for (struct llama_layer & layer : model->layers) {
|
||||
for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
|
||||
llama_set_param(reinterpret_cast<struct ggml_tensor **>(&layer)[i], param_filter, param_filter_ud);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void llama_context::opt_epoch_iter(
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result,
|
||||
const std::vector<llama_token> & tokens,
|
||||
const std::vector<llama_token> & labels_sparse,
|
||||
llama_batch & batch,
|
||||
ggml_opt_epoch_callback callback,
|
||||
bool train,
|
||||
int64_t idata_in_loop,
|
||||
int64_t ndata_in_loop,
|
||||
int64_t t_loop_start) {
|
||||
GGML_ASSERT(opt_ctx);
|
||||
const uint32_t n_ctx = llama_model_n_ctx_train(&model);
|
||||
const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
|
||||
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
||||
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
kv_self->clear();
|
||||
llama_kv_cache_guard kv_guard(kv_self);
|
||||
|
||||
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
|
||||
batch.n_tokens = n_batch;
|
||||
for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) {
|
||||
batch.token [pos_batch] = tokens[pos_ctx + pos_batch];
|
||||
batch.pos [pos_batch] = pos_ctx + pos_batch;
|
||||
batch.n_seq_id[pos_batch] = 1;
|
||||
batch.seq_id [pos_batch][0] = 0;
|
||||
batch.logits [pos_batch] = true;
|
||||
}
|
||||
|
||||
const auto n_tokens_all = batch.n_tokens;
|
||||
|
||||
n_queued_tokens += n_tokens_all;
|
||||
|
||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
embd_seq.clear();
|
||||
|
||||
int64_t n_outputs_all = n_tokens_all;
|
||||
|
||||
llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
|
||||
|
||||
// reserve output buffer
|
||||
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
|
||||
GGML_ABORT("TODO: handle this error");
|
||||
};
|
||||
|
||||
for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
|
||||
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
|
||||
|
||||
n_outputs = ubatch.n_tokens;
|
||||
|
||||
// TODO: not sure if this is needed
|
||||
if (!kv_self->find_slot(ubatch)) {
|
||||
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
|
||||
|
||||
GGML_ABORT("TODO: handle this error");
|
||||
}
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
|
||||
|
||||
struct ggml_context * ctx_compute_opt;
|
||||
{
|
||||
const size_t size_gf = ggml_graph_size(gf);
|
||||
const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true);
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ size_meta,
|
||||
/*.mem_buffer =*/ nullptr,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
ctx_compute_opt = ggml_init(params);
|
||||
}
|
||||
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
|
||||
ggml_opt_alloc(opt_ctx, train);
|
||||
res->set_inputs(&ubatch);
|
||||
{
|
||||
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
|
||||
GGML_ASSERT(labels->ne[1] == n_ubatch);
|
||||
ggml_set_zero(labels);
|
||||
const float onef = 1.0f;
|
||||
for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) {
|
||||
const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch;
|
||||
GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]);
|
||||
ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float));
|
||||
}
|
||||
}
|
||||
ggml_opt_eval(opt_ctx, result);
|
||||
if (callback) {
|
||||
callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
|
||||
}
|
||||
ggml_free(ctx_compute_opt);
|
||||
}
|
||||
}
|
||||
|
||||
kv_guard.commit();
|
||||
}
|
||||
|
||||
void llama_context::opt_epoch(
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result_train,
|
||||
ggml_opt_result_t result_eval,
|
||||
int64_t idata_split,
|
||||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval) {
|
||||
const uint32_t n_ctx = this->n_ctx();
|
||||
const uint32_t n_batch = std::min(cparams.n_batch, n_ctx);
|
||||
const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch);
|
||||
const int64_t ndata = ggml_opt_dataset_ndata(dataset);
|
||||
|
||||
GGML_ASSERT(idata_split >= 0);
|
||||
GGML_ASSERT(idata_split <= ndata);
|
||||
|
||||
const uint32_t ubatch_per_ctx = n_ctx / n_ubatch;
|
||||
|
||||
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||
std::vector<llama_token> tokens(n_ctx);
|
||||
std::vector<llama_token> labels_sparse(n_ctx);
|
||||
|
||||
int64_t idata = 0;
|
||||
|
||||
int64_t t_loop_start = ggml_time_us();
|
||||
int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
|
||||
for (; idata < idata_split; ++idata) {
|
||||
constexpr bool train = true;
|
||||
const int64_t idata_in_loop = idata*ubatch_per_ctx;
|
||||
|
||||
ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
|
||||
opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch,
|
||||
callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
|
||||
}
|
||||
|
||||
t_loop_start = ggml_time_us();
|
||||
ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
|
||||
for (; idata < ndata; ++idata) {
|
||||
constexpr bool train = false;
|
||||
const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
|
||||
|
||||
ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
|
||||
opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch,
|
||||
callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
|
||||
//
|
||||
// interface implementation
|
||||
//
|
||||
|
@ -2464,3 +2675,34 @@ void llama_perf_context_print(const llama_context * ctx) {
|
|||
void llama_perf_context_reset(llama_context * ctx) {
|
||||
ctx->perf_reset();
|
||||
}
|
||||
|
||||
//
|
||||
// training
|
||||
//
|
||||
|
||||
bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) {
|
||||
GGML_UNUSED(tensor);
|
||||
GGML_UNUSED(userdata);
|
||||
return true;
|
||||
}
|
||||
|
||||
void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) {
|
||||
ctx->opt_init(model, lopt_params);
|
||||
}
|
||||
|
||||
void llama_opt_epoch(
|
||||
struct llama_context * ctx,
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result_train,
|
||||
ggml_opt_result_t result_eval,
|
||||
int64_t idata_split,
|
||||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval) {
|
||||
ctx->opt_epoch(
|
||||
dataset,
|
||||
result_train,
|
||||
result_eval,
|
||||
idata_split,
|
||||
callback_train,
|
||||
callback_eval);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue