diff --git a/.devops/rocm.Dockerfile b/.devops/rocm.Dockerfile index 1c00f1b9..f7f5b3cd 100644 --- a/.devops/rocm.Dockerfile +++ b/.devops/rocm.Dockerfile @@ -1,11 +1,11 @@ ARG UBUNTU_VERSION=24.04 # This needs to generally match the container host's environment. -ARG ROCM_VERSION=6.3 -ARG AMDGPU_VERSION=6.3 +ARG ROCM_VERSION=6.3.4 +ARG AMDGPU_VERSION=6.3.4 # Target the CUDA build image -ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete +ARG BASE_ROCM_DEV_CONTAINER=docker.io/rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete ### Build image FROM ${BASE_ROCM_DEV_CONTAINER} AS build diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c4783a6d..be282897 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -683,7 +683,7 @@ jobs: env: OPENBLAS_VERSION: 0.3.23 SDE_VERSION: 9.33.0-2024-01-07 - VULKAN_VERSION: 1.4.309.0 + VULKAN_VERSION: 1.4.313.2 strategy: matrix: @@ -736,7 +736,7 @@ jobs: id: get_vulkan if: ${{ matrix.build == 'kompute-x64' || matrix.build == 'vulkan-x64' }} run: | - curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/VulkanSDK-${env:VULKAN_VERSION}-Installer.exe" + curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/vulkansdk-windows-X64-${env:VULKAN_VERSION}.exe" & "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install Add-Content $env:GITHUB_ENV "VULKAN_SDK=C:\VulkanSDK\${env:VULKAN_VERSION}" Add-Content $env:GITHUB_PATH "C:\VulkanSDK\${env:VULKAN_VERSION}\bin" diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index d38a74f9..637891f5 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -41,49 +41,6 @@ static std::string build_repetition(const std::string & item_rule, int min_items return result; } -/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */ -class string_view { - const std::string & _str; - const size_t _start; - const size_t _end; -public: - string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {} - - size_t size() const { - return _end - _start; - } - - size_t length() const { - return size(); - } - - operator std::string() const { - return str(); - } - - std::string str() const { - return _str.substr(_start, _end - _start); - } - - string_view substr(size_t pos, size_t len = std::string::npos) const { - return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len); - } - - char operator[](size_t pos) const { - auto index = _start + pos; - if (index >= _end) { - throw std::out_of_range("string_view index out of range"); - } - return _str[_start + pos]; - } - - bool operator==(const string_view & other) const { - std::string this_str = *this; - std::string other_str = other; - return this_str == other_str; - } -}; - static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) { auto has_min = min_value != std::numeric_limits::min(); auto has_max = max_value != std::numeric_limits::max(); @@ -112,14 +69,14 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & } out << "}"; }; - std::function uniform_range = - [&](const string_view & from, const string_view & to) { + std::function uniform_range = + [&](const std::string_view & from, const std::string_view & to) { size_t i = 0; while (i < from.length() && i < to.length() && from[i] == to[i]) { i++; } if (i > 0) { - out << "\"" << from.substr(0, i).str() << "\""; + out << "\"" << from.substr(0, i) << "\""; } if (i < from.length() && i < to.length()) { if (i > 0) { diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index 2aee0a91..cf117804 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -98,7 +98,7 @@ int main(int argc, char ** argv) { auto generate = [&](const std::string & prompt) { std::string response; - const bool is_first = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) == 0; + const bool is_first = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) == -1; // tokenize the prompt const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 364efcae..86c4d29a 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -241,8 +241,18 @@ static bool fp16_mma_available(const int cc) { #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) return false; #else - return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) || - GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc); + if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) || + GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) { + return true; + } else if (GGML_CUDA_CC_IS_RDNA4(cc)) { +#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12) + return true; +#else + return false; +#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12) + } else { + return false; + } #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) } @@ -362,6 +372,26 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #endif // FP16_AVAILABLE } +// Row reduction kernel template - compute sum (norm=false) or mean (norm=true) +template +static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) { + const int row = blockIdx.x; + const int col = threadIdx.x; + + float sum = 0.0f; + for (int i = col; i < ncols; i += blockDim.x) { + sum += x[row * ncols + i]; + } + + sum = warp_reduce_sum(sum); + + if (col != 0) { + return; + } + + dst[row] = norm ? sum / ncols : sum; +} + template static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 5bab92e3..462db71e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -37,6 +37,7 @@ #include "ggml-cuda/ssm-scan.cuh" #include "ggml-cuda/sum.cuh" #include "ggml-cuda/sumrows.cuh" +#include "ggml-cuda/mean.cuh" #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" @@ -99,8 +100,7 @@ int ggml_cuda_get_device() { static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { ggml_cuda_set_device(device); cudaError_t err; - if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) - { + if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) { err = cudaMallocManaged(ptr, size); #if defined(GGML_USE_HIP) if (err == hipSuccess) { @@ -118,9 +118,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) err = cudaMalloc(ptr, size); } #endif // defined(GGML_USE_HIP) - } - else - { + } else { err = cudaMalloc(ptr, size); } return err; @@ -2357,6 +2355,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SUM_ROWS: ggml_cuda_op_sum_rows(ctx, dst); break; + case GGML_OP_MEAN: + ggml_cuda_op_mean(ctx, dst); + break; case GGML_OP_SSM_CONV: ggml_cuda_op_ssm_conv(ctx, dst); break; @@ -3260,6 +3261,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_POOL_2D: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: case GGML_OP_ARGSORT: case GGML_OP_ACC: return true; diff --git a/ggml/src/ggml-cuda/mean.cu b/ggml/src/ggml-cuda/mean.cu new file mode 100644 index 00000000..4b238a39 --- /dev/null +++ b/ggml/src/ggml-cuda/mean.cu @@ -0,0 +1,19 @@ +#include "mean.cuh" + +void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums(nrows, 1, 1); + reduce_rows_f32<<>>(src0_d, dst_d, ncols); +} diff --git a/ggml/src/ggml-cuda/mean.cuh b/ggml/src/ggml-cuda/mean.cuh new file mode 100644 index 00000000..2b9b1043 --- /dev/null +++ b/ggml/src/ggml-cuda/mean.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/sumrows.cu b/ggml/src/ggml-cuda/sumrows.cu index 38dbf1b5..2eee08fa 100644 --- a/ggml/src/ggml-cuda/sumrows.cu +++ b/ggml/src/ggml-cuda/sumrows.cu @@ -1,25 +1,9 @@ #include "sumrows.cuh" -static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) { - const int row = blockIdx.x; - const int col = threadIdx.x; - - float sum = 0.0f; - for (int i = col; i < ncols; i += blockDim.x) { - sum += x[row * ncols + i]; - } - - sum = warp_reduce_sum(sum); - - if (col == 0) { - dst[row] = sum; - } -} - void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_nums(nrows, 1, 1); - k_sum_rows_f32<<>>(x, dst, ncols); + reduce_rows_f32<<>>(x, dst, ncols); } void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -35,5 +19,8 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int64_t ncols = src0->ne[0]; const int64_t nrows = ggml_nrows(src0); - sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream); + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums(nrows, 1, 1); + + reduce_rows_f32<<>>(src0_d, dst_d, ncols); } diff --git a/ggml/src/ggml-cuda/sumrows.cuh b/ggml/src/ggml-cuda/sumrows.cuh index 191db1c1..3431c599 100644 --- a/ggml/src/ggml-cuda/sumrows.cuh +++ b/ggml/src/ggml-cuda/sumrows.cuh @@ -1,5 +1,4 @@ #include "common.cuh" void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream); - void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index 3b08f613..3f541b0c 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -167,81 +167,81 @@ class SpecialVocab: tokenizer_config['bos_token'] = special_bos = special_cls if not special_eos and special_sep and tokenizer_config: tokenizer_config['eos_token'] = special_eos = special_sep - post_processor = tokenizer.get('post_processor', {}) - for processor in post_processor.get('processors', [post_processor]): - if processor.get('type') == 'RobertaProcessing': - self.add_special_token['bos'] = True - self.add_special_token['eos'] = True - self.add_special_token['sep'] = True - if not special_cls and tokenizer_config: - special_cls = processor.get('cls', [special_bos])[0] - tokenizer_config['cls_token'] = special_cls - if not special_sep and tokenizer_config: - special_sep = processor.get('sep', [special_eos])[0] - tokenizer_config['sep_token'] = special_sep - continue - # Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added - # Only works with simple templates, **will** get it wrong on unusual sequences - if processor.get('type') == 'TemplateProcessing': - tmpl_single = processor.get('single', []) - tmpl_pair = processor.get('pair', []) - special_first = None - special_last = None - if len(tmpl_single) > 1: - if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'): - if not tokenizer_config: - special_bos = special_first - self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False - if special_first not in (special_bos, special_cls): - logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing') - if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'): - if not tokenizer_config: - special_eos = special_last - elif special_last != special_eos: - if 'eot' not in self.special_token_types: - self.special_token_types = tuple(self.special_token_types) + ('eot', ) - tokenizer_config['eot_token'] = special_eos - elif 'eom' not in self.special_token_types: - self.special_token_types = tuple(self.special_token_types) + ('eom', ) - tokenizer_config['eom_token'] = special_eos - else: - logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!') - tokenizer_config['eos_token'] = special_eos = special_last - self.add_special_token['eos'] = True if special_last == special_eos else False - if special_last != special_eos: - logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing') - if tmpl_pair: - seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0 - seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None - if (special_first and seq_start == 0) or (special_last and seq_stop is None): - logger.warning('TemplateProcessing leading/trailing special tokens do not match TemplateProcessing') - if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]: - tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id') - tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id') - if tmpl_a != 'A' or tmpl_b != 'B': - logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing') - # A [sep] [eos] B - if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]): - add_sep = False - if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'): - if special_entry in (special_sep, special_eos) and not special_last: - add_sep = True - if special_entry not in (special_sep, special_eos): - logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing') - else: - logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing') - if len(tmpl_pair) == 2: - if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'): - if special_entry in (special_sep, special_eos): + if post_processor := tokenizer.get('post_processor'): + for processor in post_processor.get('processors', [post_processor]): + if processor.get('type') == 'RobertaProcessing': + self.add_special_token['bos'] = True + self.add_special_token['eos'] = True + self.add_special_token['sep'] = True + if not special_cls and tokenizer_config: + special_cls = processor.get('cls', [special_bos])[0] + tokenizer_config['cls_token'] = special_cls + if not special_sep and tokenizer_config: + special_sep = processor.get('sep', [special_eos])[0] + tokenizer_config['sep_token'] = special_sep + continue + # Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added + # Only works with simple templates, **will** get it wrong on unusual sequences + if processor.get('type') == 'TemplateProcessing': + tmpl_single = processor.get('single', []) + tmpl_pair = processor.get('pair', []) + special_first = None + special_last = None + if len(tmpl_single) > 1: + if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'): + if not tokenizer_config: + special_bos = special_first + self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False + if special_first not in (special_bos, special_cls): + logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing') + if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'): + if not tokenizer_config: + special_eos = special_last + elif special_last != special_eos: + if 'eot' not in self.special_token_types: + self.special_token_types = tuple(self.special_token_types) + ('eot', ) + tokenizer_config['eot_token'] = special_eos + elif 'eom' not in self.special_token_types: + self.special_token_types = tuple(self.special_token_types) + ('eom', ) + tokenizer_config['eom_token'] = special_eos + else: + logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!') + tokenizer_config['eos_token'] = special_eos = special_last + self.add_special_token['eos'] = True if special_last == special_eos else False + if special_last != special_eos: + logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing') + if tmpl_pair: + seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0 + seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None + if (special_first and seq_start == 0) or (special_last and seq_stop is None): + logger.warning('TemplateProcessing leading/trailing special tokens do not match TemplateProcessing') + if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]: + tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id') + tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id') + if tmpl_a != 'A' or tmpl_b != 'B': + logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing') + # A [sep] [eos] B + if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]): + add_sep = False + if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'): + if special_entry in (special_sep, special_eos) and not special_last: add_sep = True if special_entry not in (special_sep, special_eos): - logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing') + logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing') else: - logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing') - self.add_special_token['sep'] = add_sep - if add_sep and not special_sep and tokenizer_config: - tokenizer_config['sep_token'] = special_eos - continue + logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing') + if len(tmpl_pair) == 2: + if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'): + if special_entry in (special_sep, special_eos): + add_sep = True + if special_entry not in (special_sep, special_eos): + logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing') + else: + logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing') + self.add_special_token['sep'] = add_sep + if add_sep and not special_sep and tokenizer_config: + tokenizer_config['sep_token'] = special_eos + continue if not tokenizer_config: return True chat_template_alt = None diff --git a/include/llama.h b/include/llama.h index b04720be..f4123d14 100644 --- a/include/llama.h +++ b/include/llama.h @@ -390,6 +390,7 @@ extern "C" { void * imatrix; // pointer to importance matrix data void * kv_overrides; // pointer to vector containing overrides void * tensor_types; // pointer to vector containing tensor types + void * prune_layers; // pointer to vector containing layer indices to prune } llama_model_quantize_params; typedef struct llama_logit_bias { diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 8cf45732..43229e19 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -1,5 +1,4 @@ #include "llama-quant.h" - #include "llama-impl.h" #include "llama-model.h" #include "llama-model-loader.h" @@ -27,6 +26,56 @@ static void zeros(std::ofstream & file, size_t n) { } } +static std::string remap_layer(const std::string & orig_name, const std::vector & prune, std::map & mapped, int & next_id) { + if (prune.empty()) { + return orig_name; + } + + static const std::regex pattern(R"(blk\.(\d+)\.)"); + if (std::smatch match; std::regex_search(orig_name, match, pattern)) { + const int blk = std::stoi(match[1]); + std::string new_name = orig_name; + + if (mapped.count(blk)) { + // Already mapped, do nothing + } else if (std::find(prune.begin(), prune.end(), blk) != prune.end()) { + mapped[blk] = ""; + } else if (blk < prune.front()) { + mapped[blk] = std::to_string(blk); + next_id = blk + 1; + } else { + mapped[blk] = std::to_string(next_id); + ++next_id; + } + + return mapped[blk].empty() ? mapped[blk] : new_name.replace(match.position(1), match.length(1), mapped[blk]); + } + + return orig_name; +} + +static std::string remap_imatrix (const std::string & orig_name, const std::map & mapped) { + if (mapped.empty()) { + return orig_name; + } + + static const std::regex pattern(R"(blk\.(\d+)\.)"); + if (std::smatch match; std::regex_search(orig_name, match, pattern)) { + const std::string blk(match[1]); + std::string new_name = orig_name; + + for (const auto & p : mapped) { + if (p.second == blk) { + LLAMA_LOG_DEBUG("(blk.%d imatrix) ", p.first); + return new_name.replace(match.position(1), match.length(1), std::to_string(p.first)); + } + } + GGML_ABORT("\n%s: imatrix mapping error for %s\n", __func__, orig_name.c_str()); + } + + return orig_name; +} + struct quantize_state_impl { const llama_model & model; const llama_model_quantize_params * params; @@ -568,6 +617,11 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: const size_t align = GGUF_DEFAULT_ALIGNMENT; gguf_context_ptr ctx_out { gguf_init_empty() }; + std::vector prune_list = {}; + if (params->prune_layers) { + prune_list = *static_cast *>(params->prune_layers); + } + // copy the KV pairs from the input file gguf_set_kv (ctx_out.get(), ml.meta.get()); gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV @@ -597,12 +651,32 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } } + std::map mapped; + int blk_id = 0; + int pruned_attention_w = 0; + // make a list of weights std::vector tensors; tensors.reserve(ml.weights_map.size()); for (const auto & it : ml.weights_map) { + const std::string remapped_name(remap_layer(it.first, prune_list, mapped, blk_id)); + if (remapped_name.empty()) { + if (it.first.find("attn_v.weight") != std::string::npos || + it.first.find("attn_qkv.weight") != std::string::npos || + it.first.find("attn_kv_b.weight") != std::string::npos) { + pruned_attention_w++; + } + LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str()); + continue; + } else if (remapped_name != it.first) { + ggml_set_name(it.second.tensor, remapped_name.c_str()); + LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor)); + } tensors.push_back(&it.second); } + if (!prune_list.empty()) { + gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_BLOCK_COUNT).c_str(), blk_id); + } // keep_split requires that the weights are sorted by split index if (params->keep_split) { @@ -640,7 +714,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: if (llama_model_has_encoder(&model)) { n_attn_layer *= 3; } - GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected"); + GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected"); } size_t total_size_org = 0; @@ -681,7 +755,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: for (size_t i = 0; i < ctx_outs.size(); ++i) { gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i); gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split); - gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors); + gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), (int32_t)tensors.size()); } } @@ -832,7 +906,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: const float * imatrix = nullptr; if (imatrix_data) { - auto it = imatrix_data->find(tensor->name); + auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped)); if (it == imatrix_data->end()) { LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name); } else { @@ -947,6 +1021,7 @@ llama_model_quantize_params llama_model_quantize_default_params() { /*.imatrix =*/ nullptr, /*.kv_overrides =*/ nullptr, /*.tensor_type =*/ nullptr, + /*.prune_layers =*/ nullptr }; return result; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 772bee34..7be7f220 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4652,6 +4652,8 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1)); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1})); + return test_cases; } diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 30283d6f..a990520e 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2211,6 +2211,9 @@ struct clip_model_loader { { hparams.rope_theta = 10000.0f; hparams.warmup_image_size = hparams.patch_size * 8; + // Mistral Small 2506 needs 1024x1024 image size cap to prevent OOM + // ref: https://github.com/ggml-org/llama.cpp/issues/14310 + hparams.image_size = 1024; get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false); } break; case PROJECTOR_TYPE_GEMMA3: diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index 3f54af7c..8acc7651 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -107,13 +107,11 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp return false; } -// usage: -// ./llama-quantize [--allow-requantize] [--leave-output-tensor] [--pure] models/llama/ggml-model.gguf [models/llama/ggml-model-quant.gguf] type [nthreads] -// [[noreturn]] static void usage(const char * executable) { - printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type]\n", executable); - printf(" [--token-embedding-type] [--tensor-type] [--keep-split] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n"); + printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights]\n", executable); + printf(" [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--tensor-type] [--prune-layers] [--keep-split] [--override-kv]\n"); + printf(" model-f32.gguf [model-quant.gguf] type [nthreads]\n\n"); printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n"); printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n"); printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n"); @@ -124,6 +122,8 @@ static void usage(const char * executable) { printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n"); printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. example: --tensor-type attn_q=q8_0\n"); printf(" Advanced option to selectively quantize tensors. May be specified multiple times.\n"); + printf(" --prune-layers L0,L1,L2...comma-separated list of layer numbers to prune from the model\n"); + printf(" Advanced option to remove all tensors from the given layers\n"); printf(" --keep-split: will generate quantized model in the same shards as input\n"); printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n"); @@ -286,6 +286,32 @@ static bool parse_tensor_type(const char * data, std::vector & prune_layers) { + if (!data) { + printf("\n%s: no layer pruning ids provided\n\n", __func__); + return false; + } + + const auto block_ids = string_split(data, ','); + for (const auto & block_id : block_ids) { + int id; + try { + id = std::stoi(block_id); + } catch (...) { + id = -1; + } + if (id < 0) { + printf("\n%s: invalid layer id '%s'\n\n", __func__, block_id.c_str()); + return false; + } + prune_layers.emplace_back(id); + } + + sort(prune_layers.begin(), prune_layers.end()); + prune_layers.erase(std::unique(prune_layers.begin(), prune_layers.end()), prune_layers.end()); + return true; +} + int main(int argc, char ** argv) { if (argc < 3) { usage(argv[0]); @@ -298,6 +324,7 @@ int main(int argc, char ** argv) { std::vector included_weights, excluded_weights; std::vector kv_overrides; std::vector tensor_types; + std::vector prune_layers; for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) { if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) { @@ -324,6 +351,10 @@ int main(int argc, char ** argv) { if (arg_idx == argc-1 || !parse_tensor_type(argv[++arg_idx], tensor_types)) { usage(argv[0]); } + } else if (strcmp(argv[arg_idx], "--prune-layers") == 0) { + if (arg_idx == argc-1 || !parse_layer_prune(argv[++arg_idx], prune_layers)) { + usage(argv[0]); + } } else if (strcmp(argv[arg_idx], "--override-kv") == 0) { if (arg_idx == argc-1 || !string_parse_kv_override(argv[++arg_idx], kv_overrides)) { usage(argv[0]); @@ -411,6 +442,9 @@ int main(int argc, char ** argv) { if (!tensor_types.empty()) { params.tensor_types = &tensor_types; } + if (!prune_layers.empty()) { + params.prune_layers = &prune_layers; + } llama_backend_init(); diff --git a/tools/run/run.cpp b/tools/run/run.cpp index c65afd61..6fe728c6 100644 --- a/tools/run/run.cpp +++ b/tools/run/run.cpp @@ -9,6 +9,9 @@ #include #if defined(_WIN32) +# ifndef NOMINMAX +# define NOMINMAX +# endif # include # include #else @@ -939,17 +942,30 @@ static int apply_chat_template(const struct common_chat_templates * tmpls, Llama // Function to tokenize the prompt static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt, std::vector & prompt_tokens, const LlamaData & llama_data) { - const bool is_first = llama_memory_seq_pos_max(llama_get_memory(llama_data.context.get()), 0) == 0; - - const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); - prompt_tokens.resize(n_prompt_tokens); - if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, - true) < 0) { - printe("failed to tokenize the prompt\n"); + const bool is_first = llama_memory_seq_pos_max(llama_get_memory(llama_data.context.get()), 0) == -1; + int n_tokens = prompt.size() + 2 * is_first; + prompt_tokens.resize(n_tokens); + n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.size(), + prompt_tokens.data(), prompt_tokens.size(), + is_first, /*parse_special =*/true); + if (n_tokens == std::numeric_limits::min()) { + printe("tokenization failed: input too large\n"); return -1; } - - return n_prompt_tokens; + if (n_tokens < 0) { + prompt_tokens.resize(-n_tokens); + int check = llama_tokenize(vocab, prompt.c_str(), prompt.size(), + prompt_tokens.data(), prompt_tokens.size(), + is_first, /*parse_special =*/true); + if (check != -n_tokens) { + printe("failed to tokenize the prompt (size mismatch)\n"); + return -1; + } + n_tokens = check; + } else { + prompt_tokens.resize(n_tokens); + } + return n_tokens; } // Check if we have enough space in the context to evaluate this batch