From aa064b2eb7f7779db5e094a9d8d66d5033557de2 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sun, 22 Jun 2025 12:39:54 +0800 Subject: [PATCH 01/10] CUDA: add mean operation (#14313) * CUDA: add mean operation * add back sum_rows_f32_cuda * Review: early exit if col!=0 --- ggml/src/ggml-cuda/common.cuh | 20 ++++++++++++++++++++ ggml/src/ggml-cuda/ggml-cuda.cu | 5 +++++ ggml/src/ggml-cuda/mean.cu | 19 +++++++++++++++++++ ggml/src/ggml-cuda/mean.cuh | 3 +++ ggml/src/ggml-cuda/sumrows.cu | 23 +++++------------------ ggml/src/ggml-cuda/sumrows.cuh | 1 - tests/test-backend-ops.cpp | 2 ++ 7 files changed, 54 insertions(+), 19 deletions(-) create mode 100644 ggml/src/ggml-cuda/mean.cu create mode 100644 ggml/src/ggml-cuda/mean.cuh diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 364efcae..2f2fce06 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -362,6 +362,26 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #endif // FP16_AVAILABLE } +// Row reduction kernel template - compute sum (norm=false) or mean (norm=true) +template +static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) { + const int row = blockIdx.x; + const int col = threadIdx.x; + + float sum = 0.0f; + for (int i = col; i < ncols; i += blockDim.x) { + sum += x[row * ncols + i]; + } + + sum = warp_reduce_sum(sum); + + if (col != 0) { + return; + } + + dst[row] = norm ? sum / ncols : sum; +} + template static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 5bab92e3..c6bdd4fb 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -37,6 +37,7 @@ #include "ggml-cuda/ssm-scan.cuh" #include "ggml-cuda/sum.cuh" #include "ggml-cuda/sumrows.cuh" +#include "ggml-cuda/mean.cuh" #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" @@ -2357,6 +2358,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SUM_ROWS: ggml_cuda_op_sum_rows(ctx, dst); break; + case GGML_OP_MEAN: + ggml_cuda_op_mean(ctx, dst); + break; case GGML_OP_SSM_CONV: ggml_cuda_op_ssm_conv(ctx, dst); break; @@ -3260,6 +3264,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_POOL_2D: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: case GGML_OP_ARGSORT: case GGML_OP_ACC: return true; diff --git a/ggml/src/ggml-cuda/mean.cu b/ggml/src/ggml-cuda/mean.cu new file mode 100644 index 00000000..4b238a39 --- /dev/null +++ b/ggml/src/ggml-cuda/mean.cu @@ -0,0 +1,19 @@ +#include "mean.cuh" + +void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums(nrows, 1, 1); + reduce_rows_f32<<>>(src0_d, dst_d, ncols); +} diff --git a/ggml/src/ggml-cuda/mean.cuh b/ggml/src/ggml-cuda/mean.cuh new file mode 100644 index 00000000..2b9b1043 --- /dev/null +++ b/ggml/src/ggml-cuda/mean.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/sumrows.cu b/ggml/src/ggml-cuda/sumrows.cu index 38dbf1b5..2eee08fa 100644 --- a/ggml/src/ggml-cuda/sumrows.cu +++ b/ggml/src/ggml-cuda/sumrows.cu @@ -1,25 +1,9 @@ #include "sumrows.cuh" -static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) { - const int row = blockIdx.x; - const int col = threadIdx.x; - - float sum = 0.0f; - for (int i = col; i < ncols; i += blockDim.x) { - sum += x[row * ncols + i]; - } - - sum = warp_reduce_sum(sum); - - if (col == 0) { - dst[row] = sum; - } -} - void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_nums(nrows, 1, 1); - k_sum_rows_f32<<>>(x, dst, ncols); + reduce_rows_f32<<>>(x, dst, ncols); } void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -35,5 +19,8 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int64_t ncols = src0->ne[0]; const int64_t nrows = ggml_nrows(src0); - sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream); + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums(nrows, 1, 1); + + reduce_rows_f32<<>>(src0_d, dst_d, ncols); } diff --git a/ggml/src/ggml-cuda/sumrows.cuh b/ggml/src/ggml-cuda/sumrows.cuh index 191db1c1..3431c599 100644 --- a/ggml/src/ggml-cuda/sumrows.cuh +++ b/ggml/src/ggml-cuda/sumrows.cuh @@ -1,5 +1,4 @@ #include "common.cuh" void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream); - void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 772bee34..7be7f220 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4652,6 +4652,8 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1)); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1})); + return test_cases; } From 40bfa04c95c19fb42bafd4e21b5c2a7771846801 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sun, 22 Jun 2025 07:37:43 +0200 Subject: [PATCH 02/10] common : use std::string_view now that we target c++17 (#14319) --- common/json-schema-to-grammar.cpp | 49 ++----------------------------- 1 file changed, 3 insertions(+), 46 deletions(-) diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index d38a74f9..637891f5 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -41,49 +41,6 @@ static std::string build_repetition(const std::string & item_rule, int min_items return result; } -/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */ -class string_view { - const std::string & _str; - const size_t _start; - const size_t _end; -public: - string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {} - - size_t size() const { - return _end - _start; - } - - size_t length() const { - return size(); - } - - operator std::string() const { - return str(); - } - - std::string str() const { - return _str.substr(_start, _end - _start); - } - - string_view substr(size_t pos, size_t len = std::string::npos) const { - return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len); - } - - char operator[](size_t pos) const { - auto index = _start + pos; - if (index >= _end) { - throw std::out_of_range("string_view index out of range"); - } - return _str[_start + pos]; - } - - bool operator==(const string_view & other) const { - std::string this_str = *this; - std::string other_str = other; - return this_str == other_str; - } -}; - static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) { auto has_min = min_value != std::numeric_limits::min(); auto has_max = max_value != std::numeric_limits::max(); @@ -112,14 +69,14 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & } out << "}"; }; - std::function uniform_range = - [&](const string_view & from, const string_view & to) { + std::function uniform_range = + [&](const std::string_view & from, const std::string_view & to) { size_t i = 0; while (i < from.length() && i < to.length() && from[i] == to[i]) { i++; } if (i > 0) { - out << "\"" << from.substr(0, i).str() << "\""; + out << "\"" << from.substr(0, i) << "\""; } if (i < from.length() && i < to.length()) { if (i > 0) { From 5d5c066de8a3d2cb32f04c4d5ad1560945f30bf3 Mon Sep 17 00:00:00 2001 From: yuiseki Date: Sun, 22 Jun 2025 21:44:57 +0900 Subject: [PATCH 03/10] mtmd : fix Pixtral OOM with large images by capping image_size to 1024 (#14326) Mistral Small 2506 models using Pixtral vision encoder were running out of GPU memory when processing images larger than 1024x1024 pixels due to exponential memory growth from unlimited image size. This fix applies the same 1024x1024 limit used by Qwen2VL models to prevent OOM issues while maintaining compatibility with existing models. --- tools/mtmd/clip.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 30283d6f..a990520e 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2211,6 +2211,9 @@ struct clip_model_loader { { hparams.rope_theta = 10000.0f; hparams.warmup_image_size = hparams.patch_size * 8; + // Mistral Small 2506 needs 1024x1024 image size cap to prevent OOM + // ref: https://github.com/ggml-org/llama.cpp/issues/14310 + hparams.image_size = 1024; get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false); } break; case PROJECTOR_TYPE_GEMMA3: From af3373f1adfca56119f3e4de0e6a0a8df8edf3d9 Mon Sep 17 00:00:00 2001 From: uvos Date: Sun, 22 Jun 2025 16:51:23 +0200 Subject: [PATCH 04/10] HIP: enable vec fattn on RDNA4 (#14323) --- ggml/src/ggml-cuda/common.cuh | 14 ++++++++++++-- ggml/src/ggml-cuda/ggml-cuda.cu | 7 ++----- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 2f2fce06..86c4d29a 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -241,8 +241,18 @@ static bool fp16_mma_available(const int cc) { #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) return false; #else - return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) || - GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc); + if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) || + GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) { + return true; + } else if (GGML_CUDA_CC_IS_RDNA4(cc)) { +#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12) + return true; +#else + return false; +#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12) + } else { + return false; + } #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c6bdd4fb..462db71e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -100,8 +100,7 @@ int ggml_cuda_get_device() { static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { ggml_cuda_set_device(device); cudaError_t err; - if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) - { + if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) { err = cudaMallocManaged(ptr, size); #if defined(GGML_USE_HIP) if (err == hipSuccess) { @@ -119,9 +118,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) err = cudaMalloc(ptr, size); } #endif // defined(GGML_USE_HIP) - } - else - { + } else { err = cudaMalloc(ptr, size); } return err; From f1f5e82df6222dcbaca6396c0de44df259a1694f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 22 Jun 2025 20:10:07 +0300 Subject: [PATCH 05/10] examples : fix is_first logic for tokenization (#14329) ggml-ci --- examples/simple-chat/simple-chat.cpp | 2 +- tools/run/run.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index 2aee0a91..cf117804 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -98,7 +98,7 @@ int main(int argc, char ** argv) { auto generate = [&](const std::string & prompt) { std::string response; - const bool is_first = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) == 0; + const bool is_first = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) == -1; // tokenize the prompt const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); diff --git a/tools/run/run.cpp b/tools/run/run.cpp index c65afd61..b8556515 100644 --- a/tools/run/run.cpp +++ b/tools/run/run.cpp @@ -939,7 +939,7 @@ static int apply_chat_template(const struct common_chat_templates * tmpls, Llama // Function to tokenize the prompt static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt, std::vector & prompt_tokens, const LlamaData & llama_data) { - const bool is_first = llama_memory_seq_pos_max(llama_get_memory(llama_data.context.get()), 0) == 0; + const bool is_first = llama_memory_seq_pos_max(llama_get_memory(llama_data.context.get()), 0) == -1; const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); prompt_tokens.resize(n_prompt_tokens); From 66aba7aca9a245d71f1ddf02c7a97223529752a8 Mon Sep 17 00:00:00 2001 From: Ruikai Peng Date: Mon, 23 Jun 2025 01:28:06 +0800 Subject: [PATCH 06/10] run : avoid double tokenization (#14327) * run : avoid double tokenization by adopting common_tokenize heuristic * build : fix windows gcc and clang warnings * lint : fixed trailing whitepace * run : fix is_first flag --- tools/run/run.cpp | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/tools/run/run.cpp b/tools/run/run.cpp index b8556515..6fe728c6 100644 --- a/tools/run/run.cpp +++ b/tools/run/run.cpp @@ -9,6 +9,9 @@ #include #if defined(_WIN32) +# ifndef NOMINMAX +# define NOMINMAX +# endif # include # include #else @@ -940,16 +943,29 @@ static int apply_chat_template(const struct common_chat_templates * tmpls, Llama static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt, std::vector & prompt_tokens, const LlamaData & llama_data) { const bool is_first = llama_memory_seq_pos_max(llama_get_memory(llama_data.context.get()), 0) == -1; - - const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); - prompt_tokens.resize(n_prompt_tokens); - if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, - true) < 0) { - printe("failed to tokenize the prompt\n"); + int n_tokens = prompt.size() + 2 * is_first; + prompt_tokens.resize(n_tokens); + n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.size(), + prompt_tokens.data(), prompt_tokens.size(), + is_first, /*parse_special =*/true); + if (n_tokens == std::numeric_limits::min()) { + printe("tokenization failed: input too large\n"); return -1; } - - return n_prompt_tokens; + if (n_tokens < 0) { + prompt_tokens.resize(-n_tokens); + int check = llama_tokenize(vocab, prompt.c_str(), prompt.size(), + prompt_tokens.data(), prompt_tokens.size(), + is_first, /*parse_special =*/true); + if (check != -n_tokens) { + printe("failed to tokenize the prompt (size mismatch)\n"); + return -1; + } + n_tokens = check; + } else { + prompt_tokens.resize(n_tokens); + } + return n_tokens; } // Check if we have enough space in the context to evaluate this batch From 238005c2dc67426cf678baa2d54c881701693288 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sun, 22 Jun 2025 19:46:17 +0200 Subject: [PATCH 07/10] gguf-py : fix SpecialVocab parsing when post_processor is null (#14330) --- gguf-py/gguf/vocab.py | 144 +++++++++++++++++++++--------------------- 1 file changed, 72 insertions(+), 72 deletions(-) diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index 3b08f613..3f541b0c 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -167,81 +167,81 @@ class SpecialVocab: tokenizer_config['bos_token'] = special_bos = special_cls if not special_eos and special_sep and tokenizer_config: tokenizer_config['eos_token'] = special_eos = special_sep - post_processor = tokenizer.get('post_processor', {}) - for processor in post_processor.get('processors', [post_processor]): - if processor.get('type') == 'RobertaProcessing': - self.add_special_token['bos'] = True - self.add_special_token['eos'] = True - self.add_special_token['sep'] = True - if not special_cls and tokenizer_config: - special_cls = processor.get('cls', [special_bos])[0] - tokenizer_config['cls_token'] = special_cls - if not special_sep and tokenizer_config: - special_sep = processor.get('sep', [special_eos])[0] - tokenizer_config['sep_token'] = special_sep - continue - # Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added - # Only works with simple templates, **will** get it wrong on unusual sequences - if processor.get('type') == 'TemplateProcessing': - tmpl_single = processor.get('single', []) - tmpl_pair = processor.get('pair', []) - special_first = None - special_last = None - if len(tmpl_single) > 1: - if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'): - if not tokenizer_config: - special_bos = special_first - self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False - if special_first not in (special_bos, special_cls): - logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing') - if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'): - if not tokenizer_config: - special_eos = special_last - elif special_last != special_eos: - if 'eot' not in self.special_token_types: - self.special_token_types = tuple(self.special_token_types) + ('eot', ) - tokenizer_config['eot_token'] = special_eos - elif 'eom' not in self.special_token_types: - self.special_token_types = tuple(self.special_token_types) + ('eom', ) - tokenizer_config['eom_token'] = special_eos - else: - logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!') - tokenizer_config['eos_token'] = special_eos = special_last - self.add_special_token['eos'] = True if special_last == special_eos else False - if special_last != special_eos: - logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing') - if tmpl_pair: - seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0 - seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None - if (special_first and seq_start == 0) or (special_last and seq_stop is None): - logger.warning('TemplateProcessing leading/trailing special tokens do not match TemplateProcessing') - if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]: - tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id') - tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id') - if tmpl_a != 'A' or tmpl_b != 'B': - logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing') - # A [sep] [eos] B - if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]): - add_sep = False - if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'): - if special_entry in (special_sep, special_eos) and not special_last: - add_sep = True - if special_entry not in (special_sep, special_eos): - logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing') - else: - logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing') - if len(tmpl_pair) == 2: - if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'): - if special_entry in (special_sep, special_eos): + if post_processor := tokenizer.get('post_processor'): + for processor in post_processor.get('processors', [post_processor]): + if processor.get('type') == 'RobertaProcessing': + self.add_special_token['bos'] = True + self.add_special_token['eos'] = True + self.add_special_token['sep'] = True + if not special_cls and tokenizer_config: + special_cls = processor.get('cls', [special_bos])[0] + tokenizer_config['cls_token'] = special_cls + if not special_sep and tokenizer_config: + special_sep = processor.get('sep', [special_eos])[0] + tokenizer_config['sep_token'] = special_sep + continue + # Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added + # Only works with simple templates, **will** get it wrong on unusual sequences + if processor.get('type') == 'TemplateProcessing': + tmpl_single = processor.get('single', []) + tmpl_pair = processor.get('pair', []) + special_first = None + special_last = None + if len(tmpl_single) > 1: + if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'): + if not tokenizer_config: + special_bos = special_first + self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False + if special_first not in (special_bos, special_cls): + logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing') + if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'): + if not tokenizer_config: + special_eos = special_last + elif special_last != special_eos: + if 'eot' not in self.special_token_types: + self.special_token_types = tuple(self.special_token_types) + ('eot', ) + tokenizer_config['eot_token'] = special_eos + elif 'eom' not in self.special_token_types: + self.special_token_types = tuple(self.special_token_types) + ('eom', ) + tokenizer_config['eom_token'] = special_eos + else: + logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!') + tokenizer_config['eos_token'] = special_eos = special_last + self.add_special_token['eos'] = True if special_last == special_eos else False + if special_last != special_eos: + logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing') + if tmpl_pair: + seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0 + seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None + if (special_first and seq_start == 0) or (special_last and seq_stop is None): + logger.warning('TemplateProcessing leading/trailing special tokens do not match TemplateProcessing') + if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]: + tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id') + tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id') + if tmpl_a != 'A' or tmpl_b != 'B': + logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing') + # A [sep] [eos] B + if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]): + add_sep = False + if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'): + if special_entry in (special_sep, special_eos) and not special_last: add_sep = True if special_entry not in (special_sep, special_eos): - logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing') + logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing') else: - logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing') - self.add_special_token['sep'] = add_sep - if add_sep and not special_sep and tokenizer_config: - tokenizer_config['sep_token'] = special_eos - continue + logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing') + if len(tmpl_pair) == 2: + if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'): + if special_entry in (special_sep, special_eos): + add_sep = True + if special_entry not in (special_sep, special_eos): + logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing') + else: + logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing') + self.add_special_token['sep'] = add_sep + if add_sep and not special_sep and tokenizer_config: + tokenizer_config['sep_token'] = special_eos + continue if not tokenizer_config: return True chat_template_alt = None From fa4a9f2a1ccda2573189a9d4995bdf0bceb41156 Mon Sep 17 00:00:00 2001 From: Ed Addario <29247825+EAddario@users.noreply.github.com> Date: Sun, 22 Jun 2025 22:16:26 +0100 Subject: [PATCH 08/10] quantize : handle user-defined pruning of whole layers (blocks) (#13037) --- include/llama.h | 1 + src/llama-quant.cpp | 83 +++++++++++++++++++++++++++++++++++-- tools/quantize/quantize.cpp | 44 +++++++++++++++++--- 3 files changed, 119 insertions(+), 9 deletions(-) diff --git a/include/llama.h b/include/llama.h index b04720be..f4123d14 100644 --- a/include/llama.h +++ b/include/llama.h @@ -390,6 +390,7 @@ extern "C" { void * imatrix; // pointer to importance matrix data void * kv_overrides; // pointer to vector containing overrides void * tensor_types; // pointer to vector containing tensor types + void * prune_layers; // pointer to vector containing layer indices to prune } llama_model_quantize_params; typedef struct llama_logit_bias { diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 8cf45732..43229e19 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -1,5 +1,4 @@ #include "llama-quant.h" - #include "llama-impl.h" #include "llama-model.h" #include "llama-model-loader.h" @@ -27,6 +26,56 @@ static void zeros(std::ofstream & file, size_t n) { } } +static std::string remap_layer(const std::string & orig_name, const std::vector & prune, std::map & mapped, int & next_id) { + if (prune.empty()) { + return orig_name; + } + + static const std::regex pattern(R"(blk\.(\d+)\.)"); + if (std::smatch match; std::regex_search(orig_name, match, pattern)) { + const int blk = std::stoi(match[1]); + std::string new_name = orig_name; + + if (mapped.count(blk)) { + // Already mapped, do nothing + } else if (std::find(prune.begin(), prune.end(), blk) != prune.end()) { + mapped[blk] = ""; + } else if (blk < prune.front()) { + mapped[blk] = std::to_string(blk); + next_id = blk + 1; + } else { + mapped[blk] = std::to_string(next_id); + ++next_id; + } + + return mapped[blk].empty() ? mapped[blk] : new_name.replace(match.position(1), match.length(1), mapped[blk]); + } + + return orig_name; +} + +static std::string remap_imatrix (const std::string & orig_name, const std::map & mapped) { + if (mapped.empty()) { + return orig_name; + } + + static const std::regex pattern(R"(blk\.(\d+)\.)"); + if (std::smatch match; std::regex_search(orig_name, match, pattern)) { + const std::string blk(match[1]); + std::string new_name = orig_name; + + for (const auto & p : mapped) { + if (p.second == blk) { + LLAMA_LOG_DEBUG("(blk.%d imatrix) ", p.first); + return new_name.replace(match.position(1), match.length(1), std::to_string(p.first)); + } + } + GGML_ABORT("\n%s: imatrix mapping error for %s\n", __func__, orig_name.c_str()); + } + + return orig_name; +} + struct quantize_state_impl { const llama_model & model; const llama_model_quantize_params * params; @@ -568,6 +617,11 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: const size_t align = GGUF_DEFAULT_ALIGNMENT; gguf_context_ptr ctx_out { gguf_init_empty() }; + std::vector prune_list = {}; + if (params->prune_layers) { + prune_list = *static_cast *>(params->prune_layers); + } + // copy the KV pairs from the input file gguf_set_kv (ctx_out.get(), ml.meta.get()); gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV @@ -597,12 +651,32 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } } + std::map mapped; + int blk_id = 0; + int pruned_attention_w = 0; + // make a list of weights std::vector tensors; tensors.reserve(ml.weights_map.size()); for (const auto & it : ml.weights_map) { + const std::string remapped_name(remap_layer(it.first, prune_list, mapped, blk_id)); + if (remapped_name.empty()) { + if (it.first.find("attn_v.weight") != std::string::npos || + it.first.find("attn_qkv.weight") != std::string::npos || + it.first.find("attn_kv_b.weight") != std::string::npos) { + pruned_attention_w++; + } + LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str()); + continue; + } else if (remapped_name != it.first) { + ggml_set_name(it.second.tensor, remapped_name.c_str()); + LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor)); + } tensors.push_back(&it.second); } + if (!prune_list.empty()) { + gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_BLOCK_COUNT).c_str(), blk_id); + } // keep_split requires that the weights are sorted by split index if (params->keep_split) { @@ -640,7 +714,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: if (llama_model_has_encoder(&model)) { n_attn_layer *= 3; } - GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected"); + GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected"); } size_t total_size_org = 0; @@ -681,7 +755,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: for (size_t i = 0; i < ctx_outs.size(); ++i) { gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i); gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split); - gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors); + gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), (int32_t)tensors.size()); } } @@ -832,7 +906,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: const float * imatrix = nullptr; if (imatrix_data) { - auto it = imatrix_data->find(tensor->name); + auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped)); if (it == imatrix_data->end()) { LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name); } else { @@ -947,6 +1021,7 @@ llama_model_quantize_params llama_model_quantize_default_params() { /*.imatrix =*/ nullptr, /*.kv_overrides =*/ nullptr, /*.tensor_type =*/ nullptr, + /*.prune_layers =*/ nullptr }; return result; diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index 3f54af7c..8acc7651 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -107,13 +107,11 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp return false; } -// usage: -// ./llama-quantize [--allow-requantize] [--leave-output-tensor] [--pure] models/llama/ggml-model.gguf [models/llama/ggml-model-quant.gguf] type [nthreads] -// [[noreturn]] static void usage(const char * executable) { - printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type]\n", executable); - printf(" [--token-embedding-type] [--tensor-type] [--keep-split] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n"); + printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights]\n", executable); + printf(" [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--tensor-type] [--prune-layers] [--keep-split] [--override-kv]\n"); + printf(" model-f32.gguf [model-quant.gguf] type [nthreads]\n\n"); printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n"); printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n"); printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n"); @@ -124,6 +122,8 @@ static void usage(const char * executable) { printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n"); printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. example: --tensor-type attn_q=q8_0\n"); printf(" Advanced option to selectively quantize tensors. May be specified multiple times.\n"); + printf(" --prune-layers L0,L1,L2...comma-separated list of layer numbers to prune from the model\n"); + printf(" Advanced option to remove all tensors from the given layers\n"); printf(" --keep-split: will generate quantized model in the same shards as input\n"); printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n"); @@ -286,6 +286,32 @@ static bool parse_tensor_type(const char * data, std::vector & prune_layers) { + if (!data) { + printf("\n%s: no layer pruning ids provided\n\n", __func__); + return false; + } + + const auto block_ids = string_split(data, ','); + for (const auto & block_id : block_ids) { + int id; + try { + id = std::stoi(block_id); + } catch (...) { + id = -1; + } + if (id < 0) { + printf("\n%s: invalid layer id '%s'\n\n", __func__, block_id.c_str()); + return false; + } + prune_layers.emplace_back(id); + } + + sort(prune_layers.begin(), prune_layers.end()); + prune_layers.erase(std::unique(prune_layers.begin(), prune_layers.end()), prune_layers.end()); + return true; +} + int main(int argc, char ** argv) { if (argc < 3) { usage(argv[0]); @@ -298,6 +324,7 @@ int main(int argc, char ** argv) { std::vector included_weights, excluded_weights; std::vector kv_overrides; std::vector tensor_types; + std::vector prune_layers; for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) { if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) { @@ -324,6 +351,10 @@ int main(int argc, char ** argv) { if (arg_idx == argc-1 || !parse_tensor_type(argv[++arg_idx], tensor_types)) { usage(argv[0]); } + } else if (strcmp(argv[arg_idx], "--prune-layers") == 0) { + if (arg_idx == argc-1 || !parse_layer_prune(argv[++arg_idx], prune_layers)) { + usage(argv[0]); + } } else if (strcmp(argv[arg_idx], "--override-kv") == 0) { if (arg_idx == argc-1 || !string_parse_kv_override(argv[++arg_idx], kv_overrides)) { usage(argv[0]); @@ -411,6 +442,9 @@ int main(int argc, char ** argv) { if (!tensor_types.empty()) { params.tensor_types = &tensor_types; } + if (!prune_layers.empty()) { + params.prune_layers = &prune_layers; + } llama_backend_init(); From 3a9457df962b5883f2773f1c295e8c19df60d89f Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Mon, 23 Jun 2025 03:19:24 -0500 Subject: [PATCH 09/10] vulkan: update windows SDK in CI (#14334) --- .github/workflows/build.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c4783a6d..be282897 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -683,7 +683,7 @@ jobs: env: OPENBLAS_VERSION: 0.3.23 SDE_VERSION: 9.33.0-2024-01-07 - VULKAN_VERSION: 1.4.309.0 + VULKAN_VERSION: 1.4.313.2 strategy: matrix: @@ -736,7 +736,7 @@ jobs: id: get_vulkan if: ${{ matrix.build == 'kompute-x64' || matrix.build == 'vulkan-x64' }} run: | - curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/VulkanSDK-${env:VULKAN_VERSION}-Installer.exe" + curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/vulkansdk-windows-X64-${env:VULKAN_VERSION}.exe" & "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install Add-Content $env:GITHUB_ENV "VULKAN_SDK=C:\VulkanSDK\${env:VULKAN_VERSION}" Add-Content $env:GITHUB_PATH "C:\VulkanSDK\${env:VULKAN_VERSION}\bin" From ccc4c16970284eba8120f6f07ff151dcf2977f90 Mon Sep 17 00:00:00 2001 From: ver4a Date: Sat, 10 May 2025 01:01:02 +0200 Subject: [PATCH 10/10] Update rocm build --- .devops/rocm.Dockerfile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.devops/rocm.Dockerfile b/.devops/rocm.Dockerfile index 1c00f1b9..f7f5b3cd 100644 --- a/.devops/rocm.Dockerfile +++ b/.devops/rocm.Dockerfile @@ -1,11 +1,11 @@ ARG UBUNTU_VERSION=24.04 # This needs to generally match the container host's environment. -ARG ROCM_VERSION=6.3 -ARG AMDGPU_VERSION=6.3 +ARG ROCM_VERSION=6.3.4 +ARG AMDGPU_VERSION=6.3.4 # Target the CUDA build image -ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete +ARG BASE_ROCM_DEV_CONTAINER=docker.io/rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete ### Build image FROM ${BASE_ROCM_DEV_CONTAINER} AS build