Compare commits
10 commits
aa0ef5c578
...
ccc4c16970
Author | SHA1 | Date | |
---|---|---|---|
ccc4c16970 | |||
![]() |
3a9457df96 | ||
![]() |
fa4a9f2a1c | ||
![]() |
238005c2dc | ||
![]() |
66aba7aca9 | ||
![]() |
f1f5e82df6 | ||
![]() |
af3373f1ad | ||
![]() |
5d5c066de8 | ||
![]() |
40bfa04c95 | ||
![]() |
aa064b2eb7 |
17 changed files with 296 additions and 168 deletions
|
@ -1,11 +1,11 @@
|
|||
ARG UBUNTU_VERSION=24.04
|
||||
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG ROCM_VERSION=6.3
|
||||
ARG AMDGPU_VERSION=6.3
|
||||
ARG ROCM_VERSION=6.3.4
|
||||
ARG AMDGPU_VERSION=6.3.4
|
||||
|
||||
# Target the CUDA build image
|
||||
ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
|
||||
ARG BASE_ROCM_DEV_CONTAINER=docker.io/rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
|
||||
|
||||
### Build image
|
||||
FROM ${BASE_ROCM_DEV_CONTAINER} AS build
|
||||
|
|
4
.github/workflows/build.yml
vendored
4
.github/workflows/build.yml
vendored
|
@ -683,7 +683,7 @@ jobs:
|
|||
env:
|
||||
OPENBLAS_VERSION: 0.3.23
|
||||
SDE_VERSION: 9.33.0-2024-01-07
|
||||
VULKAN_VERSION: 1.4.309.0
|
||||
VULKAN_VERSION: 1.4.313.2
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
|
@ -736,7 +736,7 @@ jobs:
|
|||
id: get_vulkan
|
||||
if: ${{ matrix.build == 'kompute-x64' || matrix.build == 'vulkan-x64' }}
|
||||
run: |
|
||||
curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/VulkanSDK-${env:VULKAN_VERSION}-Installer.exe"
|
||||
curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/vulkansdk-windows-X64-${env:VULKAN_VERSION}.exe"
|
||||
& "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install
|
||||
Add-Content $env:GITHUB_ENV "VULKAN_SDK=C:\VulkanSDK\${env:VULKAN_VERSION}"
|
||||
Add-Content $env:GITHUB_PATH "C:\VulkanSDK\${env:VULKAN_VERSION}\bin"
|
||||
|
|
|
@ -41,49 +41,6 @@ static std::string build_repetition(const std::string & item_rule, int min_items
|
|||
return result;
|
||||
}
|
||||
|
||||
/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */
|
||||
class string_view {
|
||||
const std::string & _str;
|
||||
const size_t _start;
|
||||
const size_t _end;
|
||||
public:
|
||||
string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {}
|
||||
|
||||
size_t size() const {
|
||||
return _end - _start;
|
||||
}
|
||||
|
||||
size_t length() const {
|
||||
return size();
|
||||
}
|
||||
|
||||
operator std::string() const {
|
||||
return str();
|
||||
}
|
||||
|
||||
std::string str() const {
|
||||
return _str.substr(_start, _end - _start);
|
||||
}
|
||||
|
||||
string_view substr(size_t pos, size_t len = std::string::npos) const {
|
||||
return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len);
|
||||
}
|
||||
|
||||
char operator[](size_t pos) const {
|
||||
auto index = _start + pos;
|
||||
if (index >= _end) {
|
||||
throw std::out_of_range("string_view index out of range");
|
||||
}
|
||||
return _str[_start + pos];
|
||||
}
|
||||
|
||||
bool operator==(const string_view & other) const {
|
||||
std::string this_str = *this;
|
||||
std::string other_str = other;
|
||||
return this_str == other_str;
|
||||
}
|
||||
};
|
||||
|
||||
static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
|
||||
auto has_min = min_value != std::numeric_limits<int>::min();
|
||||
auto has_max = max_value != std::numeric_limits<int>::max();
|
||||
|
@ -112,14 +69,14 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
|||
}
|
||||
out << "}";
|
||||
};
|
||||
std::function<void(const string_view &, const string_view &)> uniform_range =
|
||||
[&](const string_view & from, const string_view & to) {
|
||||
std::function<void(const std::string_view &, const std::string_view &)> uniform_range =
|
||||
[&](const std::string_view & from, const std::string_view & to) {
|
||||
size_t i = 0;
|
||||
while (i < from.length() && i < to.length() && from[i] == to[i]) {
|
||||
i++;
|
||||
}
|
||||
if (i > 0) {
|
||||
out << "\"" << from.substr(0, i).str() << "\"";
|
||||
out << "\"" << from.substr(0, i) << "\"";
|
||||
}
|
||||
if (i < from.length() && i < to.length()) {
|
||||
if (i > 0) {
|
||||
|
|
|
@ -98,7 +98,7 @@ int main(int argc, char ** argv) {
|
|||
auto generate = [&](const std::string & prompt) {
|
||||
std::string response;
|
||||
|
||||
const bool is_first = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) == 0;
|
||||
const bool is_first = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) == -1;
|
||||
|
||||
// tokenize the prompt
|
||||
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
|
||||
|
|
|
@ -241,8 +241,18 @@ static bool fp16_mma_available(const int cc) {
|
|||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
|
||||
return false;
|
||||
#else
|
||||
return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
|
||||
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
|
||||
if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
|
||||
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
|
||||
return true;
|
||||
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
|
||||
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
|
||||
}
|
||||
|
||||
|
@ -362,6 +372,26 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
|||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
|
||||
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
|
||||
template<bool norm>
|
||||
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
|
||||
const int row = blockIdx.x;
|
||||
const int col = threadIdx.x;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int i = col; i < ncols; i += blockDim.x) {
|
||||
sum += x[row * ncols + i];
|
||||
}
|
||||
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
if (col != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst[row] = norm ? sum / ncols : sum;
|
||||
}
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||
#pragma unroll
|
||||
|
|
|
@ -37,6 +37,7 @@
|
|||
#include "ggml-cuda/ssm-scan.cuh"
|
||||
#include "ggml-cuda/sum.cuh"
|
||||
#include "ggml-cuda/sumrows.cuh"
|
||||
#include "ggml-cuda/mean.cuh"
|
||||
#include "ggml-cuda/tsembd.cuh"
|
||||
#include "ggml-cuda/unary.cuh"
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
|
@ -99,8 +100,7 @@ int ggml_cuda_get_device() {
|
|||
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
|
||||
ggml_cuda_set_device(device);
|
||||
cudaError_t err;
|
||||
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
|
||||
{
|
||||
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) {
|
||||
err = cudaMallocManaged(ptr, size);
|
||||
#if defined(GGML_USE_HIP)
|
||||
if (err == hipSuccess) {
|
||||
|
@ -118,9 +118,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
|
|||
err = cudaMalloc(ptr, size);
|
||||
}
|
||||
#endif // defined(GGML_USE_HIP)
|
||||
}
|
||||
else
|
||||
{
|
||||
} else {
|
||||
err = cudaMalloc(ptr, size);
|
||||
}
|
||||
return err;
|
||||
|
@ -2357,6 +2355,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_SUM_ROWS:
|
||||
ggml_cuda_op_sum_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_MEAN:
|
||||
ggml_cuda_op_mean(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
ggml_cuda_op_ssm_conv(ctx, dst);
|
||||
break;
|
||||
|
@ -3260,6 +3261,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_ACC:
|
||||
return true;
|
||||
|
|
19
ggml/src/ggml-cuda/mean.cu
Normal file
19
ggml/src/ggml-cuda/mean.cu
Normal file
|
@ -0,0 +1,19 @@
|
|||
#include "mean.cuh"
|
||||
|
||||
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const int64_t ncols = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums(nrows, 1, 1);
|
||||
reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
||||
}
|
3
ggml/src/ggml-cuda/mean.cuh
Normal file
3
ggml/src/ggml-cuda/mean.cuh
Normal file
|
@ -0,0 +1,3 @@
|
|||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@ -1,25 +1,9 @@
|
|||
#include "sumrows.cuh"
|
||||
|
||||
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
|
||||
const int row = blockIdx.x;
|
||||
const int col = threadIdx.x;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int i = col; i < ncols; i += blockDim.x) {
|
||||
sum += x[row * ncols + i];
|
||||
}
|
||||
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
if (col == 0) {
|
||||
dst[row] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums(nrows, 1, 1);
|
||||
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
@ -35,5 +19,8 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
const int64_t ncols = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums(nrows, 1, 1);
|
||||
|
||||
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
#include "common.cuh"
|
||||
|
||||
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
|
||||
|
||||
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
|
@ -167,7 +167,7 @@ class SpecialVocab:
|
|||
tokenizer_config['bos_token'] = special_bos = special_cls
|
||||
if not special_eos and special_sep and tokenizer_config:
|
||||
tokenizer_config['eos_token'] = special_eos = special_sep
|
||||
post_processor = tokenizer.get('post_processor', {})
|
||||
if post_processor := tokenizer.get('post_processor'):
|
||||
for processor in post_processor.get('processors', [post_processor]):
|
||||
if processor.get('type') == 'RobertaProcessing':
|
||||
self.add_special_token['bos'] = True
|
||||
|
|
|
@ -390,6 +390,7 @@ extern "C" {
|
|||
void * imatrix; // pointer to importance matrix data
|
||||
void * kv_overrides; // pointer to vector containing overrides
|
||||
void * tensor_types; // pointer to vector containing tensor types
|
||||
void * prune_layers; // pointer to vector containing layer indices to prune
|
||||
} llama_model_quantize_params;
|
||||
|
||||
typedef struct llama_logit_bias {
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
#include "llama-quant.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-model-loader.h"
|
||||
|
@ -27,6 +26,56 @@ static void zeros(std::ofstream & file, size_t n) {
|
|||
}
|
||||
}
|
||||
|
||||
static std::string remap_layer(const std::string & orig_name, const std::vector<int> & prune, std::map<int, std::string> & mapped, int & next_id) {
|
||||
if (prune.empty()) {
|
||||
return orig_name;
|
||||
}
|
||||
|
||||
static const std::regex pattern(R"(blk\.(\d+)\.)");
|
||||
if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
|
||||
const int blk = std::stoi(match[1]);
|
||||
std::string new_name = orig_name;
|
||||
|
||||
if (mapped.count(blk)) {
|
||||
// Already mapped, do nothing
|
||||
} else if (std::find(prune.begin(), prune.end(), blk) != prune.end()) {
|
||||
mapped[blk] = "";
|
||||
} else if (blk < prune.front()) {
|
||||
mapped[blk] = std::to_string(blk);
|
||||
next_id = blk + 1;
|
||||
} else {
|
||||
mapped[blk] = std::to_string(next_id);
|
||||
++next_id;
|
||||
}
|
||||
|
||||
return mapped[blk].empty() ? mapped[blk] : new_name.replace(match.position(1), match.length(1), mapped[blk]);
|
||||
}
|
||||
|
||||
return orig_name;
|
||||
}
|
||||
|
||||
static std::string remap_imatrix (const std::string & orig_name, const std::map<int, std::string> & mapped) {
|
||||
if (mapped.empty()) {
|
||||
return orig_name;
|
||||
}
|
||||
|
||||
static const std::regex pattern(R"(blk\.(\d+)\.)");
|
||||
if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
|
||||
const std::string blk(match[1]);
|
||||
std::string new_name = orig_name;
|
||||
|
||||
for (const auto & p : mapped) {
|
||||
if (p.second == blk) {
|
||||
LLAMA_LOG_DEBUG("(blk.%d imatrix) ", p.first);
|
||||
return new_name.replace(match.position(1), match.length(1), std::to_string(p.first));
|
||||
}
|
||||
}
|
||||
GGML_ABORT("\n%s: imatrix mapping error for %s\n", __func__, orig_name.c_str());
|
||||
}
|
||||
|
||||
return orig_name;
|
||||
}
|
||||
|
||||
struct quantize_state_impl {
|
||||
const llama_model & model;
|
||||
const llama_model_quantize_params * params;
|
||||
|
@ -568,6 +617,11 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
const size_t align = GGUF_DEFAULT_ALIGNMENT;
|
||||
gguf_context_ptr ctx_out { gguf_init_empty() };
|
||||
|
||||
std::vector<int> prune_list = {};
|
||||
if (params->prune_layers) {
|
||||
prune_list = *static_cast<const std::vector<int> *>(params->prune_layers);
|
||||
}
|
||||
|
||||
// copy the KV pairs from the input file
|
||||
gguf_set_kv (ctx_out.get(), ml.meta.get());
|
||||
gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
|
||||
|
@ -597,12 +651,32 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
}
|
||||
}
|
||||
|
||||
std::map<int, std::string> mapped;
|
||||
int blk_id = 0;
|
||||
int pruned_attention_w = 0;
|
||||
|
||||
// make a list of weights
|
||||
std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
|
||||
tensors.reserve(ml.weights_map.size());
|
||||
for (const auto & it : ml.weights_map) {
|
||||
const std::string remapped_name(remap_layer(it.first, prune_list, mapped, blk_id));
|
||||
if (remapped_name.empty()) {
|
||||
if (it.first.find("attn_v.weight") != std::string::npos ||
|
||||
it.first.find("attn_qkv.weight") != std::string::npos ||
|
||||
it.first.find("attn_kv_b.weight") != std::string::npos) {
|
||||
pruned_attention_w++;
|
||||
}
|
||||
LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
|
||||
continue;
|
||||
} else if (remapped_name != it.first) {
|
||||
ggml_set_name(it.second.tensor, remapped_name.c_str());
|
||||
LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor));
|
||||
}
|
||||
tensors.push_back(&it.second);
|
||||
}
|
||||
if (!prune_list.empty()) {
|
||||
gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_BLOCK_COUNT).c_str(), blk_id);
|
||||
}
|
||||
|
||||
// keep_split requires that the weights are sorted by split index
|
||||
if (params->keep_split) {
|
||||
|
@ -640,7 +714,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
if (llama_model_has_encoder(&model)) {
|
||||
n_attn_layer *= 3;
|
||||
}
|
||||
GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
|
||||
GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected");
|
||||
}
|
||||
|
||||
size_t total_size_org = 0;
|
||||
|
@ -681,7 +755,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
for (size_t i = 0; i < ctx_outs.size(); ++i) {
|
||||
gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
|
||||
gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
|
||||
gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors);
|
||||
gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), (int32_t)tensors.size());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -832,7 +906,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
|
||||
const float * imatrix = nullptr;
|
||||
if (imatrix_data) {
|
||||
auto it = imatrix_data->find(tensor->name);
|
||||
auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped));
|
||||
if (it == imatrix_data->end()) {
|
||||
LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
|
||||
} else {
|
||||
|
@ -947,6 +1021,7 @@ llama_model_quantize_params llama_model_quantize_default_params() {
|
|||
/*.imatrix =*/ nullptr,
|
||||
/*.kv_overrides =*/ nullptr,
|
||||
/*.tensor_type =*/ nullptr,
|
||||
/*.prune_layers =*/ nullptr
|
||||
};
|
||||
|
||||
return result;
|
||||
|
|
|
@ -4652,6 +4652,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|||
|
||||
test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1));
|
||||
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));
|
||||
|
||||
return test_cases;
|
||||
}
|
||||
|
||||
|
|
|
@ -2211,6 +2211,9 @@ struct clip_model_loader {
|
|||
{
|
||||
hparams.rope_theta = 10000.0f;
|
||||
hparams.warmup_image_size = hparams.patch_size * 8;
|
||||
// Mistral Small 2506 needs 1024x1024 image size cap to prevent OOM
|
||||
// ref: https://github.com/ggml-org/llama.cpp/issues/14310
|
||||
hparams.image_size = 1024;
|
||||
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
|
|
|
@ -107,13 +107,11 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp
|
|||
return false;
|
||||
}
|
||||
|
||||
// usage:
|
||||
// ./llama-quantize [--allow-requantize] [--leave-output-tensor] [--pure] models/llama/ggml-model.gguf [models/llama/ggml-model-quant.gguf] type [nthreads]
|
||||
//
|
||||
[[noreturn]]
|
||||
static void usage(const char * executable) {
|
||||
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type]\n", executable);
|
||||
printf(" [--token-embedding-type] [--tensor-type] [--keep-split] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n");
|
||||
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights]\n", executable);
|
||||
printf(" [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--tensor-type] [--prune-layers] [--keep-split] [--override-kv]\n");
|
||||
printf(" model-f32.gguf [model-quant.gguf] type [nthreads]\n\n");
|
||||
printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n");
|
||||
printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n");
|
||||
printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n");
|
||||
|
@ -124,6 +122,8 @@ static void usage(const char * executable) {
|
|||
printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n");
|
||||
printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. example: --tensor-type attn_q=q8_0\n");
|
||||
printf(" Advanced option to selectively quantize tensors. May be specified multiple times.\n");
|
||||
printf(" --prune-layers L0,L1,L2...comma-separated list of layer numbers to prune from the model\n");
|
||||
printf(" Advanced option to remove all tensors from the given layers\n");
|
||||
printf(" --keep-split: will generate quantized model in the same shards as input\n");
|
||||
printf(" --override-kv KEY=TYPE:VALUE\n");
|
||||
printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n");
|
||||
|
@ -286,6 +286,32 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool parse_layer_prune(const char * data, std::vector<int> & prune_layers) {
|
||||
if (!data) {
|
||||
printf("\n%s: no layer pruning ids provided\n\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto block_ids = string_split<std::string>(data, ',');
|
||||
for (const auto & block_id : block_ids) {
|
||||
int id;
|
||||
try {
|
||||
id = std::stoi(block_id);
|
||||
} catch (...) {
|
||||
id = -1;
|
||||
}
|
||||
if (id < 0) {
|
||||
printf("\n%s: invalid layer id '%s'\n\n", __func__, block_id.c_str());
|
||||
return false;
|
||||
}
|
||||
prune_layers.emplace_back(id);
|
||||
}
|
||||
|
||||
sort(prune_layers.begin(), prune_layers.end());
|
||||
prune_layers.erase(std::unique(prune_layers.begin(), prune_layers.end()), prune_layers.end());
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
if (argc < 3) {
|
||||
usage(argv[0]);
|
||||
|
@ -298,6 +324,7 @@ int main(int argc, char ** argv) {
|
|||
std::vector<std::string> included_weights, excluded_weights;
|
||||
std::vector<llama_model_kv_override> kv_overrides;
|
||||
std::vector<tensor_quantization> tensor_types;
|
||||
std::vector<int> prune_layers;
|
||||
|
||||
for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
|
||||
if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) {
|
||||
|
@ -324,6 +351,10 @@ int main(int argc, char ** argv) {
|
|||
if (arg_idx == argc-1 || !parse_tensor_type(argv[++arg_idx], tensor_types)) {
|
||||
usage(argv[0]);
|
||||
}
|
||||
} else if (strcmp(argv[arg_idx], "--prune-layers") == 0) {
|
||||
if (arg_idx == argc-1 || !parse_layer_prune(argv[++arg_idx], prune_layers)) {
|
||||
usage(argv[0]);
|
||||
}
|
||||
} else if (strcmp(argv[arg_idx], "--override-kv") == 0) {
|
||||
if (arg_idx == argc-1 || !string_parse_kv_override(argv[++arg_idx], kv_overrides)) {
|
||||
usage(argv[0]);
|
||||
|
@ -411,6 +442,9 @@ int main(int argc, char ** argv) {
|
|||
if (!tensor_types.empty()) {
|
||||
params.tensor_types = &tensor_types;
|
||||
}
|
||||
if (!prune_layers.empty()) {
|
||||
params.prune_layers = &prune_layers;
|
||||
}
|
||||
|
||||
llama_backend_init();
|
||||
|
||||
|
|
|
@ -9,6 +9,9 @@
|
|||
#include <nlohmann/json.hpp>
|
||||
|
||||
#if defined(_WIN32)
|
||||
# ifndef NOMINMAX
|
||||
# define NOMINMAX
|
||||
# endif
|
||||
# include <windows.h>
|
||||
# include <io.h>
|
||||
#else
|
||||
|
@ -939,17 +942,30 @@ static int apply_chat_template(const struct common_chat_templates * tmpls, Llama
|
|||
// Function to tokenize the prompt
|
||||
static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
|
||||
std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
|
||||
const bool is_first = llama_memory_seq_pos_max(llama_get_memory(llama_data.context.get()), 0) == 0;
|
||||
|
||||
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
|
||||
prompt_tokens.resize(n_prompt_tokens);
|
||||
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first,
|
||||
true) < 0) {
|
||||
printe("failed to tokenize the prompt\n");
|
||||
const bool is_first = llama_memory_seq_pos_max(llama_get_memory(llama_data.context.get()), 0) == -1;
|
||||
int n_tokens = prompt.size() + 2 * is_first;
|
||||
prompt_tokens.resize(n_tokens);
|
||||
n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.size(),
|
||||
prompt_tokens.data(), prompt_tokens.size(),
|
||||
is_first, /*parse_special =*/true);
|
||||
if (n_tokens == std::numeric_limits<int32_t>::min()) {
|
||||
printe("tokenization failed: input too large\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
return n_prompt_tokens;
|
||||
if (n_tokens < 0) {
|
||||
prompt_tokens.resize(-n_tokens);
|
||||
int check = llama_tokenize(vocab, prompt.c_str(), prompt.size(),
|
||||
prompt_tokens.data(), prompt_tokens.size(),
|
||||
is_first, /*parse_special =*/true);
|
||||
if (check != -n_tokens) {
|
||||
printe("failed to tokenize the prompt (size mismatch)\n");
|
||||
return -1;
|
||||
}
|
||||
n_tokens = check;
|
||||
} else {
|
||||
prompt_tokens.resize(n_tokens);
|
||||
}
|
||||
return n_tokens;
|
||||
}
|
||||
|
||||
// Check if we have enough space in the context to evaluate this batch
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue