convert : ability to lazy-load safetensors remotely without downloading to disk (#12820)

* gguf util : add SafetensorRemote

* fix style

* convert: add --remote option

* convert : allow using lazy remote tensors

It's a bit slow for now since everything is blocking and single-threaded.

* correct metadata.name

* small style fix

* support HF_TOKEN

* convert : use writeable buffer for remote lazy tensors

* convert : fix flake8 lint regarding lamdba assigment

* multithreaded download

* multithread: print debug

* fix style

* Revert "multithreaded download"

This reverts commit 42fc895ace385edc972ad819c76c704aeea61791.

* bring back _get_request_headers

---------

Co-authored-by: Francis Couture-Harpin <git@compilade.net>
This commit is contained in:
Xuan-Son Nguyen 2025-04-10 17:24:44 +02:00 committed by GitHub
parent fe5b78c896
commit 64eda5deb9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 244 additions and 7 deletions

View file

@ -65,6 +65,7 @@ class Model:
model_name: str | None
metadata_override: Path | None
dir_model_card: Path
remote_hf_model_id: str | None
# subclasses should define this!
model_arch: gguf.MODEL_ARCH
@ -73,7 +74,7 @@ class Model:
use_temp_file: bool = False, eager: bool = False,
metadata_override: Path | None = None, model_name: str | None = None,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
small_first_shard: bool = False, hparams: dict[str, Any] | None = None):
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None):
if type(self) is Model:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
@ -83,11 +84,24 @@ class Model:
self.is_big_endian = is_big_endian
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
self.use_temp_file = use_temp_file
self.lazy = not eager
self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors")
self.is_safetensors = len(self.part_names) > 0
if not self.is_safetensors:
self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
self.lazy = not eager or (remote_hf_model_id is not None)
self.remote_hf_model_id = remote_hf_model_id
if remote_hf_model_id is not None:
self.is_safetensors = True
def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
self.tensor_names = set(name for name in remote_tensors.keys())
for name, remote_tensor in gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id).items():
yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor))
self.get_tensors = get_remote_tensors
else:
self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors")
self.is_safetensors = len(self.part_names) > 0
if not self.is_safetensors:
self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
self.hparams = Model.load_hparams(self.dir_model) if hparams is None else hparams
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
@ -393,6 +407,10 @@ class Model:
self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model_card, self.model_name, total_params)
# If we are using HF model id, set the metadata name to the model id
if self.remote_hf_model_id:
self.metadata.name = self.remote_hf_model_id
# Fallback to model directory name if metadata name is still missing
if self.metadata.name is None:
self.metadata.name = self.dir_model.name
@ -5403,6 +5421,14 @@ class LazyTorchTensor(gguf.LazyBase):
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
return cast(torch.Tensor, lazy)
@classmethod
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
dtype = cls._dtype_str_map[remote_tensor.dtype]
shape = remote_tensor.shape
meta = cls.meta_with_dtype_and_shape(dtype, shape)
lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape))
return cast(torch.Tensor, lazy)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
del types # unused
@ -5480,6 +5506,10 @@ def parse_args() -> argparse.Namespace:
"--print-supported-models", action="store_true",
help="Print the supported models"
)
parser.add_argument(
"--remote", action="store_true",
help="(Experimental) Read safetensors file remotely without downloading to disk. Config and tokenizer files will still be downloaded. To use this feature, you need to specify Hugging Face model repo name instead of a local directory. For example: 'HuggingFaceTB/SmolLM2-1.7B-Instruct'. Note: To access gated repo, set HF_TOKEN environment variable to your Hugging Face token.",
)
args = parser.parse_args()
if not args.print_supported_models and args.model is None:
@ -5520,6 +5550,14 @@ def main() -> None:
dir_model = args.model
if args.remote:
from huggingface_hub import snapshot_download
local_dir = snapshot_download(
repo_id=str(dir_model),
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
dir_model = Path(local_dir)
logger.info(f"Downloaded config and tokenizer to {local_dir}")
if not dir_model.is_dir():
logger.error(f'Error: {args.model} is not a directory')
sys.exit(1)
@ -5541,6 +5579,9 @@ def main() -> None:
if args.outfile is not None:
fname_out = args.outfile
elif args.remote:
# if remote, use the model ID as the output file name
fname_out = Path("./" + str(args.model).replace("/", "-") + "-{ftype}.gguf")
else:
fname_out = dir_model
@ -5564,7 +5605,8 @@ def main() -> None:
metadata_override=args.metadata, model_name=args.model_name,
split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split)
small_first_shard=args.no_tensor_first_split,
remote_hf_model_id=str(args.model) if args.remote else None)
if args.vocab_only:
logger.info("Exporting model vocab...")