Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,7 @@ def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Call
logger.info(f"gguf: indexing model part '{part_name}'")
ctx: ContextManager[Any]
if is_safetensors:
from safetensors import safe_open
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
ctx = cast(ContextManager[Any], gguf.utility.SafetensorsLocal(self.dir_model / part_name))
else:
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))

Expand All @@ -228,18 +227,18 @@ def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Call

for name in model_part.keys():
if is_safetensors:
data: gguf.utility.LocalTensor = model_part[name]
if self.lazy:
data = model_part.get_slice(name)
data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data) # noqa: E731
data_gen = lambda data=data: LazyTorchTensor.from_local_tensor(data) # noqa: E731
else:
data = model_part.get_tensor(name)
data_gen = lambda data=data: data # noqa: E731
dtype = LazyTorchTensor._dtype_str_map[data.dtype]
data_gen = lambda data=data, dtype=dtype: torch.from_numpy(data.mmap_bytes()).view(dtype).reshape(data.shape) # noqa: E731
else:
data = model_part[name]
data_torch: Tensor = model_part[name]
if self.lazy:
data_gen = lambda data=data: LazyTorchTensor.from_eager(data) # noqa: E731
data_gen = lambda data=data_torch: LazyTorchTensor.from_eager(data) # noqa: E731
else:
data_gen = lambda data=data: data # noqa: E731
data_gen = lambda data=data_torch: data # noqa: E731
tensors[name] = data_gen

# verify tensor name presence and identify potentially missing files
Expand Down Expand Up @@ -10002,6 +10001,16 @@ def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[...] if len(s.get_shape()) == 0 else s[:])
return cast(torch.Tensor, lazy)

@classmethod
def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor:
def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor:
dtype = cls._dtype_str_map[tensor.dtype]
return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape)
dtype = cls._dtype_str_map[t.dtype]
shape = t.shape
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r))
return cast(torch.Tensor, lazy)

@classmethod
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
dtype = cls._dtype_str_map[remote_tensor.dtype]
Expand Down
80 changes: 80 additions & 0 deletions gguf-py/gguf/utility.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Literal

import os
import json
import numpy as np


def fill_templated_filename(filename: str, output_type: str | None) -> str:
Expand Down Expand Up @@ -177,6 +179,10 @@ def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]:
except KeyError as e:
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")

# order by name (same as default safetensors behavior)
# ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
res = dict(sorted(res.items(), key=lambda t: t[0]))

return res

@classmethod
Expand Down Expand Up @@ -266,3 +272,77 @@ def _get_request_headers(cls) -> dict[str, str]:
if os.environ.get("HF_TOKEN"):
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
return headers


@dataclass
class LocalTensorRange:
filename: Path
offset: int
size: int


@dataclass
class LocalTensor:
dtype: str
shape: tuple[int, ...]
data_range: LocalTensorRange

def mmap_bytes(self) -> np.ndarray:
return np.memmap(self.data_range.filename, offset=self.data_range.offset, shape=self.data_range.size)


class SafetensorsLocal:
"""
Read a safetensors file from the local filesystem.

Custom parsing gives a bit more control over the memory usage.
The official safetensors library doesn't expose file ranges.
"""
ALIGNMENT = 8 # bytes

tensors: dict[str, LocalTensor]

def __init__(self, filename: Path):
with open(filename, "rb") as f:
metadata_length = int.from_bytes(f.read(8), byteorder='little')
file_size = os.stat(filename).st_size
if file_size < 8 + metadata_length:
raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {file_size}")

metadata_str = f.read(metadata_length).decode('utf-8')
try:
metadata = json.loads(metadata_str)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse safetensors metadata as JSON: {e}")

data_start_offset = f.tell()
alignment = self.ALIGNMENT
if data_start_offset % alignment != 0:
data_start_offset += alignment - (data_start_offset % alignment)

tensors: dict[str, LocalTensor] = {}
for name, meta in metadata.items():
if name == "__metadata__":
# ignore metadata, it's not a tensor
continue

tensors[name] = LocalTensor(
dtype=meta["dtype"],
shape=tuple(meta["shape"]),
data_range=LocalTensorRange(
filename,
data_start_offset + meta["data_offsets"][0],
meta["data_offsets"][1] - meta["data_offsets"][0],
),
)

# order by name (same as default safetensors behavior)
# ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
self.tensors = dict(sorted(tensors.items(), key=lambda t: t[0]))

def __enter__(self, *args, **kwargs):
del args, kwargs # unused
return self.tensors

def __exit__(self, *args, **kwargs):
del args, kwargs # unused