llama/ggml: add LLM training support (#10544)
* llama/ggml: add LLM training support more compact progress bar llama_save_model_to_file llama_opt_param_filter ggml_graph_dup force_grads refactor ggml_opt, fix test-opt * remove logits_all * refactor CUDA implementation for ACC * reset graph at beginning of opt period
This commit is contained in:
parent
064cc596ac
commit
10d2af0eaa
31 changed files with 1415 additions and 359 deletions
|
@ -23,6 +23,7 @@ add_library(llama
|
|||
llama-memory.cpp
|
||||
llama-mmap.cpp
|
||||
llama-model-loader.cpp
|
||||
llama-model-saver.cpp
|
||||
llama-model.cpp
|
||||
llama-quant.cpp
|
||||
llama-sampling.cpp
|
||||
|
|
|
@ -359,7 +359,9 @@ llama_context::llama_context(
|
|||
}
|
||||
}
|
||||
|
||||
llama_context::~llama_context() = default;
|
||||
llama_context::~llama_context() {
|
||||
ggml_opt_free(opt_ctx);
|
||||
}
|
||||
|
||||
void llama_context::synchronize() {
|
||||
ggml_backend_sched_synchronize(sched.get());
|
||||
|
@ -1846,6 +1848,215 @@ void llama_context::perf_reset() {
|
|||
t_p_eval_us = n_p_eval = 0;
|
||||
}
|
||||
|
||||
//
|
||||
// training
|
||||
//
|
||||
|
||||
static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) {
|
||||
if (!tensor || tensor->type != GGML_TYPE_F32) {
|
||||
return;
|
||||
}
|
||||
if (!param_filter(tensor, userdata)) {
|
||||
return;
|
||||
}
|
||||
if (strcmp(tensor->name, "token_embd.weight") == 0) {
|
||||
return; // FIXME
|
||||
}
|
||||
if (strcmp(tensor->name, "rope_freqs.weight") == 0) {
|
||||
return; // FIXME
|
||||
}
|
||||
ggml_set_param(tensor);
|
||||
}
|
||||
|
||||
void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) {
|
||||
GGML_ASSERT(!opt_ctx);
|
||||
model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx();
|
||||
const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train);
|
||||
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
||||
GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0);
|
||||
GGML_ASSERT(n_batch % n_ubatch == 0);
|
||||
|
||||
ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY);
|
||||
opt_params.opt_period = n_batch / n_ubatch;
|
||||
opt_params.get_opt_pars = lopt_params.get_opt_pars;
|
||||
opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
|
||||
|
||||
opt_ctx = ggml_opt_init(opt_params);
|
||||
|
||||
llama_opt_param_filter param_filter = lopt_params.param_filter;
|
||||
void * param_filter_ud = lopt_params.param_filter_ud;
|
||||
|
||||
//llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME
|
||||
llama_set_param(model->type_embd, param_filter, param_filter_ud);
|
||||
llama_set_param(model->pos_embd, param_filter, param_filter_ud);
|
||||
llama_set_param(model->tok_norm, param_filter, param_filter_ud);
|
||||
llama_set_param(model->tok_norm_b, param_filter, param_filter_ud);
|
||||
llama_set_param(model->output_norm, param_filter, param_filter_ud);
|
||||
llama_set_param(model->output_norm_b, param_filter, param_filter_ud);
|
||||
llama_set_param(model->output, param_filter, param_filter_ud);
|
||||
llama_set_param(model->output_b, param_filter, param_filter_ud);
|
||||
llama_set_param(model->output_norm_enc, param_filter, param_filter_ud);
|
||||
llama_set_param(model->cls, param_filter, param_filter_ud);
|
||||
llama_set_param(model->cls_b, param_filter, param_filter_ud);
|
||||
llama_set_param(model->cls_out, param_filter, param_filter_ud);
|
||||
llama_set_param(model->cls_out_b, param_filter, param_filter_ud);
|
||||
|
||||
for (struct llama_layer & layer : model->layers) {
|
||||
for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
|
||||
llama_set_param(reinterpret_cast<struct ggml_tensor **>(&layer)[i], param_filter, param_filter_ud);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void llama_context::opt_epoch_iter(
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result,
|
||||
const std::vector<llama_token> & tokens,
|
||||
const std::vector<llama_token> & labels_sparse,
|
||||
llama_batch & batch,
|
||||
ggml_opt_epoch_callback callback,
|
||||
bool train,
|
||||
int64_t idata_in_loop,
|
||||
int64_t ndata_in_loop,
|
||||
int64_t t_loop_start) {
|
||||
GGML_ASSERT(opt_ctx);
|
||||
const uint32_t n_ctx = llama_model_n_ctx_train(&model);
|
||||
const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
|
||||
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
||||
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
kv_self->clear();
|
||||
llama_kv_cache_guard kv_guard(kv_self);
|
||||
|
||||
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
|
||||
batch.n_tokens = n_batch;
|
||||
for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) {
|
||||
batch.token [pos_batch] = tokens[pos_ctx + pos_batch];
|
||||
batch.pos [pos_batch] = pos_ctx + pos_batch;
|
||||
batch.n_seq_id[pos_batch] = 1;
|
||||
batch.seq_id [pos_batch][0] = 0;
|
||||
batch.logits [pos_batch] = true;
|
||||
}
|
||||
|
||||
const auto n_tokens_all = batch.n_tokens;
|
||||
|
||||
n_queued_tokens += n_tokens_all;
|
||||
|
||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
embd_seq.clear();
|
||||
|
||||
int64_t n_outputs_all = n_tokens_all;
|
||||
|
||||
llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
|
||||
|
||||
// reserve output buffer
|
||||
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
|
||||
GGML_ABORT("TODO: handle this error");
|
||||
};
|
||||
|
||||
for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
|
||||
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
|
||||
|
||||
n_outputs = ubatch.n_tokens;
|
||||
|
||||
// TODO: not sure if this is needed
|
||||
if (!kv_self->find_slot(ubatch)) {
|
||||
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
|
||||
|
||||
GGML_ABORT("TODO: handle this error");
|
||||
}
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
|
||||
|
||||
struct ggml_context * ctx_compute_opt;
|
||||
{
|
||||
const size_t size_gf = ggml_graph_size(gf);
|
||||
const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true);
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ size_meta,
|
||||
/*.mem_buffer =*/ nullptr,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
ctx_compute_opt = ggml_init(params);
|
||||
}
|
||||
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
|
||||
ggml_opt_alloc(opt_ctx, train);
|
||||
res->set_inputs(&ubatch);
|
||||
{
|
||||
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
|
||||
GGML_ASSERT(labels->ne[1] == n_ubatch);
|
||||
ggml_set_zero(labels);
|
||||
const float onef = 1.0f;
|
||||
for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) {
|
||||
const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch;
|
||||
GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]);
|
||||
ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float));
|
||||
}
|
||||
}
|
||||
ggml_opt_eval(opt_ctx, result);
|
||||
if (callback) {
|
||||
callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
|
||||
}
|
||||
ggml_free(ctx_compute_opt);
|
||||
}
|
||||
}
|
||||
|
||||
kv_guard.commit();
|
||||
}
|
||||
|
||||
void llama_context::opt_epoch(
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result_train,
|
||||
ggml_opt_result_t result_eval,
|
||||
int64_t idata_split,
|
||||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval) {
|
||||
const uint32_t n_ctx = this->n_ctx();
|
||||
const uint32_t n_batch = std::min(cparams.n_batch, n_ctx);
|
||||
const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch);
|
||||
const int64_t ndata = ggml_opt_dataset_ndata(dataset);
|
||||
|
||||
GGML_ASSERT(idata_split >= 0);
|
||||
GGML_ASSERT(idata_split <= ndata);
|
||||
|
||||
const uint32_t ubatch_per_ctx = n_ctx / n_ubatch;
|
||||
|
||||
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||
std::vector<llama_token> tokens(n_ctx);
|
||||
std::vector<llama_token> labels_sparse(n_ctx);
|
||||
|
||||
int64_t idata = 0;
|
||||
|
||||
int64_t t_loop_start = ggml_time_us();
|
||||
int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
|
||||
for (; idata < idata_split; ++idata) {
|
||||
constexpr bool train = true;
|
||||
const int64_t idata_in_loop = idata*ubatch_per_ctx;
|
||||
|
||||
ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
|
||||
opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch,
|
||||
callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
|
||||
}
|
||||
|
||||
t_loop_start = ggml_time_us();
|
||||
ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
|
||||
for (; idata < ndata; ++idata) {
|
||||
constexpr bool train = false;
|
||||
const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
|
||||
|
||||
ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
|
||||
opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch,
|
||||
callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
|
||||
//
|
||||
// interface implementation
|
||||
//
|
||||
|
@ -2464,3 +2675,34 @@ void llama_perf_context_print(const llama_context * ctx) {
|
|||
void llama_perf_context_reset(llama_context * ctx) {
|
||||
ctx->perf_reset();
|
||||
}
|
||||
|
||||
//
|
||||
// training
|
||||
//
|
||||
|
||||
bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) {
|
||||
GGML_UNUSED(tensor);
|
||||
GGML_UNUSED(userdata);
|
||||
return true;
|
||||
}
|
||||
|
||||
void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) {
|
||||
ctx->opt_init(model, lopt_params);
|
||||
}
|
||||
|
||||
void llama_opt_epoch(
|
||||
struct llama_context * ctx,
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result_train,
|
||||
ggml_opt_result_t result_eval,
|
||||
int64_t idata_split,
|
||||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval) {
|
||||
ctx->opt_epoch(
|
||||
dataset,
|
||||
result_train,
|
||||
result_eval,
|
||||
idata_split,
|
||||
callback_train,
|
||||
callback_eval);
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
#include "llama-adapter.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
#include "ggml-opt.h"
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
@ -133,6 +134,32 @@ struct llama_context {
|
|||
llama_perf_context_data perf_get_data() const;
|
||||
void perf_reset();
|
||||
|
||||
//
|
||||
// training
|
||||
//
|
||||
|
||||
void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
|
||||
|
||||
void opt_epoch(
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result_train,
|
||||
ggml_opt_result_t result_eval,
|
||||
int64_t idata_split,
|
||||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval);
|
||||
|
||||
void opt_epoch_iter(
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result,
|
||||
const std::vector<llama_token> & tokens,
|
||||
const std::vector<llama_token> & labels_sparse,
|
||||
llama_batch & batch,
|
||||
ggml_opt_epoch_callback callback,
|
||||
bool train,
|
||||
int64_t idata_in_loop,
|
||||
int64_t ndata_in_loop,
|
||||
int64_t t_loop_start);
|
||||
|
||||
private:
|
||||
//
|
||||
// output
|
||||
|
@ -212,6 +239,9 @@ private:
|
|||
|
||||
ggml_context_ptr ctx_compute;
|
||||
|
||||
// training
|
||||
ggml_opt_context_t opt_ctx = nullptr;
|
||||
|
||||
ggml_threadpool_t threadpool = nullptr;
|
||||
ggml_threadpool_t threadpool_batch = nullptr;
|
||||
|
||||
|
|
|
@ -971,6 +971,7 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
|||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
//cb(inp->tokens, "inp_tokens", -1);
|
||||
ggml_set_input(inp->tokens);
|
||||
res->t_tokens = inp->tokens;
|
||||
|
||||
cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
|
||||
|
||||
|
|
|
@ -298,6 +298,7 @@ class llm_graph_result_i {
|
|||
public:
|
||||
virtual ~llm_graph_result_i() = default;
|
||||
|
||||
virtual ggml_tensor * get_tokens() = 0;
|
||||
virtual ggml_tensor * get_logits() = 0;
|
||||
virtual ggml_tensor * get_embd() = 0;
|
||||
virtual ggml_tensor * get_embd_pooled() = 0;
|
||||
|
@ -312,6 +313,7 @@ class llm_graph_result : public llm_graph_result_i {
|
|||
public:
|
||||
virtual ~llm_graph_result() = default;
|
||||
|
||||
ggml_tensor * get_tokens() override { return t_tokens; }
|
||||
ggml_tensor * get_logits() override { return t_logits; }
|
||||
ggml_tensor * get_embd() override { return t_embd; }
|
||||
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
|
||||
|
@ -328,6 +330,7 @@ public:
|
|||
}
|
||||
|
||||
// important graph nodes
|
||||
ggml_tensor * t_tokens = nullptr;
|
||||
ggml_tensor * t_logits = nullptr;
|
||||
ggml_tensor * t_embd = nullptr;
|
||||
ggml_tensor * t_embd_pooled = nullptr;
|
||||
|
|
|
@ -301,12 +301,12 @@ namespace GGUFMeta {
|
|||
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
|
||||
|
||||
switch (arr_info.gt) {
|
||||
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
|
||||
case GGUF_TYPE_INT32: GGML_ASSERT(
|
||||
(std::is_same<T, int32_t>::value) ||
|
||||
(std::is_same<T, uint32_t>::value)); break;
|
||||
case GGUF_TYPE_UINT32:
|
||||
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
|
||||
(std::is_same<T, uint32_t>::value)); break;
|
||||
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
|
||||
default:
|
||||
throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
|
||||
throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
|
||||
}
|
||||
|
||||
result.resize(arr_info.length);
|
||||
|
@ -330,12 +330,12 @@ namespace GGUFMeta {
|
|||
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
|
||||
|
||||
switch (arr_info.gt) {
|
||||
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
|
||||
case GGUF_TYPE_INT32: GGML_ASSERT(
|
||||
(std::is_same<T, int32_t>::value) ||
|
||||
(std::is_same<T, uint32_t>::value)); break;
|
||||
case GGUF_TYPE_UINT32:
|
||||
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
|
||||
(std::is_same<T, uint32_t>::value)); break;
|
||||
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
|
||||
default:
|
||||
throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
|
||||
throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
|
||||
}
|
||||
|
||||
if (arr_info.length > N_MAX) {
|
||||
|
|
281
src/llama-model-saver.cpp
Normal file
281
src/llama-model-saver.cpp
Normal file
|
@ -0,0 +1,281 @@
|
|||
#include "llama-model-saver.h"
|
||||
|
||||
#include "gguf.h"
|
||||
|
||||
#include "llama.h"
|
||||
#include "llama-hparams.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-vocab.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
llama_model_saver::llama_model_saver(const struct llama_model & model) : model(model), llm_kv(model.arch) {
|
||||
gguf_ctx = gguf_init_empty();
|
||||
}
|
||||
|
||||
llama_model_saver::~llama_model_saver() {
|
||||
gguf_free(gguf_ctx);
|
||||
}
|
||||
|
||||
void llama_model_saver::add_kv(const enum llm_kv key, const uint32_t value) {
|
||||
gguf_set_val_u32(gguf_ctx, llm_kv(key).c_str(), value);
|
||||
}
|
||||
|
||||
void llama_model_saver::add_kv(const enum llm_kv key, const int32_t value) {
|
||||
gguf_set_val_i32(gguf_ctx, llm_kv(key).c_str(), value);
|
||||
}
|
||||
|
||||
void llama_model_saver::add_kv(const enum llm_kv key, const float value) {
|
||||
gguf_set_val_f32(gguf_ctx, llm_kv(key).c_str(), value);
|
||||
}
|
||||
|
||||
void llama_model_saver::add_kv(const enum llm_kv key, const bool value) {
|
||||
gguf_set_val_bool(gguf_ctx, llm_kv(key).c_str(), value);
|
||||
}
|
||||
|
||||
void llama_model_saver::add_kv(const enum llm_kv key, const char * value) {
|
||||
gguf_set_val_str(gguf_ctx, llm_kv(key).c_str(), value);
|
||||
}
|
||||
|
||||
[[noreturn]]
|
||||
void llama_model_saver::add_kv(const enum llm_kv key, const char value) {
|
||||
GGML_UNUSED(key);
|
||||
GGML_UNUSED(value);
|
||||
GGML_ABORT("fatal error"); // this should never be called, only needed to make the template below compile
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, const bool per_layer) {
|
||||
const size_t n_values = per_layer ? size_t(model.hparams.n_layer) : value.size();
|
||||
GGML_ASSERT(n_values <= value.size());
|
||||
|
||||
if (n_values == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (per_layer) {
|
||||
bool all_values_the_same = true;
|
||||
for (size_t i = 1; i < n_values; ++i) {
|
||||
if (value[i] != value[0]) {
|
||||
all_values_the_same = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (all_values_the_same) {
|
||||
add_kv(key, value[0]);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (std::is_same<typename Container::value_type, uint8_t>::value) {
|
||||
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT8, value.data(), n_values);
|
||||
} else if (std::is_same<typename Container::value_type, int8_t>::value) {
|
||||
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT8, value.data(), n_values);
|
||||
} else if (std::is_same<typename Container::value_type, uint32_t>::value) {
|
||||
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT32, value.data(), n_values);
|
||||
} else if (std::is_same<typename Container::value_type, int32_t>::value) {
|
||||
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT32, value.data(), n_values);
|
||||
} else if (std::is_same<typename Container::value_type, float>::value) {
|
||||
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_FLOAT32, value.data(), n_values);
|
||||
} else if (std::is_same<Container, std::string>::value) {
|
||||
gguf_set_val_str(gguf_ctx, llm_kv(key).c_str(), reinterpret_cast<const char *>(value.data()));
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
void llama_model_saver::add_kv(const enum llm_kv key, const std::vector<std::string> & value) {
|
||||
std::vector<const char *> tmp(value.size());
|
||||
for (size_t i = 0; i < value.size(); ++i) {
|
||||
tmp[i] = value[i].c_str();
|
||||
}
|
||||
gguf_set_arr_str(gguf_ctx, llm_kv(key).c_str(), tmp.data(), tmp.size());
|
||||
}
|
||||
|
||||
void llama_model_saver::add_tensor(const struct ggml_tensor * tensor) {
|
||||
if (!tensor) {
|
||||
return;
|
||||
}
|
||||
if (gguf_find_tensor(gguf_ctx, tensor->name) >= 0) {
|
||||
GGML_ASSERT(std::string(tensor->name) == "rope_freqs.weight"); // FIXME
|
||||
return;
|
||||
}
|
||||
gguf_add_tensor(gguf_ctx, tensor);
|
||||
}
|
||||
|
||||
void llama_model_saver::add_kv_from_model() {
|
||||
const llama_hparams & hparams = model.hparams;
|
||||
const llama_vocab & vocab = model.vocab;
|
||||
|
||||
const int32_t n_vocab = vocab.n_tokens();
|
||||
std::vector<std::string> tokens(n_vocab);
|
||||
std::vector<float> scores(n_vocab);
|
||||
std::vector<int32_t> token_types(n_vocab);
|
||||
|
||||
for (int32_t id = 0; id < n_vocab; ++id) {
|
||||
const llama_vocab::token_data & token_data = vocab.get_token_data(id);
|
||||
|
||||
tokens[id] = token_data.text;
|
||||
scores[id] = token_data.score;
|
||||
|
||||
switch(token_data.attr) {
|
||||
case LLAMA_TOKEN_ATTR_UNKNOWN: token_types[id] = LLAMA_TOKEN_TYPE_UNKNOWN; break;
|
||||
case LLAMA_TOKEN_ATTR_UNUSED: token_types[id] = LLAMA_TOKEN_TYPE_UNUSED; break;
|
||||
case LLAMA_TOKEN_ATTR_NORMAL: token_types[id] = LLAMA_TOKEN_TYPE_NORMAL; break;
|
||||
case LLAMA_TOKEN_ATTR_CONTROL: token_types[id] = LLAMA_TOKEN_TYPE_CONTROL; break;
|
||||
case LLAMA_TOKEN_ATTR_USER_DEFINED: token_types[id] = LLAMA_TOKEN_TYPE_USER_DEFINED; break;
|
||||
case LLAMA_TOKEN_ATTR_BYTE: token_types[id] = LLAMA_TOKEN_TYPE_BYTE; break;
|
||||
case LLAMA_TOKEN_ATTR_UNDEFINED:
|
||||
default: token_types[id] = LLAMA_TOKEN_TYPE_UNDEFINED; break;
|
||||
}
|
||||
}
|
||||
|
||||
// add_kv(LLM_KV_GENERAL_TYPE, ???);
|
||||
add_kv(LLM_KV_GENERAL_ARCHITECTURE, model.arch_name());
|
||||
// add_kv(LLM_KV_GENERAL_QUANTIZATION_VERSION, ???);
|
||||
// add_kv(LLM_KV_GENERAL_ALIGNMENT, ???);
|
||||
add_kv(LLM_KV_GENERAL_NAME, model.name);
|
||||
// add_kv(LLM_KV_GENERAL_AUTHOR, ???);
|
||||
// add_kv(LLM_KV_GENERAL_VERSION, ???);
|
||||
// add_kv(LLM_KV_GENERAL_URL, ???);
|
||||
// add_kv(LLM_KV_GENERAL_DESCRIPTION, ???);
|
||||
// add_kv(LLM_KV_GENERAL_LICENSE, ???);
|
||||
// add_kv(LLM_KV_GENERAL_SOURCE_URL, ???);
|
||||
// add_kv(LLM_KV_GENERAL_SOURCE_HF_REPO, ???);
|
||||
|
||||
add_kv(LLM_KV_VOCAB_SIZE, vocab.n_tokens());
|
||||
add_kv(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
|
||||
add_kv(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
|
||||
add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer);
|
||||
add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
|
||||
add_kv(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, true);
|
||||
add_kv(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
add_kv(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
add_kv(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res);
|
||||
// add_kv(LLM_KV_TENSOR_DATA_LAYOUT, ???);
|
||||
add_kv(LLM_KV_EXPERT_COUNT, hparams.n_expert);
|
||||
add_kv(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used);
|
||||
add_kv(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
|
||||
add_kv(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
|
||||
add_kv(LLM_KV_POOLING_TYPE, uint32_t(hparams.pooling_type));
|
||||
add_kv(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
|
||||
add_kv(LLM_KV_DECODER_START_TOKEN_ID, hparams.dec_start_token_id);
|
||||
add_kv(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping);
|
||||
add_kv(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping);
|
||||
add_kv(LLM_KV_SWIN_NORM, hparams.swin_norm);
|
||||
add_kv(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers);
|
||||
add_kv(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim);
|
||||
add_kv(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim);
|
||||
add_kv(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale);
|
||||
add_kv(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale);
|
||||
|
||||
add_kv(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, true);
|
||||
add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true);
|
||||
add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias);
|
||||
add_kv(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv);
|
||||
add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k);
|
||||
add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v);
|
||||
add_kv(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
add_kv(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
||||
add_kv(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q);
|
||||
add_kv(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
|
||||
add_kv(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts);
|
||||
add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
||||
add_kv(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale);
|
||||
|
||||
const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train;
|
||||
|
||||
add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot);
|
||||
add_kv(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train);
|
||||
// add_kv(LLM_KV_ROPE_SCALE_LINEAR, rope_scaling_factor); // old name
|
||||
add_kv(LLM_KV_ROPE_SCALING_TYPE, llama_rope_scaling_type_name(hparams.rope_scaling_type_train));
|
||||
add_kv(LLM_KV_ROPE_SCALING_FACTOR, rope_scaling_factor);
|
||||
add_kv(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor);
|
||||
add_kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn);
|
||||
add_kv(LLM_KV_ROPE_SCALING_FINETUNED, hparams.rope_finetuned);
|
||||
add_kv(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul);
|
||||
|
||||
// TODO: implement split file support
|
||||
// add_kv(LLM_KV_SPLIT_NO, ???);
|
||||
// add_kv(LLM_KV_SPLIT_COUNT, ???);
|
||||
// add_kv(LLM_KV_SPLIT_TENSORS_COUNT, ???);
|
||||
|
||||
add_kv(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
|
||||
add_kv(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
|
||||
add_kv(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
|
||||
add_kv(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
|
||||
add_kv(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms);
|
||||
|
||||
add_kv(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
|
||||
|
||||
add_kv(LLM_KV_TOKENIZER_MODEL, vocab.get_tokenizer_model());
|
||||
add_kv(LLM_KV_TOKENIZER_PRE, vocab.get_tokenizer_pre());
|
||||
add_kv(LLM_KV_TOKENIZER_LIST, tokens);
|
||||
add_kv(LLM_KV_TOKENIZER_TOKEN_TYPE, token_types);
|
||||
add_kv(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, vocab.n_token_types());
|
||||
add_kv(LLM_KV_TOKENIZER_SCORES, scores);
|
||||
add_kv(LLM_KV_TOKENIZER_MERGES, vocab.get_bpe_merges());
|
||||
// FIXME llama_token is type i32 but when reading in a GGUF file u32 is expected, not an issue for writing though
|
||||
add_kv(LLM_KV_TOKENIZER_BOS_ID, uint32_t(vocab.token_bos()));
|
||||
add_kv(LLM_KV_TOKENIZER_EOS_ID, uint32_t(vocab.token_eos()));
|
||||
add_kv(LLM_KV_TOKENIZER_EOT_ID, uint32_t(vocab.token_eot()));
|
||||
add_kv(LLM_KV_TOKENIZER_EOM_ID, uint32_t(vocab.token_eom()));
|
||||
add_kv(LLM_KV_TOKENIZER_UNK_ID, uint32_t(vocab.token_unk()));
|
||||
add_kv(LLM_KV_TOKENIZER_SEP_ID, uint32_t(vocab.token_sep()));
|
||||
add_kv(LLM_KV_TOKENIZER_PAD_ID, uint32_t(vocab.token_pad()));
|
||||
// add_kv(LLM_KV_TOKENIZER_CLS_ID, uint32_t(vocab.token_bos())); // deprecated
|
||||
// add_kv(LLM_KV_TOKENIZER_MASK_ID, ???);
|
||||
add_kv(LLM_KV_TOKENIZER_ADD_BOS, vocab.get_add_bos());
|
||||
add_kv(LLM_KV_TOKENIZER_ADD_EOS, vocab.get_add_eos());
|
||||
add_kv(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.get_add_space_prefix());
|
||||
add_kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.get_remove_extra_whitespaces());
|
||||
add_kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, vocab.get_precompiled_charsmap());
|
||||
// add_kv(LLM_KV_TOKENIZER_HF_JSON, ???);
|
||||
// add_kv(LLM_KV_TOKENIZER_RWKV, ???);
|
||||
add_kv(LLM_KV_TOKENIZER_FIM_PRE_ID, uint32_t(vocab.token_fim_pre()));
|
||||
add_kv(LLM_KV_TOKENIZER_FIM_SUF_ID, uint32_t(vocab.token_fim_suf()));
|
||||
add_kv(LLM_KV_TOKENIZER_FIM_MID_ID, uint32_t(vocab.token_fim_mid()));
|
||||
add_kv(LLM_KV_TOKENIZER_FIM_PAD_ID, uint32_t(vocab.token_fim_pad()));
|
||||
add_kv(LLM_KV_TOKENIZER_FIM_REP_ID, uint32_t(vocab.token_fim_rep()));
|
||||
add_kv(LLM_KV_TOKENIZER_FIM_SEP_ID, uint32_t(vocab.token_fim_sep()));
|
||||
|
||||
// TODO: implement LoRA support
|
||||
// add_kv(LLM_KV_ADAPTER_TYPE, ???);
|
||||
// add_kv(LLM_KV_ADAPTER_LORA_ALPHA, ???);
|
||||
|
||||
// deprecated
|
||||
// add_kv(LLM_KV_TOKENIZER_PREFIX_ID, ???);
|
||||
// add_kv(LLM_KV_TOKENIZER_SUFFIX_ID, ???);
|
||||
// add_kv(LLM_KV_TOKENIZER_MIDDLE_ID, ???);
|
||||
}
|
||||
|
||||
void llama_model_saver::add_tensors_from_model() {
|
||||
if (std::string(model.output->name) != std::string(model.tok_embd->name)) {
|
||||
add_tensor(model.tok_embd); // some models use the same tensor for tok_embd and output
|
||||
}
|
||||
add_tensor(model.type_embd);
|
||||
add_tensor(model.pos_embd);
|
||||
add_tensor(model.tok_norm);
|
||||
add_tensor(model.tok_norm_b);
|
||||
add_tensor(model.output_norm);
|
||||
add_tensor(model.output_norm_b);
|
||||
add_tensor(model.output);
|
||||
add_tensor(model.output_b);
|
||||
add_tensor(model.output_norm_enc);
|
||||
add_tensor(model.cls);
|
||||
add_tensor(model.cls_b);
|
||||
add_tensor(model.cls_out);
|
||||
add_tensor(model.cls_out_b);
|
||||
|
||||
for (const struct llama_layer & layer : model.layers) {
|
||||
for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
|
||||
add_tensor(reinterpret_cast<const struct ggml_tensor * const *>(&layer)[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void llama_model_saver::save(const std::string & path_model) {
|
||||
gguf_write_to_file(gguf_ctx, path_model.c_str(), false);
|
||||
}
|
||||
|
37
src/llama-model-saver.h
Normal file
37
src/llama-model-saver.h
Normal file
|
@ -0,0 +1,37 @@
|
|||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
#include "llama-arch.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
struct llama_model_saver {
|
||||
struct gguf_context * gguf_ctx = nullptr;
|
||||
const struct llama_model & model;
|
||||
const struct LLM_KV llm_kv;
|
||||
|
||||
llama_model_saver(const struct llama_model & model);
|
||||
~llama_model_saver();
|
||||
|
||||
void add_kv(enum llm_kv key, uint32_t value);
|
||||
void add_kv(enum llm_kv key, int32_t value);
|
||||
void add_kv(enum llm_kv key, float value);
|
||||
void add_kv(enum llm_kv key, bool value);
|
||||
void add_kv(enum llm_kv key, const char * value);
|
||||
|
||||
[[noreturn]]
|
||||
void add_kv(enum llm_kv key, char value); // needed to make the template below compile
|
||||
|
||||
template <typename Container>
|
||||
void add_kv(enum llm_kv key, const Container & value, bool per_layer = false);
|
||||
|
||||
void add_kv(enum llm_kv key, const std::vector<std::string> & value);
|
||||
|
||||
void add_tensor(const struct ggml_tensor * tensor);
|
||||
|
||||
void add_kv_from_model();
|
||||
|
||||
void add_tensors_from_model();
|
||||
|
||||
void save(const std::string & path_model);
|
||||
};
|
|
@ -117,6 +117,10 @@ static const std::map<llama_rope_scaling_type, const char *> LLAMA_ROPE_SCALING_
|
|||
{ LLAMA_ROPE_SCALING_TYPE_LONGROPE, "longrope" },
|
||||
};
|
||||
|
||||
std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type) {
|
||||
return LLAMA_ROPE_SCALING_TYPES.at(rope_scaling_type);
|
||||
}
|
||||
|
||||
static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
|
||||
for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
|
||||
if (kv.second == name) {
|
||||
|
@ -4264,7 +4268,7 @@ uint64_t llama_model::n_elements() const {
|
|||
}
|
||||
|
||||
void llama_model::print_info() const {
|
||||
const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
|
||||
const std::string rope_scaling_type = llama_rope_scaling_type_name(hparams.rope_scaling_type_train);
|
||||
|
||||
auto print_f = [](const std::function<uint32_t(uint32_t)> & f, uint32_t n) {
|
||||
bool is_var = false;
|
||||
|
@ -4325,7 +4329,7 @@ void llama_model::print_info() const {
|
|||
LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn);
|
||||
LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type);
|
||||
LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type);
|
||||
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type);
|
||||
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str());
|
||||
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
|
||||
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
|
||||
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
|
||||
|
|
|
@ -96,6 +96,8 @@ enum llm_type {
|
|||
LLM_TYPE_235B_A22B,
|
||||
};
|
||||
|
||||
std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type);
|
||||
|
||||
struct llama_layer_posnet {
|
||||
// resnet
|
||||
struct ggml_tensor * norm1 = nullptr;
|
||||
|
|
|
@ -519,7 +519,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
nthread = std::thread::hardware_concurrency();
|
||||
}
|
||||
|
||||
// mmap consistently increases speed Linux, and also increases speed on Windows with
|
||||
// mmap consistently increases speed on Linux, and also increases speed on Windows with
|
||||
// hot cache. It may cause a slowdown on macOS, possibly related to free memory.
|
||||
#if defined(__linux__) || defined(_WIN32)
|
||||
constexpr bool use_mmap = true;
|
||||
|
@ -529,7 +529,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
|
||||
llama_model_kv_override * kv_overrides = nullptr;
|
||||
if (params->kv_overrides) {
|
||||
auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
|
||||
auto * v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
|
||||
kv_overrides = v->data();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
#include "llama-vocab.h"
|
||||
|
||||
#include "ggml.h"
|
||||
#include "gguf.h"
|
||||
#include "llama-impl.h"
|
||||
#include "llama-model-loader.h"
|
||||
|
||||
|
@ -1234,6 +1236,9 @@ struct fragment_buffer_variant {
|
|||
struct llama_vocab::impl {
|
||||
uint32_t n_token_types = 0; // for BERT-style token types
|
||||
|
||||
std::string tokenizer_model;
|
||||
std::string tokenizer_pre;
|
||||
|
||||
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
|
||||
enum llama_vocab_pre_type pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||
|
||||
|
@ -1369,9 +1374,6 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
|
||||
// determine vocab type
|
||||
{
|
||||
std::string tokenizer_model;
|
||||
std::string tokenizer_pre;
|
||||
|
||||
ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model);
|
||||
ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
|
||||
|
||||
|
@ -1466,7 +1468,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
|
||||
const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
|
||||
if (precompiled_charsmap_keyidx != -1) {
|
||||
size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
|
||||
const gguf_type pc_type = gguf_get_arr_type(ctx, precompiled_charsmap_keyidx);
|
||||
GGML_ASSERT(pc_type == GGUF_TYPE_INT8 || pc_type == GGUF_TYPE_UINT8);
|
||||
|
||||
const size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
|
||||
const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
|
||||
precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
|
||||
#ifdef IS_BIG_ENDIAN
|
||||
|
@ -2789,6 +2794,14 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
pimpl->load(ml, kv);
|
||||
}
|
||||
|
||||
std::string llama_vocab::get_tokenizer_model() const {
|
||||
return pimpl->tokenizer_model;
|
||||
}
|
||||
|
||||
std::string llama_vocab::get_tokenizer_pre() const {
|
||||
return pimpl->tokenizer_pre;
|
||||
}
|
||||
|
||||
enum llama_vocab_type llama_vocab::get_type() const {
|
||||
return pimpl->type;
|
||||
}
|
||||
|
@ -3011,6 +3024,20 @@ int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string
|
|||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> llama_vocab::get_bpe_merges() const {
|
||||
std::vector<std::string> result(pimpl->bpe_ranks.size());
|
||||
|
||||
for (const auto & pair : pimpl->bpe_ranks) {
|
||||
result[pair.second] = pair.first.first + " " + pair.first.second;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<char> llama_vocab::get_precompiled_charsmap() const {
|
||||
return pimpl->precompiled_charsmap;
|
||||
}
|
||||
|
||||
int32_t llama_vocab::tokenize(
|
||||
const char * text,
|
||||
int32_t text_len,
|
||||
|
|
|
@ -21,6 +21,9 @@ struct llama_vocab {
|
|||
|
||||
void load(llama_model_loader & ml, const LLM_KV & kv);
|
||||
|
||||
std::string get_tokenizer_model() const;
|
||||
std::string get_tokenizer_pre() const;
|
||||
|
||||
enum llama_vocab_type get_type() const;
|
||||
enum llama_vocab_pre_type get_pre_type() const;
|
||||
|
||||
|
@ -80,6 +83,9 @@ struct llama_vocab {
|
|||
int max_token_len() const;
|
||||
|
||||
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
|
||||
std::vector<std::string> get_bpe_merges() const;
|
||||
|
||||
std::vector<char> get_precompiled_charsmap() const;
|
||||
|
||||
int32_t tokenize(
|
||||
const char * text,
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
#include "llama-mmap.h"
|
||||
#include "llama-vocab.h"
|
||||
#include "llama-model-loader.h"
|
||||
#include "llama-model-saver.h"
|
||||
#include "llama-model.h"
|
||||
|
||||
#include "ggml.h"
|
||||
|
@ -253,6 +254,13 @@ struct llama_model * llama_model_load_from_splits(
|
|||
return llama_model_load_from_file_impl(splits.front(), splits, params);
|
||||
}
|
||||
|
||||
void llama_model_save_to_file(const struct llama_model * model, const char * path_model) {
|
||||
llama_model_saver ms(*model);
|
||||
ms.add_kv_from_model();
|
||||
ms.add_tensors_from_model();
|
||||
ms.save(path_model);
|
||||
}
|
||||
|
||||
//
|
||||
// chat templates
|
||||
//
|
||||
|
@ -338,3 +346,4 @@ const char * llama_print_system_info(void) {
|
|||
|
||||
return s.c_str();
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue