Refactor gguf scripts to improve metadata handling (#11909)
* Refactor gguf scripts to improve metadata handling Added contents method to ReaderField class Added endianess property to GGUFReader class * update scripts * fix import * remove unused import * attempt to work around flake and pyright errors * second attempt * give up, ignore type * bump version * apply newbyteorder fixes
This commit is contained in:
parent
3567ee3a94
commit
69050a11be
6 changed files with 88 additions and 81 deletions
|
@ -2,12 +2,14 @@
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from gguf.gguf_reader import GGUFReader
|
|
||||||
|
|
||||||
logger = logging.getLogger("reader")
|
logger = logging.getLogger("reader")
|
||||||
|
|
||||||
|
# Necessary to load the local gguf package
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
from gguf.gguf_reader import GGUFReader
|
||||||
|
|
||||||
|
|
||||||
def read_gguf_file(gguf_file_path):
|
def read_gguf_file(gguf_file_path):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -6,6 +6,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Literal, NamedTuple, TypeVar, Union
|
from typing import Any, Literal, NamedTuple, TypeVar, Union
|
||||||
|
|
||||||
|
@ -15,7 +16,6 @@ import numpy.typing as npt
|
||||||
from .quants import quant_shape_to_byte_shape
|
from .quants import quant_shape_to_byte_shape
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# Allow running file in package as a script.
|
# Allow running file in package as a script.
|
||||||
|
@ -28,6 +28,7 @@ from gguf.constants import (
|
||||||
GGUF_VERSION,
|
GGUF_VERSION,
|
||||||
GGMLQuantizationType,
|
GGMLQuantizationType,
|
||||||
GGUFValueType,
|
GGUFValueType,
|
||||||
|
GGUFEndian,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -53,6 +54,48 @@ class ReaderField(NamedTuple):
|
||||||
|
|
||||||
types: list[GGUFValueType] = []
|
types: list[GGUFValueType] = []
|
||||||
|
|
||||||
|
def contents(self, index_or_slice: int | slice = slice(None)) -> Any:
|
||||||
|
if self.types:
|
||||||
|
to_string = lambda x: str(x.tobytes(), encoding='utf-8') # noqa: E731
|
||||||
|
main_type = self.types[0]
|
||||||
|
|
||||||
|
if main_type == GGUFValueType.ARRAY:
|
||||||
|
sub_type = self.types[-1]
|
||||||
|
|
||||||
|
if sub_type == GGUFValueType.STRING:
|
||||||
|
indices = self.data[index_or_slice]
|
||||||
|
|
||||||
|
if isinstance(index_or_slice, int):
|
||||||
|
return to_string(self.parts[indices]) # type: ignore
|
||||||
|
else:
|
||||||
|
return [to_string(self.parts[idx]) for idx in indices] # type: ignore
|
||||||
|
else:
|
||||||
|
# FIXME: When/if _get_field_parts() support multi-dimensional arrays, this must do so too
|
||||||
|
|
||||||
|
# Check if it's unsafe to perform slice optimization on data
|
||||||
|
# if any(True for idx in self.data if len(self.parts[idx]) != 1):
|
||||||
|
# optim_slice = slice(None)
|
||||||
|
# else:
|
||||||
|
# optim_slice = index_or_slice
|
||||||
|
# index_or_slice = slice(None)
|
||||||
|
|
||||||
|
# if isinstance(optim_slice, int):
|
||||||
|
# return self.parts[self.data[optim_slice]].tolist()[0]
|
||||||
|
# else:
|
||||||
|
# return [pv for idx in self.data[optim_slice] for pv in self.parts[idx].tolist()][index_or_slice]
|
||||||
|
|
||||||
|
if isinstance(index_or_slice, int):
|
||||||
|
return self.parts[self.data[index_or_slice]].tolist()[0]
|
||||||
|
else:
|
||||||
|
return [pv for idx in self.data[index_or_slice] for pv in self.parts[idx].tolist()]
|
||||||
|
|
||||||
|
if main_type == GGUFValueType.STRING:
|
||||||
|
return to_string(self.parts[-1])
|
||||||
|
else:
|
||||||
|
return self.parts[-1].tolist()[0]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class ReaderTensor(NamedTuple):
|
class ReaderTensor(NamedTuple):
|
||||||
name: str
|
name: str
|
||||||
|
@ -101,10 +144,19 @@ class GGUFReader:
|
||||||
# If we get 0 here that means it's (probably) a GGUF file created for
|
# If we get 0 here that means it's (probably) a GGUF file created for
|
||||||
# the opposite byte order of the machine this script is running on.
|
# the opposite byte order of the machine this script is running on.
|
||||||
self.byte_order = 'S'
|
self.byte_order = 'S'
|
||||||
temp_version = temp_version.newbyteorder(self.byte_order)
|
temp_version = temp_version.view(temp_version.dtype.newbyteorder(self.byte_order))
|
||||||
version = temp_version[0]
|
version = temp_version[0]
|
||||||
if version not in READER_SUPPORTED_VERSIONS:
|
if version not in READER_SUPPORTED_VERSIONS:
|
||||||
raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle')
|
raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle')
|
||||||
|
if sys.byteorder == "little":
|
||||||
|
# Host is little endian
|
||||||
|
host_endian = GGUFEndian.LITTLE
|
||||||
|
swapped_endian = GGUFEndian.BIG
|
||||||
|
else:
|
||||||
|
# Sorry PDP or other weird systems that don't use BE or LE.
|
||||||
|
host_endian = GGUFEndian.BIG
|
||||||
|
swapped_endian = GGUFEndian.LITTLE
|
||||||
|
self.endianess = swapped_endian if self.byte_order == "S" else host_endian
|
||||||
self.fields: OrderedDict[str, ReaderField] = OrderedDict()
|
self.fields: OrderedDict[str, ReaderField] = OrderedDict()
|
||||||
self.tensors: list[ReaderTensor] = []
|
self.tensors: list[ReaderTensor] = []
|
||||||
offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32]))
|
offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32]))
|
||||||
|
@ -146,11 +198,7 @@ class GGUFReader:
|
||||||
itemsize = int(np.empty([], dtype = dtype).itemsize)
|
itemsize = int(np.empty([], dtype = dtype).itemsize)
|
||||||
end_offs = offset + itemsize * count
|
end_offs = offset + itemsize * count
|
||||||
arr = self.data[offset:end_offs].view(dtype=dtype)[:count]
|
arr = self.data[offset:end_offs].view(dtype=dtype)[:count]
|
||||||
if override_order is not None:
|
return arr.view(arr.dtype.newbyteorder(self.byte_order if override_order is None else override_order))
|
||||||
return arr.view(arr.dtype.newbyteorder(override_order))
|
|
||||||
if self.byte_order == 'S':
|
|
||||||
return arr.view(arr.dtype.newbyteorder(self.byte_order))
|
|
||||||
return arr
|
|
||||||
|
|
||||||
def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
|
def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
|
||||||
if field.name in self.fields:
|
if field.name in self.fields:
|
||||||
|
@ -192,6 +240,7 @@ class GGUFReader:
|
||||||
offs += int(alen.nbytes)
|
offs += int(alen.nbytes)
|
||||||
aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
|
aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
|
||||||
data_idxs: list[int] = []
|
data_idxs: list[int] = []
|
||||||
|
# FIXME: Handle multi-dimensional arrays properly instead of flattening
|
||||||
for idx in range(alen[0]):
|
for idx in range(alen[0]):
|
||||||
curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0])
|
curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0])
|
||||||
if idx == 0:
|
if idx == 0:
|
||||||
|
|
|
@ -20,22 +20,15 @@ logger = logging.getLogger("gguf-convert-endian")
|
||||||
|
|
||||||
|
|
||||||
def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None:
|
def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None:
|
||||||
if np.uint32(1) == np.uint32(1).newbyteorder("<"):
|
file_endian = reader.endianess.name
|
||||||
# Host is little endian
|
if reader.byte_order == 'S':
|
||||||
host_endian = "little"
|
host_endian = 'BIG' if file_endian == 'LITTLE' else 'LITTLE'
|
||||||
swapped_endian = "big"
|
|
||||||
else:
|
else:
|
||||||
# Sorry PDP or other weird systems that don't use BE or LE.
|
host_endian = file_endian
|
||||||
host_endian = "big"
|
order = host_endian if args.order == "native" else args.order.upper()
|
||||||
swapped_endian = "little"
|
logger.info(f"* Host is {host_endian} endian, GGUF file seems to be {file_endian} endian")
|
||||||
if reader.byte_order == "S":
|
|
||||||
file_endian = swapped_endian
|
|
||||||
else:
|
|
||||||
file_endian = host_endian
|
|
||||||
order = host_endian if args.order == "native" else args.order
|
|
||||||
logger.info(f"* Host is {host_endian.upper()} endian, GGUF file seems to be {file_endian.upper()} endian")
|
|
||||||
if file_endian == order:
|
if file_endian == order:
|
||||||
logger.info(f"* File is already {order.upper()} endian. Nothing to do.")
|
logger.info(f"* File is already {order} endian. Nothing to do.")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
logger.info("* Checking tensors for conversion compatibility")
|
logger.info("* Checking tensors for conversion compatibility")
|
||||||
for tensor in reader.tensors:
|
for tensor in reader.tensors:
|
||||||
|
@ -47,7 +40,7 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None
|
||||||
gguf.GGMLQuantizationType.Q6_K,
|
gguf.GGMLQuantizationType.Q6_K,
|
||||||
):
|
):
|
||||||
raise ValueError(f"Cannot handle type {tensor.tensor_type.name} for tensor {repr(tensor.name)}")
|
raise ValueError(f"Cannot handle type {tensor.tensor_type.name} for tensor {repr(tensor.name)}")
|
||||||
logger.info(f"* Preparing to convert from {file_endian.upper()} to {order.upper()}")
|
logger.info(f"* Preparing to convert from {file_endian} to {order}")
|
||||||
if args.dry_run:
|
if args.dry_run:
|
||||||
return
|
return
|
||||||
logger.warning("*** Warning *** Warning *** Warning **")
|
logger.warning("*** Warning *** Warning *** Warning **")
|
||||||
|
|
|
@ -9,8 +9,6 @@ import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# Necessary to load the local gguf package
|
# Necessary to load the local gguf package
|
||||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
|
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
|
@ -21,11 +19,11 @@ logger = logging.getLogger("gguf-dump")
|
||||||
|
|
||||||
|
|
||||||
def get_file_host_endian(reader: GGUFReader) -> tuple[str, str]:
|
def get_file_host_endian(reader: GGUFReader) -> tuple[str, str]:
|
||||||
host_endian = 'LITTLE' if np.uint32(1) == np.uint32(1).newbyteorder("<") else 'BIG'
|
file_endian = reader.endianess.name
|
||||||
if reader.byte_order == 'S':
|
if reader.byte_order == 'S':
|
||||||
file_endian = 'BIG' if host_endian == 'LITTLE' else 'LITTLE'
|
host_endian = 'BIG' if file_endian == 'LITTLE' else 'LITTLE'
|
||||||
else:
|
else:
|
||||||
file_endian = host_endian
|
host_endian = file_endian
|
||||||
return (host_endian, file_endian)
|
return (host_endian, file_endian)
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,12 +43,20 @@ def dump_metadata(reader: GGUFReader, args: argparse.Namespace) -> None:
|
||||||
pretty_type = str(field.types[-1].name)
|
pretty_type = str(field.types[-1].name)
|
||||||
|
|
||||||
log_message = f' {n:5}: {pretty_type:10} | {len(field.data):8} | {field.name}'
|
log_message = f' {n:5}: {pretty_type:10} | {len(field.data):8} | {field.name}'
|
||||||
if len(field.types) == 1:
|
if field.types:
|
||||||
curr_type = field.types[0]
|
curr_type = field.types[0]
|
||||||
if curr_type == GGUFValueType.STRING:
|
if curr_type == GGUFValueType.STRING:
|
||||||
log_message += ' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf-8')[:60]))
|
content = field.contents()
|
||||||
elif field.types[0] in reader.gguf_scalar_to_np:
|
if len(content) > 60:
|
||||||
log_message += ' = {0}'.format(field.parts[-1][0])
|
content = content[:57] + '...'
|
||||||
|
log_message += ' = {0}'.format(repr(content))
|
||||||
|
elif curr_type in reader.gguf_scalar_to_np:
|
||||||
|
log_message += ' = {0}'.format(field.contents())
|
||||||
|
else:
|
||||||
|
content = repr(field.contents(slice(6)))
|
||||||
|
if len(field.data) > 6:
|
||||||
|
content = content[:-1] + ', ...]'
|
||||||
|
log_message += ' = {0}'.format(content)
|
||||||
print(log_message) # noqa: NP100
|
print(log_message) # noqa: NP100
|
||||||
if args.no_tensors:
|
if args.no_tensors:
|
||||||
return
|
return
|
||||||
|
@ -82,15 +88,9 @@ def dump_metadata_json(reader: GGUFReader, args: argparse.Namespace) -> None:
|
||||||
curr["array_types"] = [t.name for t in field.types][1:]
|
curr["array_types"] = [t.name for t in field.types][1:]
|
||||||
if not args.json_array:
|
if not args.json_array:
|
||||||
continue
|
continue
|
||||||
itype = field.types[-1]
|
curr["value"] = field.contents()
|
||||||
if itype == GGUFValueType.STRING:
|
|
||||||
curr["value"] = [str(bytes(field.parts[idx]), encoding="utf-8") for idx in field.data]
|
|
||||||
else:
|
else:
|
||||||
curr["value"] = [pv for idx in field.data for pv in field.parts[idx].tolist()]
|
curr["value"] = field.contents()
|
||||||
elif field.types[0] == GGUFValueType.STRING:
|
|
||||||
curr["value"] = str(bytes(field.parts[-1]), encoding="utf-8")
|
|
||||||
else:
|
|
||||||
curr["value"] = field.parts[-1].tolist()[0]
|
|
||||||
if not args.no_tensors:
|
if not args.no_tensors:
|
||||||
for idx, tensor in enumerate(reader.tensors):
|
for idx, tensor in enumerate(reader.tensors):
|
||||||
tensors[tensor.name] = {
|
tensors[tensor.name] = {
|
||||||
|
|
|
@ -8,7 +8,6 @@ import sys
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import Any, Sequence, NamedTuple
|
from typing import Any, Sequence, NamedTuple
|
||||||
|
|
||||||
|
@ -27,45 +26,10 @@ class MetadataDetails(NamedTuple):
|
||||||
description: str = ''
|
description: str = ''
|
||||||
|
|
||||||
|
|
||||||
def get_byteorder(reader: gguf.GGUFReader) -> gguf.GGUFEndian:
|
|
||||||
if np.uint32(1) == np.uint32(1).newbyteorder("<"):
|
|
||||||
# Host is little endian
|
|
||||||
host_endian = gguf.GGUFEndian.LITTLE
|
|
||||||
swapped_endian = gguf.GGUFEndian.BIG
|
|
||||||
else:
|
|
||||||
# Sorry PDP or other weird systems that don't use BE or LE.
|
|
||||||
host_endian = gguf.GGUFEndian.BIG
|
|
||||||
swapped_endian = gguf.GGUFEndian.LITTLE
|
|
||||||
|
|
||||||
if reader.byte_order == "S":
|
|
||||||
return swapped_endian
|
|
||||||
else:
|
|
||||||
return host_endian
|
|
||||||
|
|
||||||
|
|
||||||
def decode_field(field: gguf.ReaderField | None) -> Any:
|
|
||||||
if field and field.types:
|
|
||||||
main_type = field.types[0]
|
|
||||||
|
|
||||||
if main_type == gguf.GGUFValueType.ARRAY:
|
|
||||||
sub_type = field.types[-1]
|
|
||||||
|
|
||||||
if sub_type == gguf.GGUFValueType.STRING:
|
|
||||||
return [str(bytes(field.parts[idx]), encoding='utf-8') for idx in field.data]
|
|
||||||
else:
|
|
||||||
return [pv for idx in field.data for pv in field.parts[idx].tolist()]
|
|
||||||
if main_type == gguf.GGUFValueType.STRING:
|
|
||||||
return str(bytes(field.parts[-1]), encoding='utf-8')
|
|
||||||
else:
|
|
||||||
return field.parts[-1][0]
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
|
def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
|
||||||
field = reader.get_field(key)
|
field = reader.get_field(key)
|
||||||
|
|
||||||
return decode_field(field)
|
return field.contents() if field else None
|
||||||
|
|
||||||
|
|
||||||
def find_token(token_list: Sequence[int], token: str) -> Sequence[int]:
|
def find_token(token_list: Sequence[int], token: str) -> Sequence[int]:
|
||||||
|
@ -93,7 +57,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
|
||||||
logger.debug(f'Removing {field.name}')
|
logger.debug(f'Removing {field.name}')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
old_val = MetadataDetails(field.types[0], decode_field(field))
|
old_val = MetadataDetails(field.types[0], field.contents())
|
||||||
val = new_metadata.get(field.name, old_val)
|
val = new_metadata.get(field.name, old_val)
|
||||||
|
|
||||||
if field.name in new_metadata:
|
if field.name in new_metadata:
|
||||||
|
@ -192,7 +156,6 @@ def main() -> None:
|
||||||
reader = gguf.GGUFReader(args.input, 'r')
|
reader = gguf.GGUFReader(args.input, 'r')
|
||||||
|
|
||||||
arch = get_field_data(reader, gguf.Keys.General.ARCHITECTURE)
|
arch = get_field_data(reader, gguf.Keys.General.ARCHITECTURE)
|
||||||
endianess = get_byteorder(reader)
|
|
||||||
|
|
||||||
token_list = get_field_data(reader, gguf.Keys.Tokenizer.LIST) or []
|
token_list = get_field_data(reader, gguf.Keys.Tokenizer.LIST) or []
|
||||||
|
|
||||||
|
@ -230,7 +193,7 @@ def main() -> None:
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
logger.info(f'* Writing: {args.output}')
|
logger.info(f'* Writing: {args.output}')
|
||||||
writer = gguf.GGUFWriter(args.output, arch=arch, endianess=endianess)
|
writer = gguf.GGUFWriter(args.output, arch=arch, endianess=reader.endianess)
|
||||||
|
|
||||||
alignment = get_field_data(reader, gguf.Keys.General.ALIGNMENT)
|
alignment = get_field_data(reader, gguf.Keys.General.ALIGNMENT)
|
||||||
if alignment is not None:
|
if alignment is not None:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "gguf"
|
name = "gguf"
|
||||||
version = "0.15.0"
|
version = "0.16.0"
|
||||||
description = "Read and write ML models in GGUF for GGML"
|
description = "Read and write ML models in GGUF for GGML"
|
||||||
authors = ["GGML <ggml@ggml.ai>"]
|
authors = ["GGML <ggml@ggml.ai>"]
|
||||||
packages = [
|
packages = [
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue