Compare commits

..

10 commits

Author SHA1 Message Date
ccc4c16970 Update rocm build
Some checks failed
EditorConfig Checker / editorconfig (push) Has been cancelled
2025-06-23 11:21:23 +02:00
Jeff Bolz
3a9457df96
vulkan: update windows SDK in CI (#14334) 2025-06-23 10:19:24 +02:00
Ed Addario
fa4a9f2a1c
quantize : handle user-defined pruning of whole layers (blocks) (#13037) 2025-06-22 23:16:26 +02:00
Sigbjørn Skjæret
238005c2dc
gguf-py : fix SpecialVocab parsing when post_processor is null (#14330) 2025-06-22 19:46:17 +02:00
Ruikai Peng
66aba7aca9
run : avoid double tokenization (#14327)
* run : avoid double tokenization by adopting common_tokenize heuristic

* build : fix windows gcc and clang warnings

* lint : fixed trailing whitepace

* run : fix is_first flag
2025-06-23 01:28:06 +08:00
Georgi Gerganov
f1f5e82df6
examples : fix is_first logic for tokenization (#14329)
ggml-ci
2025-06-22 20:10:07 +03:00
uvos
af3373f1ad
HIP: enable vec fattn on RDNA4 (#14323) 2025-06-22 16:51:23 +02:00
yuiseki
5d5c066de8
mtmd : fix Pixtral OOM with large images by capping image_size to 1024 (#14326)
Mistral Small 2506 models using Pixtral vision encoder were running out
of GPU memory when processing images larger than 1024x1024 pixels due to
exponential memory growth from unlimited image size.

This fix applies the same 1024x1024 limit used by Qwen2VL models to
prevent OOM issues while maintaining compatibility with existing models.
2025-06-22 14:44:57 +02:00
Sigbjørn Skjæret
40bfa04c95
common : use std::string_view now that we target c++17 (#14319) 2025-06-22 08:37:43 +03:00
Aman Gupta
aa064b2eb7
CUDA: add mean operation (#14313)
* CUDA: add mean operation

* add back sum_rows_f32_cuda

* Review: early exit if col!=0
2025-06-22 12:39:54 +08:00
17 changed files with 296 additions and 168 deletions

View file

@ -1,11 +1,11 @@
ARG UBUNTU_VERSION=24.04
# This needs to generally match the container host's environment.
ARG ROCM_VERSION=6.3
ARG AMDGPU_VERSION=6.3
ARG ROCM_VERSION=6.3.4
ARG AMDGPU_VERSION=6.3.4
# Target the CUDA build image
ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
ARG BASE_ROCM_DEV_CONTAINER=docker.io/rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
### Build image
FROM ${BASE_ROCM_DEV_CONTAINER} AS build

View file

@ -683,7 +683,7 @@ jobs:
env:
OPENBLAS_VERSION: 0.3.23
SDE_VERSION: 9.33.0-2024-01-07
VULKAN_VERSION: 1.4.309.0
VULKAN_VERSION: 1.4.313.2
strategy:
matrix:
@ -736,7 +736,7 @@ jobs:
id: get_vulkan
if: ${{ matrix.build == 'kompute-x64' || matrix.build == 'vulkan-x64' }}
run: |
curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/VulkanSDK-${env:VULKAN_VERSION}-Installer.exe"
curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/vulkansdk-windows-X64-${env:VULKAN_VERSION}.exe"
& "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install
Add-Content $env:GITHUB_ENV "VULKAN_SDK=C:\VulkanSDK\${env:VULKAN_VERSION}"
Add-Content $env:GITHUB_PATH "C:\VulkanSDK\${env:VULKAN_VERSION}\bin"

View file

@ -41,49 +41,6 @@ static std::string build_repetition(const std::string & item_rule, int min_items
return result;
}
/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */
class string_view {
const std::string & _str;
const size_t _start;
const size_t _end;
public:
string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {}
size_t size() const {
return _end - _start;
}
size_t length() const {
return size();
}
operator std::string() const {
return str();
}
std::string str() const {
return _str.substr(_start, _end - _start);
}
string_view substr(size_t pos, size_t len = std::string::npos) const {
return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len);
}
char operator[](size_t pos) const {
auto index = _start + pos;
if (index >= _end) {
throw std::out_of_range("string_view index out of range");
}
return _str[_start + pos];
}
bool operator==(const string_view & other) const {
std::string this_str = *this;
std::string other_str = other;
return this_str == other_str;
}
};
static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
auto has_min = min_value != std::numeric_limits<int>::min();
auto has_max = max_value != std::numeric_limits<int>::max();
@ -112,14 +69,14 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
}
out << "}";
};
std::function<void(const string_view &, const string_view &)> uniform_range =
[&](const string_view & from, const string_view & to) {
std::function<void(const std::string_view &, const std::string_view &)> uniform_range =
[&](const std::string_view & from, const std::string_view & to) {
size_t i = 0;
while (i < from.length() && i < to.length() && from[i] == to[i]) {
i++;
}
if (i > 0) {
out << "\"" << from.substr(0, i).str() << "\"";
out << "\"" << from.substr(0, i) << "\"";
}
if (i < from.length() && i < to.length()) {
if (i > 0) {

View file

@ -98,7 +98,7 @@ int main(int argc, char ** argv) {
auto generate = [&](const std::string & prompt) {
std::string response;
const bool is_first = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) == 0;
const bool is_first = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) == -1;
// tokenize the prompt
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);

View file

@ -241,8 +241,18 @@ static bool fp16_mma_available(const int cc) {
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
return false;
#else
return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
return true;
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
return true;
#else
return false;
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
} else {
return false;
}
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
}
@ -362,6 +372,26 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#endif // FP16_AVAILABLE
}
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
template<bool norm>
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x;
const int col = threadIdx.x;
float sum = 0.0f;
for (int i = col; i < ncols; i += blockDim.x) {
sum += x[row * ncols + i];
}
sum = warp_reduce_sum(sum);
if (col != 0) {
return;
}
dst[row] = norm ? sum / ncols : sum;
}
template<int width = WARP_SIZE>
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll

View file

@ -37,6 +37,7 @@
#include "ggml-cuda/ssm-scan.cuh"
#include "ggml-cuda/sum.cuh"
#include "ggml-cuda/sumrows.cuh"
#include "ggml-cuda/mean.cuh"
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
@ -99,8 +100,7 @@ int ggml_cuda_get_device() {
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
ggml_cuda_set_device(device);
cudaError_t err;
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
{
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) {
err = cudaMallocManaged(ptr, size);
#if defined(GGML_USE_HIP)
if (err == hipSuccess) {
@ -118,9 +118,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
err = cudaMalloc(ptr, size);
}
#endif // defined(GGML_USE_HIP)
}
else
{
} else {
err = cudaMalloc(ptr, size);
}
return err;
@ -2357,6 +2355,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SUM_ROWS:
ggml_cuda_op_sum_rows(ctx, dst);
break;
case GGML_OP_MEAN:
ggml_cuda_op_mean(ctx, dst);
break;
case GGML_OP_SSM_CONV:
ggml_cuda_op_ssm_conv(ctx, dst);
break;
@ -3260,6 +3261,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_POOL_2D:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_ARGSORT:
case GGML_OP_ACC:
return true;

View file

@ -0,0 +1,19 @@
#include "mean.cuh"
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
}

View file

@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View file

@ -1,25 +1,9 @@
#include "sumrows.cuh"
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x;
const int col = threadIdx.x;
float sum = 0.0f;
for (int i = col; i < ncols; i += blockDim.x) {
sum += x[row * ncols + i];
}
sum = warp_reduce_sum(sum);
if (col == 0) {
dst[row] = sum;
}
}
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
}
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@ -35,5 +19,8 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
}

View file

@ -1,5 +1,4 @@
#include "common.cuh"
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View file

@ -167,81 +167,81 @@ class SpecialVocab:
tokenizer_config['bos_token'] = special_bos = special_cls
if not special_eos and special_sep and tokenizer_config:
tokenizer_config['eos_token'] = special_eos = special_sep
post_processor = tokenizer.get('post_processor', {})
for processor in post_processor.get('processors', [post_processor]):
if processor.get('type') == 'RobertaProcessing':
self.add_special_token['bos'] = True
self.add_special_token['eos'] = True
self.add_special_token['sep'] = True
if not special_cls and tokenizer_config:
special_cls = processor.get('cls', [special_bos])[0]
tokenizer_config['cls_token'] = special_cls
if not special_sep and tokenizer_config:
special_sep = processor.get('sep', [special_eos])[0]
tokenizer_config['sep_token'] = special_sep
continue
# Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
# Only works with simple templates, **will** get it wrong on unusual sequences
if processor.get('type') == 'TemplateProcessing':
tmpl_single = processor.get('single', [])
tmpl_pair = processor.get('pair', [])
special_first = None
special_last = None
if len(tmpl_single) > 1:
if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
if not tokenizer_config:
special_bos = special_first
self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
if special_first not in (special_bos, special_cls):
logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
if not tokenizer_config:
special_eos = special_last
elif special_last != special_eos:
if 'eot' not in self.special_token_types:
self.special_token_types = tuple(self.special_token_types) + ('eot', )
tokenizer_config['eot_token'] = special_eos
elif 'eom' not in self.special_token_types:
self.special_token_types = tuple(self.special_token_types) + ('eom', )
tokenizer_config['eom_token'] = special_eos
else:
logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!')
tokenizer_config['eos_token'] = special_eos = special_last
self.add_special_token['eos'] = True if special_last == special_eos else False
if special_last != special_eos:
logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
if tmpl_pair:
seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
if (special_first and seq_start == 0) or (special_last and seq_stop is None):
logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
if tmpl_a != 'A' or tmpl_b != 'B':
logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
# A [sep] [eos] B
if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
add_sep = False
if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
if special_entry in (special_sep, special_eos) and not special_last:
add_sep = True
if special_entry not in (special_sep, special_eos):
logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
else:
logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
if len(tmpl_pair) == 2:
if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
if special_entry in (special_sep, special_eos):
if post_processor := tokenizer.get('post_processor'):
for processor in post_processor.get('processors', [post_processor]):
if processor.get('type') == 'RobertaProcessing':
self.add_special_token['bos'] = True
self.add_special_token['eos'] = True
self.add_special_token['sep'] = True
if not special_cls and tokenizer_config:
special_cls = processor.get('cls', [special_bos])[0]
tokenizer_config['cls_token'] = special_cls
if not special_sep and tokenizer_config:
special_sep = processor.get('sep', [special_eos])[0]
tokenizer_config['sep_token'] = special_sep
continue
# Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
# Only works with simple templates, **will** get it wrong on unusual sequences
if processor.get('type') == 'TemplateProcessing':
tmpl_single = processor.get('single', [])
tmpl_pair = processor.get('pair', [])
special_first = None
special_last = None
if len(tmpl_single) > 1:
if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
if not tokenizer_config:
special_bos = special_first
self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
if special_first not in (special_bos, special_cls):
logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
if not tokenizer_config:
special_eos = special_last
elif special_last != special_eos:
if 'eot' not in self.special_token_types:
self.special_token_types = tuple(self.special_token_types) + ('eot', )
tokenizer_config['eot_token'] = special_eos
elif 'eom' not in self.special_token_types:
self.special_token_types = tuple(self.special_token_types) + ('eom', )
tokenizer_config['eom_token'] = special_eos
else:
logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!')
tokenizer_config['eos_token'] = special_eos = special_last
self.add_special_token['eos'] = True if special_last == special_eos else False
if special_last != special_eos:
logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
if tmpl_pair:
seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
if (special_first and seq_start == 0) or (special_last and seq_stop is None):
logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
if tmpl_a != 'A' or tmpl_b != 'B':
logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
# A [sep] [eos] B
if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
add_sep = False
if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
if special_entry in (special_sep, special_eos) and not special_last:
add_sep = True
if special_entry not in (special_sep, special_eos):
logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
else:
logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
self.add_special_token['sep'] = add_sep
if add_sep and not special_sep and tokenizer_config:
tokenizer_config['sep_token'] = special_eos
continue
logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
if len(tmpl_pair) == 2:
if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
if special_entry in (special_sep, special_eos):
add_sep = True
if special_entry not in (special_sep, special_eos):
logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
else:
logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
self.add_special_token['sep'] = add_sep
if add_sep and not special_sep and tokenizer_config:
tokenizer_config['sep_token'] = special_eos
continue
if not tokenizer_config:
return True
chat_template_alt = None

View file

@ -390,6 +390,7 @@ extern "C" {
void * imatrix; // pointer to importance matrix data
void * kv_overrides; // pointer to vector containing overrides
void * tensor_types; // pointer to vector containing tensor types
void * prune_layers; // pointer to vector containing layer indices to prune
} llama_model_quantize_params;
typedef struct llama_logit_bias {

View file

@ -1,5 +1,4 @@
#include "llama-quant.h"
#include "llama-impl.h"
#include "llama-model.h"
#include "llama-model-loader.h"
@ -27,6 +26,56 @@ static void zeros(std::ofstream & file, size_t n) {
}
}
static std::string remap_layer(const std::string & orig_name, const std::vector<int> & prune, std::map<int, std::string> & mapped, int & next_id) {
if (prune.empty()) {
return orig_name;
}
static const std::regex pattern(R"(blk\.(\d+)\.)");
if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
const int blk = std::stoi(match[1]);
std::string new_name = orig_name;
if (mapped.count(blk)) {
// Already mapped, do nothing
} else if (std::find(prune.begin(), prune.end(), blk) != prune.end()) {
mapped[blk] = "";
} else if (blk < prune.front()) {
mapped[blk] = std::to_string(blk);
next_id = blk + 1;
} else {
mapped[blk] = std::to_string(next_id);
++next_id;
}
return mapped[blk].empty() ? mapped[blk] : new_name.replace(match.position(1), match.length(1), mapped[blk]);
}
return orig_name;
}
static std::string remap_imatrix (const std::string & orig_name, const std::map<int, std::string> & mapped) {
if (mapped.empty()) {
return orig_name;
}
static const std::regex pattern(R"(blk\.(\d+)\.)");
if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
const std::string blk(match[1]);
std::string new_name = orig_name;
for (const auto & p : mapped) {
if (p.second == blk) {
LLAMA_LOG_DEBUG("(blk.%d imatrix) ", p.first);
return new_name.replace(match.position(1), match.length(1), std::to_string(p.first));
}
}
GGML_ABORT("\n%s: imatrix mapping error for %s\n", __func__, orig_name.c_str());
}
return orig_name;
}
struct quantize_state_impl {
const llama_model & model;
const llama_model_quantize_params * params;
@ -568,6 +617,11 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
const size_t align = GGUF_DEFAULT_ALIGNMENT;
gguf_context_ptr ctx_out { gguf_init_empty() };
std::vector<int> prune_list = {};
if (params->prune_layers) {
prune_list = *static_cast<const std::vector<int> *>(params->prune_layers);
}
// copy the KV pairs from the input file
gguf_set_kv (ctx_out.get(), ml.meta.get());
gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
@ -597,12 +651,32 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
}
}
std::map<int, std::string> mapped;
int blk_id = 0;
int pruned_attention_w = 0;
// make a list of weights
std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
tensors.reserve(ml.weights_map.size());
for (const auto & it : ml.weights_map) {
const std::string remapped_name(remap_layer(it.first, prune_list, mapped, blk_id));
if (remapped_name.empty()) {
if (it.first.find("attn_v.weight") != std::string::npos ||
it.first.find("attn_qkv.weight") != std::string::npos ||
it.first.find("attn_kv_b.weight") != std::string::npos) {
pruned_attention_w++;
}
LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
continue;
} else if (remapped_name != it.first) {
ggml_set_name(it.second.tensor, remapped_name.c_str());
LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor));
}
tensors.push_back(&it.second);
}
if (!prune_list.empty()) {
gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_BLOCK_COUNT).c_str(), blk_id);
}
// keep_split requires that the weights are sorted by split index
if (params->keep_split) {
@ -640,7 +714,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
if (llama_model_has_encoder(&model)) {
n_attn_layer *= 3;
}
GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected");
}
size_t total_size_org = 0;
@ -681,7 +755,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
for (size_t i = 0; i < ctx_outs.size(); ++i) {
gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors);
gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), (int32_t)tensors.size());
}
}
@ -832,7 +906,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
const float * imatrix = nullptr;
if (imatrix_data) {
auto it = imatrix_data->find(tensor->name);
auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped));
if (it == imatrix_data->end()) {
LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
} else {
@ -947,6 +1021,7 @@ llama_model_quantize_params llama_model_quantize_default_params() {
/*.imatrix =*/ nullptr,
/*.kv_overrides =*/ nullptr,
/*.tensor_type =*/ nullptr,
/*.prune_layers =*/ nullptr
};
return result;

View file

@ -4652,6 +4652,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));
return test_cases;
}

