llama : use n_swa + n_ubatch cells for SWA cache (#13833)
* llama : use n_swa + n_ubatch cells for SWA cache ggml-ci * llama : add warning about multi-sqeuence SWA contexts
This commit is contained in:
parent
c7e0a2054b
commit
3600cc2886
6 changed files with 24 additions and 11 deletions
|
@ -366,6 +366,8 @@ extern "C" {
|
||||||
bool no_perf; // measure performance timings
|
bool no_perf; // measure performance timings
|
||||||
bool op_offload; // offload host tensor operations to device
|
bool op_offload; // offload host tensor operations to device
|
||||||
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||||
|
// NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
|
||||||
};
|
};
|
||||||
|
|
||||||
// model quantization parameters
|
// model quantization parameters
|
||||||
|
@ -502,6 +504,7 @@ extern "C" {
|
||||||
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
||||||
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
||||||
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
|
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
|
||||||
|
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
|
||||||
|
|
||||||
// Get the model's RoPE frequency scaling factor
|
// Get the model's RoPE frequency scaling factor
|
||||||
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
||||||
|
|
|
@ -123,6 +123,11 @@ llama_context::llama_context(
|
||||||
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!params.swa_full && cparams.n_seq_max > 1) {
|
||||||
|
LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
|
||||||
|
__func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
|
||||||
|
}
|
||||||
|
|
||||||
if (!hparams.vocab_only) {
|
if (!hparams.vocab_only) {
|
||||||
// GPU backends
|
// GPU backends
|
||||||
for (auto * dev : model.devices) {
|
for (auto * dev : model.devices) {
|
||||||
|
|
|
@ -1731,14 +1731,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
||||||
bool swa_full,
|
bool swa_full,
|
||||||
uint32_t kv_size,
|
uint32_t kv_size,
|
||||||
uint32_t n_seq_max,
|
uint32_t n_seq_max,
|
||||||
uint32_t n_batch,
|
uint32_t n_ubatch,
|
||||||
uint32_t n_pad) : hparams(model.hparams) {
|
uint32_t n_pad) : hparams(model.hparams) {
|
||||||
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
||||||
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
||||||
|
|
||||||
const uint32_t size_base = kv_size;
|
const uint32_t size_base = kv_size;
|
||||||
|
|
||||||
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad));
|
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
|
||||||
|
|
||||||
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
||||||
if (swa_full) {
|
if (swa_full) {
|
||||||
|
|
|
@ -339,7 +339,7 @@ public:
|
||||||
bool swa_full,
|
bool swa_full,
|
||||||
uint32_t kv_size,
|
uint32_t kv_size,
|
||||||
uint32_t n_seq_max,
|
uint32_t n_seq_max,
|
||||||
uint32_t n_batch,
|
uint32_t n_ubatch,
|
||||||
uint32_t n_pad);
|
uint32_t n_pad);
|
||||||
|
|
||||||
~llama_kv_cache_unified_iswa() = default;
|
~llama_kv_cache_unified_iswa() = default;
|
||||||
|
|
|
@ -13230,7 +13230,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||||
params.swa_full,
|
params.swa_full,
|
||||||
cparams.n_ctx,
|
cparams.n_ctx,
|
||||||
cparams.n_seq_max,
|
cparams.n_seq_max,
|
||||||
cparams.n_batch,
|
cparams.n_ubatch,
|
||||||
padding);
|
padding);
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(!hparams.is_swa_any());
|
GGML_ASSERT(!hparams.is_swa_any());
|
||||||
|
@ -13593,6 +13593,10 @@ int32_t llama_model_n_head_kv(const llama_model * model) {
|
||||||
return model->hparams.n_head_kv();
|
return model->hparams.n_head_kv();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int32_t llama_model_n_swa(const llama_model * model) {
|
||||||
|
return model->hparams.n_swa;
|
||||||
|
}
|
||||||
|
|
||||||
// deprecated
|
// deprecated
|
||||||
int32_t llama_n_ctx_train(const llama_model * model) {
|
int32_t llama_n_ctx_train(const llama_model * model) {
|
||||||
return llama_model_n_ctx_train(model);
|
return llama_model_n_ctx_train(model);
|
||||||
|
|
|
@ -2016,11 +2016,6 @@ struct server_context {
|
||||||
params_base.n_cache_reuse = 0;
|
params_base.n_cache_reuse = 0;
|
||||||
SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled");
|
SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!params_base.speculative.model.path.empty()) {
|
|
||||||
SRV_ERR("%s\n", "err: speculative decode is not supported by this context");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
@ -3215,8 +3210,14 @@ struct server_context {
|
||||||
|
|
||||||
if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
|
if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
|
||||||
const auto pos_min = llama_kv_self_seq_pos_min(ctx, slot.id);
|
const auto pos_min = llama_kv_self_seq_pos_min(ctx, slot.id);
|
||||||
if (pos_min > 0) {
|
if (pos_min == -1) {
|
||||||
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min);
|
SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min);
|
||||||
|
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto n_swa = llama_model_n_swa(model);
|
||||||
|
if (pos_min > slot.n_past - n_swa) {
|
||||||
|
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
|
||||||
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
|
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
|
||||||
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
||||||
llama_kv_self_seq_rm(ctx, slot.id, 0, -1);
|
llama_kv_self_seq_rm(ctx, slot.id, 0, -1);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue