mtmd : support Qwen 2.5 Omni (input audio+vision, no audio output) (#13784)

* mtmd : allow multiple modalities at the same time

* refactor mtmd tokenizer

* fix compile

* ok, missing SinusoidsPositionEmbedding

* first working version

* fix style

* more strict validate of n_embd

* refactor if..else to switch

* fix regression

* add test for 3B

* update docs

* fix tokenizing with add_special

* add more tests

* fix test case "huge"

* rm redundant code

* set_position_mrope_1d rm n_tokens
This commit is contained in:
Xuan-Son Nguyen 2025-05-27 14:06:10 +02:00 committed by GitHub
parent 72b090da2c
commit bc583e3c63
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 1148 additions and 744 deletions

View file

@ -432,6 +432,9 @@ class ModelBase:
if "llm_config" in config:
# rename for InternVL
config["text_config"] = config["llm_config"]
if "thinker_config" in config:
# rename for Qwen2.5-Omni
config["text_config"] = config["thinker_config"]["text_config"]
return config
@classmethod
@ -1121,18 +1124,21 @@ class MmprojModel(ModelBase):
preprocessor_config: dict[str, Any]
global_config: dict[str, Any]
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"]
has_vision_encoder: bool = True # by default
has_audio_encoder: bool = False
# for models having multiple encoders, we need to separate their hparams
hparams_vision: dict[str, Any] | None = None
hparams_audio: dict[str, Any] | None = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.model_arch != gguf.MODEL_ARCH.MMPROJ:
raise TypeError("MmprojModel must be subclassed with model_arch = gguf.MODEL_ARCH.MMPROJ")
if self.has_vision_encoder and self.has_audio_encoder:
raise NotImplementedError("both vision + audio not supported yet")
# get n_embd of the text model
if "text_config" not in self.hparams:
self.hparams["text_config"] = {}
@ -1143,22 +1149,32 @@ class MmprojModel(ModelBase):
assert self.n_embd_text > 0, "n_embd not found in hparams"
# move vision config to the top level, while preserving the original hparams in global_config
self.global_config = self.hparams
import copy
self.global_config = copy.deepcopy(self.hparams)
self.hparams_vision = self.get_vision_config()
self.hparams_audio = self.get_audio_config()
if "vision_config" in self.hparams:
self.hparams = self.hparams["vision_config"]
elif "audio_config" in self.hparams:
self.hparams = self.hparams["audio_config"]
else:
if self.hparams_vision is None and self.hparams_audio is None:
raise ValueError("vision_config / audio_config not found in hparams")
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"])
# for compat with vision-only models
self.hparams = self.hparams_vision or self.hparams_audio or self.hparams
# TODO @ngxson : this is a hack to support both vision and audio encoders
have_multiple_encoders = self.has_audio_encoder and self.has_vision_encoder
self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys, True)
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)
# load preprocessor config
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
self.preprocessor_config = json.load(f)
def get_vision_config(self) -> dict[str, Any] | None:
return self.global_config.get("vision_config")
def get_audio_config(self) -> dict[str, Any] | None:
return self.global_config.get("audio_config")
def set_type(self):
self.gguf_writer.add_type(gguf.GGUFType.MMPROJ)
@ -1170,26 +1186,26 @@ class MmprojModel(ModelBase):
self.gguf_writer.add_vision_projection_dim(self.n_embd_text)
# vision config
self.gguf_writer.add_vision_image_size(self.find_hparam(["image_size"]))
self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"]))
self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"]))
self.gguf_writer.add_vision_block_count(self.block_count)
self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"]))
self.gguf_writer.add_vision_image_size(self.find_vparam(["image_size"]))
self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"]))
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"]))
self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys))
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"]))
# preprocessor config
self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"])
self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_std"])
elif self.has_audio_encoder:
if self.has_audio_encoder:
self.gguf_writer.add_clip_has_audio_encoder(True)
self.gguf_writer.add_audio_projection_dim(self.n_embd_text)
# audio config
self.gguf_writer.add_audio_embedding_length(self.find_hparam(["hidden_size"]))
self.gguf_writer.add_audio_feed_forward_length(self.find_hparam(["intermediate_size"]))
self.gguf_writer.add_audio_block_count(self.block_count)
self.gguf_writer.add_audio_head_count(self.find_hparam(["num_attention_heads"]))
self.gguf_writer.add_audio_embedding_length(self.find_aparam(["hidden_size"]))
self.gguf_writer.add_audio_feed_forward_length(self.find_aparam(["intermediate_size"]))
self.gguf_writer.add_audio_block_count(self.find_aparam(self.n_block_keys))
self.gguf_writer.add_audio_head_count(self.find_aparam(["num_attention_heads"]))
else:
raise ValueError("MmprojModel must have either vision or audio encoder")
@ -1197,6 +1213,22 @@ class MmprojModel(ModelBase):
def write_vocab(self):
raise ValueError("MmprojModel does not support vocab writing")
def find_vparam(self, keys: Iterable[str], optional: bool = False) -> Any:
assert self.hparams_vision is not None
return self._find_param(self.hparams_vision, keys, optional)
def find_aparam(self, keys: Iterable[str], optional: bool = False) -> Any:
assert self.hparams_audio is not None
return self._find_param(self.hparams_audio, keys, optional)
def _find_param(self, obj: dict[str, Any], keys: Iterable[str], optional: bool = False) -> Any:
key = next((k for k in keys if k in obj), None)
if key is not None:
return obj[key]
if optional:
return None
raise KeyError(f"could not find any of: {keys}")
@ModelBase.register("GPTNeoXForCausalLM")
class GPTNeoXModel(TextModel):
@ -2674,7 +2706,12 @@ class Qwen2Model(TextModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
@ModelBase.register(
"Qwen2VLModel",
"Qwen2VLForConditionalGeneration",
"Qwen2_5_VLForConditionalGeneration",
"Qwen2_5OmniModel",
)
class Qwen2VLModel(TextModel):
model_arch = gguf.MODEL_ARCH.QWEN2VL
@ -2692,8 +2729,11 @@ class Qwen2VLModel(TextModel):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
if name.startswith("visual."):
# skip visual tensors
if name.startswith("thinker."):
name = name.replace("thinker.", "")
if name.startswith("visual") or name.startswith("audio") or \
name.startswith("talker") or name.startswith("token2wav"):
# skip multimodal tensors
return []
return [(self.map_tensor_name(name), data_torch)]
@ -2702,21 +2742,27 @@ class Qwen2VLModel(TextModel):
class Qwen2VLVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hparams["image_size"] = self.hparams.get("image_size", 560)
assert self.hparams_vision is not None
self.hparams_vision["image_size"] = self.hparams_vision.get("image_size", 560)
# rename config.json values
self.hparams["num_attention_heads"] = self.hparams.get("num_heads")
self.hparams["num_hidden_layers"] = self.hparams.get("depth")
if "embed_dim" in self.hparams: # qwen2vl
self.hparams["intermediate_size"] = self.hparams.get("hidden_size")
self.hparams["hidden_size"] = self.hparams.get("embed_dim")
self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads")
self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth")
if "embed_dim" in self.hparams_vision: # qwen2vl
self.hparams_vision["intermediate_size"] = self.hparams_vision.get("hidden_size")
self.hparams_vision["hidden_size"] = self.hparams_vision.get("embed_dim")
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
if self.global_config['model_type'] == 'qwen2_vl':
assert self.hparams_vision is not None
hparams = self.hparams_vision
model_type = self.global_config['model_type']
if model_type == 'qwen2_vl':
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN2VL)
elif self.global_config['model_type'] == 'qwen2_5_vl':
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25VL)
elif model_type == 'qwen2_5_vl' or model_type == 'qwen2_5_omni':
if model_type == 'qwen2_5_omni':
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25O)
else:
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25VL)
self.gguf_writer.add_vision_use_silu(True)
# find n_wa_pattern (window attention pattern)
fullatt_block_indexes = hparams.get("fullatt_block_indexes")
@ -2774,6 +2820,66 @@ class Qwen2VLVisionModel(MmprojModel):
return [] # skip other tensors
@ModelBase.register("Qwen2_5OmniModel")
class Qwen25OmniModel(Qwen2VLVisionModel):
has_vision_encoder = True
has_audio_encoder = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_audio is not None
self.hparams_audio["hidden_size"] = self.hparams_audio["d_model"]
self.hparams_audio["intermediate_size"] = self.hparams_audio["encoder_ffn_dim"]
self.hparams_audio["num_attention_heads"] = self.hparams_audio["encoder_attention_heads"]
def set_gguf_parameters(self):
super().set_gguf_parameters()
assert self.hparams_audio is not None
self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["num_mel_bins"])
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams_audio.get("layer_norm_eps", 1e-5))
def get_vision_config(self) -> dict[str, Any] | None:
return self.global_config["thinker_config"].get("vision_config")
def get_audio_config(self) -> dict[str, Any] | None:
return self.global_config["thinker_config"].get("audio_config")
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
# SinusoidsPositionEmbedding
assert self.hparams_audio is not None
max_timescale = 10000
length = 1500
channels = self.hparams_audio["hidden_size"]
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float())
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
pos_embd = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1).to(dtype=torch.float32)
yield ("audio_tower.embed_positions.weight", pos_embd)
def tensor_force_quant(self, name, new_name, bid, n_dims):
del bid, new_name, n_dims # unused
if ".conv" in name and ".weight" in name:
return gguf.GGMLQuantizationType.F16
return False
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith("thinker."):
name = name.replace("thinker.", "")
if name.startswith("audio_tower"):
# process audio tensors
if "conv1.bias" in name or "conv2.bias" in name:
# transpose conv1 and conv2 bias
data_torch = data_torch.unsqueeze(-1)
if "audio_bos_eos_token" in name:
# this tensor is left unused in transformers code
# https://github.com/huggingface/transformers/blob/6e3063422c4b1c014aa60c32b9254fd2902f0f28/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py#L1809
return []
return [(self.map_tensor_name(name), data_torch)]
return super().modify_tensors(data_torch, name, bid)
@ModelBase.register("InternVisionModel")
class InternVisionModel(MmprojModel):
def set_gguf_parameters(self):