View file

@ -2211,6 +2211,9 @@ struct clip_model_loader {
{
hparams.rope_theta = 10000.0f;
hparams.warmup_image_size = hparams.patch_size * 8;
// Mistral Small 2506 needs 1024x1024 image size cap to prevent OOM
// ref: https://github.com/ggml-org/llama.cpp/issues/14310
hparams.image_size = 1024;
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
} break;
case PROJECTOR_TYPE_GEMMA3:

View file

@ -107,13 +107,11 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp
return false;
}
// usage:
// ./llama-quantize [--allow-requantize] [--leave-output-tensor] [--pure] models/llama/ggml-model.gguf [models/llama/ggml-model-quant.gguf] type [nthreads]
//
[[noreturn]]
static void usage(const char * executable) {
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type]\n", executable);
printf(" [--token-embedding-type] [--tensor-type] [--keep-split] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n");
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights]\n", executable);
printf(" [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--tensor-type] [--prune-layers] [--keep-split] [--override-kv]\n");
printf(" model-f32.gguf [model-quant.gguf] type [nthreads]\n\n");
printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n");
printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n");
printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n");
@ -124,6 +122,8 @@ static void usage(const char * executable) {
printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n");
printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. example: --tensor-type attn_q=q8_0\n");
printf(" Advanced option to selectively quantize tensors. May be specified multiple times.\n");
printf(" --prune-layers L0,L1,L2...comma-separated list of layer numbers to prune from the model\n");
printf(" Advanced option to remove all tensors from the given layers\n");
printf(" --keep-split: will generate quantized model in the same shards as input\n");
printf(" --override-kv KEY=TYPE:VALUE\n");
printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n");
@ -286,6 +286,32 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
return true;
}
static bool parse_layer_prune(const char * data, std::vector<int> & prune_layers) {
if (!data) {
printf("\n%s: no layer pruning ids provided\n\n", __func__);
return false;
}
const auto block_ids = string_split<std::string>(data, ',');
for (const auto & block_id : block_ids) {
int id;
try {
id = std::stoi(block_id);
} catch (...) {
id = -1;
}
if (id < 0) {
printf("\n%s: invalid layer id '%s'\n\n", __func__, block_id.c_str());
return false;
}
prune_layers.emplace_back(id);
}
sort(prune_layers.begin(), prune_layers.end());
prune_layers.erase(std::unique(prune_layers.begin(), prune_layers.end()), prune_layers.end());
return true;
}
int main(int argc, char ** argv) {
if (argc < 3) {
usage(argv[0]);
@ -298,6 +324,7 @@ int main(int argc, char ** argv) {
std::vector<std::string> included_weights, excluded_weights;
std::vector<llama_model_kv_override> kv_overrides;
std::vector<tensor_quantization> tensor_types;
std::vector<int> prune_layers;
for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) {
@ -324,6 +351,10 @@ int main(int argc, char ** argv) {
if (arg_idx == argc-1 || !parse_tensor_type(argv[++arg_idx], tensor_types)) {
usage(argv[0]);
}
} else if (strcmp(argv[arg_idx], "--prune-layers") == 0) {
if (arg_idx == argc-1 || !parse_layer_prune(argv[++arg_idx], prune_layers)) {
usage(argv[0]);
}
} else if (strcmp(argv[arg_idx], "--override-kv") == 0) {
if (arg_idx == argc-1 || !string_parse_kv_override(argv[++arg_idx], kv_overrides)) {
usage(argv[0]);
@ -411,6 +442,9 @@ int main(int argc, char ** argv) {
if (!tensor_types.empty()) {
params.tensor_types = &tensor_types;
}
if (!prune_layers.empty()) {
params.prune_layers = &prune_layers;
}
llama_backend_init();

View file

@ -9,6 +9,9 @@
#include <nlohmann/json.hpp>
#if defined(_WIN32)
# ifndef NOMINMAX
# define NOMINMAX
# endif
# include <windows.h>
# include <io.h>
#else
@ -939,17 +942,30 @@ static int apply_chat_template(const struct common_chat_templates * tmpls, Llama
// Function to tokenize the prompt
static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
const bool is_first = llama_memory_seq_pos_max(llama_get_memory(llama_data.context.get()), 0) == 0;
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
prompt_tokens.resize(n_prompt_tokens);
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first,
true) < 0) {
printe("failed to tokenize the prompt\n");
const bool is_first = llama_memory_seq_pos_max(llama_get_memory(llama_data.context.get()), 0) == -1;
int n_tokens = prompt.size() + 2 * is_first;
prompt_tokens.resize(n_tokens);
n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.size(),
prompt_tokens.data(), prompt_tokens.size(),
is_first, /*parse_special =*/true);
if (n_tokens == std::numeric_limits<int32_t>::min()) {
printe("tokenization failed: input too large\n");
return -1;
}
return n_prompt_tokens;
if (n_tokens < 0) {
prompt_tokens.resize(-n_tokens);
int check = llama_tokenize(vocab, prompt.c_str(), prompt.size(),
prompt_tokens.data(), prompt_tokens.size(),
is_first, /*parse_special =*/true);
if (check != -n_tokens) {
printe("failed to tokenize the prompt (size mismatch)\n");
return -1;
}
n_tokens = check;
} else {
prompt_tokens.resize(n_tokens);
}
return n_tokens;
}
// Check if we have enough space in the context to evaluate this batch