Refactor gguf scripts to improve metadata handling (#11909)

* Refactor gguf scripts to improve metadata handling

Added contents method to ReaderField class
Added endianess property to GGUFReader class

* update scripts

* fix import

* remove unused import

* attempt to work around flake and pyright errors

* second attempt

* give up, ignore type

* bump version

* apply newbyteorder fixes
This commit is contained in:
Sigbjørn Skjæret 2025-02-26 14:04:48 +01:00 committed by GitHub
parent 3567ee3a94
commit 69050a11be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 88 additions and 81 deletions

View file

@ -6,6 +6,7 @@ from __future__ import annotations
import logging
import os
import sys
from collections import OrderedDict
from typing import Any, Literal, NamedTuple, TypeVar, Union
@ -15,7 +16,6 @@ import numpy.typing as npt
from .quants import quant_shape_to_byte_shape
if __name__ == "__main__":
import sys
from pathlib import Path
# Allow running file in package as a script.
@ -28,6 +28,7 @@ from gguf.constants import (
GGUF_VERSION,
GGMLQuantizationType,
GGUFValueType,
GGUFEndian,
)
logger = logging.getLogger(__name__)
@ -53,6 +54,48 @@ class ReaderField(NamedTuple):
types: list[GGUFValueType] = []
def contents(self, index_or_slice: int | slice = slice(None)) -> Any:
if self.types:
to_string = lambda x: str(x.tobytes(), encoding='utf-8') # noqa: E731
main_type = self.types[0]
if main_type == GGUFValueType.ARRAY:
sub_type = self.types[-1]
if sub_type == GGUFValueType.STRING:
indices = self.data[index_or_slice]
if isinstance(index_or_slice, int):
return to_string(self.parts[indices]) # type: ignore
else:
return [to_string(self.parts[idx]) for idx in indices] # type: ignore
else:
# FIXME: When/if _get_field_parts() support multi-dimensional arrays, this must do so too
# Check if it's unsafe to perform slice optimization on data
# if any(True for idx in self.data if len(self.parts[idx]) != 1):
# optim_slice = slice(None)
# else:
# optim_slice = index_or_slice
# index_or_slice = slice(None)
# if isinstance(optim_slice, int):
# return self.parts[self.data[optim_slice]].tolist()[0]
# else:
# return [pv for idx in self.data[optim_slice] for pv in self.parts[idx].tolist()][index_or_slice]
if isinstance(index_or_slice, int):
return self.parts[self.data[index_or_slice]].tolist()[0]
else:
return [pv for idx in self.data[index_or_slice] for pv in self.parts[idx].tolist()]
if main_type == GGUFValueType.STRING:
return to_string(self.parts[-1])
else:
return self.parts[-1].tolist()[0]
return None
class ReaderTensor(NamedTuple):
name: str
@ -101,10 +144,19 @@ class GGUFReader:
# If we get 0 here that means it's (probably) a GGUF file created for
# the opposite byte order of the machine this script is running on.
self.byte_order = 'S'
temp_version = temp_version.newbyteorder(self.byte_order)
temp_version = temp_version.view(temp_version.dtype.newbyteorder(self.byte_order))
version = temp_version[0]
if version not in READER_SUPPORTED_VERSIONS:
raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle')
if sys.byteorder == "little":
# Host is little endian
host_endian = GGUFEndian.LITTLE
swapped_endian = GGUFEndian.BIG
else:
# Sorry PDP or other weird systems that don't use BE or LE.
host_endian = GGUFEndian.BIG
swapped_endian = GGUFEndian.LITTLE
self.endianess = swapped_endian if self.byte_order == "S" else host_endian
self.fields: OrderedDict[str, ReaderField] = OrderedDict()
self.tensors: list[ReaderTensor] = []
offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32]))
@ -146,11 +198,7 @@ class GGUFReader:
itemsize = int(np.empty([], dtype = dtype).itemsize)
end_offs = offset + itemsize * count
arr = self.data[offset:end_offs].view(dtype=dtype)[:count]
if override_order is not None:
return arr.view(arr.dtype.newbyteorder(override_order))
if self.byte_order == 'S':
return arr.view(arr.dtype.newbyteorder(self.byte_order))
return arr
return arr.view(arr.dtype.newbyteorder(self.byte_order if override_order is None else override_order))
def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
if field.name in self.fields:
@ -192,6 +240,7 @@ class GGUFReader:
offs += int(alen.nbytes)
aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
data_idxs: list[int] = []
# FIXME: Handle multi-dimensional arrays properly instead of flattening
for idx in range(alen[0]):
curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0])
if idx == 0: