convert : converting mmproj for Qwen2/2.5VL from convert_hf_to_gguf (#13209)

* wip

* qwen2.5vl ok

* vision: fix models missing "text_config"

* add test

* fix test repo name

* fix 32B model

* Revert "fix 32B model"

This reverts commit 651752f1ae25fe8a01c1e57c18cf2eca80b2774e.

* clarify about 32B

* rm qwen surgery script

* update llava/readme

* move V_ENC_EMBD_PATCH handling to Qwen2VLVisionModel
This commit is contained in:
Xuan-Son Nguyen 2025-05-02 17:17:15 +02:00 committed by GitHub
parent c642bc014c
commit 074e42ab31
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 132 additions and 233 deletions

View file

@ -35,6 +35,16 @@ llama-mtmd-cli -hf ggml-org/SmolVLM2-500M-Video-Instruct-GGUF
# Pixtral 12B
llama-mtmd-cli -hf ggml-org/pixtral-12b-GGUF
# Qwen 2 VL
llama-mtmd-cli -hf ggml-org/Qwen2-VL-2B-Instruct-GGUF
llama-mtmd-cli -hf ggml-org/Qwen2-VL-7B-Instruct-GGUF
# Qwen 2.5 VL
llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-3B-Instruct-GGUF
llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-7B-Instruct-GGUF
llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-32B-Instruct-GGUF
llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-72B-Instruct-GGUF
# Mistral Small 3.1 24B (IQ2_M quantization)
llama-mtmd-cli -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF --chat-template mistral-v7
```
@ -60,7 +70,17 @@ Built upon `clip.cpp` (similar to `llava.cpp`), `libmtmd` offers several advanta
## How to obtain `mmproj`
Multimodal projector (`mmproj`) files are specific to each model architecture. Please refer to the relevant guide for instructions on how to obtain or create them:
Multimodal projector (`mmproj`) files are specific to each model architecture.
For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` flag to get the `mmproj` file:
- [Gemma 3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) - Note: 1B variant does not have vision support
- SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
- SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
- [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint
- Qwen 2 VL and Qwen 2.5 VL (from [Qwen](https://huggingface.co/Qwen))
- [Mistral Small 3.1 24B](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)
For older models, please refer to the relevant guide for instructions on how to obtain or create them:
- [LLaVA](../../docs/multimodal/llava.md)
- [MobileVLM](../../docs/multimodal/MobileVLM.md)
@ -70,10 +90,3 @@ Multimodal projector (`mmproj`) files are specific to each model architecture. P
- [MiniCPM-o 2.6](../../docs/multimodal/minicpmo2.6.md)
- [IBM Granite Vision](../../docs/multimodal/granitevision.md)
- [Google Gemma 3](../../docs/multimodal/gemma3.md)
For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` flag to get the `mmproj` file:
- [Gemma 3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) - Note: 1B variant does not have vision support
- SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
- SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
- [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint
- [Mistral Small 3.1 24B](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)

View file

@ -1,217 +0,0 @@
import argparse
from typing import Dict, List, Optional
import torch
import numpy as np
from gguf import *
from transformers import (
AutoProcessor,
Qwen2VLConfig,
Qwen2VLProcessor,
Qwen2VLForConditionalGeneration,
Qwen2_5_VLConfig, # type: ignore[reportAttributeAccessIssue]
Qwen2_5_VLForConditionalGeneration, # type: ignore[reportAttributeAccessIssue]
)
VISION = "clip.vision"
def k(raw_key: str, arch: str) -> str:
return raw_key.format(arch=arch)
def get_n_wa_pattern(fullatt_block_indexes: Optional[List[int]]):
if fullatt_block_indexes is None:
return 0
n_wa = fullatt_block_indexes[0]
for a, b in zip(fullatt_block_indexes, fullatt_block_indexes[1:]):
if b - a - 1 != n_wa:
raise ValueError(
f"window/full attention layer should have fix pattern of "
f"for each full-attention layer followed by {n_wa} window-attention layers"
)
return n_wa + 1
class VL2:
@staticmethod
def to_gguf_name(name: str) -> str:
og = name
name = name.replace("text_model", "t").replace("vision_model", "v")
name = name.replace("blocks", "blk").replace("embeddings.", "")
name = name.replace("attn.", "attn_")
name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.")
# name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln")
name = name.replace("norm1", "ln1").replace("norm2", "ln2")
name = name.replace("merger.mlp", 'mm')
print(f"[to_gguf_name] {og} --> {name}")
return name
@classmethod
def find_vision_tensors(cls, qwen2vl, dtype) -> Dict[str, np.ndarray]:
vision_model = qwen2vl.visual
tensor_map = {}
for name, ten in vision_model.state_dict().items():
ten = ten.numpy()
if 'qkv' in name:
if ten.ndim == 2: # weight
c3, _ = ten.shape
else: # bias
c3 = ten.shape[0]
assert c3 % 3 == 0
c = c3 // 3
wq = ten[:c]
wk = ten[c: c * 2]
wv = ten[c * 2:]
tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq
tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk
tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv
elif 'merger' in name:
if name.endswith("ln_q.weight"):
tensor_map['v.post_ln.weight'] = ten
elif name.endswith("ln_q.bias"):
tensor_map['v.post_ln.bias'] = ten
else:
# "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias"
tensor_map[cls.to_gguf_name(name)] = ten
elif 'patch_embed.proj.weight' in name:
# NOTE: split Conv3D into Conv2Ds
c1, c2, kt, kh, kw = ten.shape
assert kt == 2, "Current implmentation only support temporal_patch_size of 2"
tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...]
tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...]
else:
tensor_map[cls.to_gguf_name(f"vision_model.{name}")] = ten
for new_name, ten in tensor_map.items():
if ten.ndim <= 1 or new_name.endswith("_norm.weight"):
tensor_map[new_name] = ten.astype(np.float32)
else:
tensor_map[new_name] = ten.astype(dtype)
tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder
return tensor_map
class VL25(VL2):
@staticmethod
def to_gguf_name(name: str) -> str:
og = name
name = name.replace("text_model", "t").replace("vision_model", "v")
name = name.replace("blocks", "blk").replace("embeddings.", "")
name = name.replace("attn.", "attn_")
name = name.replace("mlp.down_proj", "ffn_down").replace("mlp.up_proj", "ffn_up")
name = name.replace("mlp.gate_proj", "ffn_gate").replace("proj.", "out.")
name = name.replace("norm1", "ln1").replace("norm2", "ln2")
name = name.replace("merger.mlp", 'mm')
print(f"[vl25][to_gguf_name] {og} --> {name}")
return name
def main(args):
if args.data_type == 'fp32':
dtype = torch.float32
np_dtype = np.float32
ftype = 0
elif args.data_type == 'fp16':
dtype = torch.float16
np_dtype = np.float16
ftype = 1
else:
raise ValueError()
local_model = False
model_path = ""
model_name = args.model_name
print("model_name: ", model_name)
if args.model_type == "qwen2vl":
qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained(
model_name, torch_dtype=dtype, device_map="cpu"
)
cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType]
vcfg = cfg.vision_config
else:
qwen2vl = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_name, torch_dtype=dtype, device_map="cpu"
)
cfg: Qwen2_5_VLConfig = qwen2vl.config # type: ignore[reportAssignmentType]
vcfg = cfg.vision_config
if os.path.isdir(model_name):
local_model = True
if model_name.endswith(os.sep):
model_name = model_name[:-1]
model_path = model_name
model_name = os.path.basename(model_name)
fname_out = f"{model_name.replace('/', '-').lower()}-vision.gguf"
fout = GGUFWriter(path=fname_out, arch="clip")
fout.add_description("image encoder for Qwen2VL")
fout.add_file_type(ftype)
fout.add_bool("clip.has_text_encoder", False)
fout.add_bool("clip.has_vision_encoder", True)
fout.add_bool("clip.has_qwen2vl_merger", True)
print(cfg.vision_config)
if 'silu' in cfg.vision_config.hidden_act.lower():
fout.add_bool("clip.use_silu", True)
fout.add_bool("clip.use_gelu", False)
elif 'gelu' in cfg.vision_config.hidden_act.lower():
fout.add_bool("clip.use_silu", False)
fout.add_bool("clip.use_gelu", 'quick' not in cfg.vision_config.hidden_act.lower())
else:
raise ValueError()
if args.model_type == "qwen2.5vl":
fout.add_uint32("clip.vision.n_wa_pattern", get_n_wa_pattern(vcfg.fullatt_block_indexes))
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.hidden_size)
fout.add_uint32("clip.vision.projection_dim", vcfg.out_hidden_size)
fout.add_string("clip.projector_type", "qwen2.5vl_merger")
else:
fout.add_string("clip.projector_type", "qwen2vl_merger")
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim)
fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size)
if args.model_type == "qwen2.5vl":
tensor_map = VL25.find_vision_tensors(qwen2vl, np_dtype)
else:
tensor_map = VL2.find_vision_tensors(qwen2vl, np_dtype)
for name, data in tensor_map.items():
fout.add_tensor(name, data)
fout.add_uint32("clip.vision.patch_size", vcfg.patch_size)
fout.add_uint32("clip.vision.image_size", 14 * 40) # some reasonable size that is divable by (14*2)
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads)
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth)
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 0) # not sure what this does, put 0 here as a placeholder
fout.add_name(model_name)
"""
HACK: Since vision rope related parameter aren't stored in the `Qwen2VLConfig,
it will be hardcoded in the `clip_image_build_graph` from `clip.cpp`.
"""
if local_model:
processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_path)
else:
processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_name)
fout.add_array("clip.vision.image_mean", processor.image_processor.image_mean) # type: ignore[reportAttributeAccessIssue]
fout.add_array("clip.vision.image_std", processor.image_processor.image_std) # type: ignore[reportAttributeAccessIssue]
fout.write_header_to_file()
fout.write_kv_data_to_file()
fout.write_tensors_to_file()
fout.close()
print("save model as: ", fname_out)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct")
parser.add_argument("--model_type", nargs='?', choices=['qwen2vl', 'qwen2.5vl'], default="qwen2vl")
parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32")
args = parser.parse_args()
main(args)

View file

@ -36,12 +36,6 @@ add_test() {
arr_tmpl+=("$tmpl")
}
add_test_big() {
if [ "$RUN_BIG_TESTS" = true ]; then
add_test "$@"
fi
}
add_test "llama-mtmd-cli" "ggml-org/SmolVLM-500M-Instruct-GGUF:Q8_0"
add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-2.2B-Instruct-GGUF:Q4_K_M"
add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0"
@ -58,8 +52,16 @@ add_test "llama-mtmd-cli" "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
# to test the big models, run: ./tests.sh big
add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
add_test_big "llama-mtmd-cli" "ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF" "mistral-v7"
if [ "$RUN_BIG_TESTS" = true ]; then
add_test "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
add_test "llama-mtmd-cli" "ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF" "mistral-v7"
add_test "llama-mtmd-cli" "ggml-org/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
add_test "llama-mtmd-cli" "ggml-org/Qwen2-VL-7B-Instruct-GGUF:Q4_K_M"
add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-7B-Instruct-GGUF:Q4_K_M"
# add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra
# add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-72B-Instruct-GGUF:Q4_K_M" # too big
fi
# these models always give the wrong answer, not sure why
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M"