clip : Add Qwen2.5VL support (#12402)
* implment vision model architecture, gguf convertor * handle window attention inputs * add debug utils * fix few incorrect tensor memory layout * move position id remap out of ggml to avoid int32 cuda operations * cleaning up * ignore transformers Qwen2_5_xxx type check * remove not so often use `qwen2vl-cli` debug functions * remove commented-out code blocks * fix attn weight scaling after rebase * add `PROJECTOR_TYPE_QWEN2_5_VL` * remove `KEY_USE_GLU_MLP`, `KEY_USE_RMS_NORM` * replace `KEY_FULLATTN_BLK_IDX` with `KEY_WIN_ATTN_PATTERN` * remove `attn_window_size` from gguf * fix model conversion * clean up * fix merging problem * add test --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
This commit is contained in:
parent
2d451c8059
commit
ca2bb89eac
6 changed files with 594 additions and 102 deletions
|
@ -1,14 +1,16 @@
|
|||
import argparse
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from gguf import *
|
||||
from transformers import (
|
||||
Qwen2VLForConditionalGeneration,
|
||||
Qwen2VLProcessor,
|
||||
AutoProcessor,
|
||||
Qwen2VLConfig
|
||||
Qwen2VLConfig,
|
||||
Qwen2VLProcessor,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
Qwen2_5_VLConfig, # type: ignore[reportAttributeAccessIssue]
|
||||
Qwen2_5_VLForConditionalGeneration, # type: ignore[reportAttributeAccessIssue]
|
||||
)
|
||||
|
||||
|
||||
|
@ -19,61 +21,93 @@ def k(raw_key: str, arch: str) -> str:
|
|||
return raw_key.format(arch=arch)
|
||||
|
||||
|
||||
def to_gguf_name(name: str) -> str:
|
||||
og = name
|
||||
name = name.replace("text_model", "t").replace("vision_model", "v")
|
||||
name = name.replace("blocks", "blk").replace("embeddings.", "")
|
||||
name = name.replace("attn.", "attn_")
|
||||
name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.")
|
||||
# name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln")
|
||||
name = name.replace("norm1", "ln1").replace("norm2", "ln2")
|
||||
name = name.replace("merger.mlp", 'mm')
|
||||
print(f"[to_gguf_name] {og} --> {name}")
|
||||
return name
|
||||
def get_n_wa_pattern(fullatt_block_indexes: Optional[List[int]]):
|
||||
if fullatt_block_indexes is None:
|
||||
return 0
|
||||
n_wa = fullatt_block_indexes[0]
|
||||
for a, b in zip(fullatt_block_indexes, fullatt_block_indexes[1:]):
|
||||
if b - a - 1 != n_wa:
|
||||
raise ValueError(
|
||||
f"window/full attention layer should have fix pattern of "
|
||||
f"for each full-attention layer followed by {n_wa} window-attention layers"
|
||||
)
|
||||
return n_wa + 1
|
||||
|
||||
|
||||
def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]:
|
||||
vision_model = qwen2vl.visual
|
||||
tensor_map = {}
|
||||
for name, ten in vision_model.state_dict().items():
|
||||
ten = ten.numpy()
|
||||
if 'qkv' in name:
|
||||
if ten.ndim == 2: # weight
|
||||
c3, _ = ten.shape
|
||||
else: # bias
|
||||
c3 = ten.shape[0]
|
||||
assert c3 % 3 == 0
|
||||
c = c3 // 3
|
||||
wq = ten[:c]
|
||||
wk = ten[c: c * 2]
|
||||
wv = ten[c * 2:]
|
||||
tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq
|
||||
tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk
|
||||
tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv
|
||||
elif 'merger' in name:
|
||||
if name.endswith("ln_q.weight"):
|
||||
tensor_map['v.post_ln.weight'] = ten
|
||||
elif name.endswith("ln_q.bias"):
|
||||
tensor_map['v.post_ln.bias'] = ten
|
||||
class VL2:
|
||||
|
||||
@staticmethod
|
||||
def to_gguf_name(name: str) -> str:
|
||||
og = name
|
||||
name = name.replace("text_model", "t").replace("vision_model", "v")
|
||||
name = name.replace("blocks", "blk").replace("embeddings.", "")
|
||||
name = name.replace("attn.", "attn_")
|
||||
name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.")
|
||||
# name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln")
|
||||
name = name.replace("norm1", "ln1").replace("norm2", "ln2")
|
||||
name = name.replace("merger.mlp", 'mm')
|
||||
print(f"[to_gguf_name] {og} --> {name}")
|
||||
return name
|
||||
|
||||
@classmethod
|
||||
def find_vision_tensors(cls, qwen2vl, dtype) -> Dict[str, np.ndarray]:
|
||||
vision_model = qwen2vl.visual
|
||||
tensor_map = {}
|
||||
for name, ten in vision_model.state_dict().items():
|
||||
ten = ten.numpy()
|
||||
if 'qkv' in name:
|
||||
if ten.ndim == 2: # weight
|
||||
c3, _ = ten.shape
|
||||
else: # bias
|
||||
c3 = ten.shape[0]
|
||||
assert c3 % 3 == 0
|
||||
c = c3 // 3
|
||||
wq = ten[:c]
|
||||
wk = ten[c: c * 2]
|
||||
wv = ten[c * 2:]
|
||||
tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq
|
||||
tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk
|
||||
tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv
|
||||
elif 'merger' in name:
|
||||
if name.endswith("ln_q.weight"):
|
||||
tensor_map['v.post_ln.weight'] = ten
|
||||
elif name.endswith("ln_q.bias"):
|
||||
tensor_map['v.post_ln.bias'] = ten
|
||||
else:
|
||||
# "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias"
|
||||
tensor_map[cls.to_gguf_name(name)] = ten
|
||||
elif 'patch_embed.proj.weight' in name:
|
||||
# NOTE: split Conv3D into Conv2Ds
|
||||
c1, c2, kt, kh, kw = ten.shape
|
||||
assert kt == 2, "Current implmentation only support temporal_patch_size of 2"
|
||||
tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...]
|
||||
tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...]
|
||||
else:
|
||||
# "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias"
|
||||
tensor_map[to_gguf_name(name)] = ten
|
||||
elif 'patch_embed.proj.weight' in name:
|
||||
# NOTE: split Conv3D into Conv2Ds
|
||||
c1, c2, kt, kh, kw = ten.shape
|
||||
assert kt == 2, "Current implmentation only support temporal_patch_size of 2"
|
||||
tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...]
|
||||
tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...]
|
||||
else:
|
||||
tensor_map[to_gguf_name(f"vision_model.{name}")] = ten
|
||||
tensor_map[cls.to_gguf_name(f"vision_model.{name}")] = ten
|
||||
|
||||
for new_name, ten in tensor_map.items():
|
||||
if ten.ndim <= 1 or new_name.endswith("_norm.weight"):
|
||||
tensor_map[new_name] = ten.astype(np.float32)
|
||||
else:
|
||||
tensor_map[new_name] = ten.astype(dtype)
|
||||
tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder
|
||||
return tensor_map
|
||||
for new_name, ten in tensor_map.items():
|
||||
if ten.ndim <= 1 or new_name.endswith("_norm.weight"):
|
||||
tensor_map[new_name] = ten.astype(np.float32)
|
||||
else:
|
||||
tensor_map[new_name] = ten.astype(dtype)
|
||||
tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder
|
||||
return tensor_map
|
||||
|
||||
|
||||
class VL25(VL2):
|
||||
|
||||
@staticmethod
|
||||
def to_gguf_name(name: str) -> str:
|
||||
og = name
|
||||
name = name.replace("text_model", "t").replace("vision_model", "v")
|
||||
name = name.replace("blocks", "blk").replace("embeddings.", "")
|
||||
name = name.replace("attn.", "attn_")
|
||||
name = name.replace("mlp.down_proj", "ffn_down").replace("mlp.up_proj", "ffn_up")
|
||||
name = name.replace("mlp.gate_proj", "ffn_gate").replace("proj.", "out.")
|
||||
name = name.replace("norm1", "ln1").replace("norm2", "ln2")
|
||||
name = name.replace("merger.mlp", 'mm')
|
||||
print(f"[vl25][to_gguf_name] {og} --> {name}")
|
||||
return name
|
||||
|
||||
|
||||
def main(args):
|
||||
|
@ -82,7 +116,7 @@ def main(args):
|
|||
np_dtype = np.float32
|
||||
ftype = 0
|
||||
elif args.data_type == 'fp16':
|
||||
dtype = torch.float32
|
||||
dtype = torch.float16
|
||||
np_dtype = np.float16
|
||||
ftype = 1
|
||||
else:
|
||||
|
@ -92,11 +126,18 @@ def main(args):
|
|||
model_path = ""
|
||||
model_name = args.model_name
|
||||
print("model_name: ", model_name)
|
||||
qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_name, torch_dtype=dtype, device_map="cpu"
|
||||
)
|
||||
cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType]
|
||||
vcfg = cfg.vision_config
|
||||
if args.model_type == "qwen2vl":
|
||||
qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_name, torch_dtype=dtype, device_map="cpu"
|
||||
)
|
||||
cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType]
|
||||
vcfg = cfg.vision_config
|
||||
else:
|
||||
qwen2vl = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
model_name, torch_dtype=dtype, device_map="cpu"
|
||||
)
|
||||
cfg: Qwen2_5_VLConfig = qwen2vl.config # type: ignore[reportAssignmentType]
|
||||
vcfg = cfg.vision_config
|
||||
|
||||
if os.path.isdir(model_name):
|
||||
local_model = True
|
||||
|
@ -113,7 +154,6 @@ def main(args):
|
|||
fout.add_bool("clip.has_text_encoder", False)
|
||||
fout.add_bool("clip.has_vision_encoder", True)
|
||||
fout.add_bool("clip.has_qwen2vl_merger", True)
|
||||
fout.add_string("clip.projector_type", "qwen2vl_merger")
|
||||
|
||||
print(cfg.vision_config)
|
||||
if 'silu' in cfg.vision_config.hidden_act.lower():
|
||||
|
@ -125,14 +165,25 @@ def main(args):
|
|||
else:
|
||||
raise ValueError()
|
||||
|
||||
tensor_map = find_vision_tensors(qwen2vl, np_dtype)
|
||||
if args.model_type == "qwen2.5vl":
|
||||
fout.add_uint32("clip.vision.n_wa_pattern", get_n_wa_pattern(vcfg.fullatt_block_indexes))
|
||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.hidden_size)
|
||||
fout.add_uint32("clip.vision.projection_dim", vcfg.out_hidden_size)
|
||||
fout.add_string("clip.projector_type", "qwen2.5vl_merger")
|
||||
else:
|
||||
fout.add_string("clip.projector_type", "qwen2vl_merger")
|
||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim)
|
||||
fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size)
|
||||
|
||||
if args.model_type == "qwen2.5vl":
|
||||
tensor_map = VL25.find_vision_tensors(qwen2vl, np_dtype)
|
||||
else:
|
||||
tensor_map = VL2.find_vision_tensors(qwen2vl, np_dtype)
|
||||
for name, data in tensor_map.items():
|
||||
fout.add_tensor(name, data)
|
||||
|
||||
fout.add_uint32("clip.vision.patch_size", vcfg.patch_size)
|
||||
fout.add_uint32("clip.vision.image_size", 14 * 40) # some reasonable size that is divable by (14*2)
|
||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim)
|
||||
fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size)
|
||||
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads)
|
||||
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
|
||||
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth)
|
||||
|
@ -160,6 +211,7 @@ def main(args):
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct")
|
||||
parser.add_argument("--model_type", nargs='?', choices=['qwen2vl', 'qwen2.5vl'], default="qwen2vl")
|
||||
parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue