Skip to content

Commit 29d5c20

Browse files
[Proposal] Support loading from safetensors if file is present. (huggingface#1357)
* [Proposal] Support loading from safetensors if file is present. * Style. * Fix. * Adding some test to check loading logic. + modify download logic to not download pytorch file if not necessary. * Fixing the logic. * Adressing comments. * factor out into a function. * Remove dead function. * Typo. * Extra fetch only if safetensors is there. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 2ef6813 commit 29d5c20

File tree

5 files changed

+169
-66
lines changed

5 files changed

+169
-66
lines changed

dependency_versions_table.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"pytest": "pytest",
2222
"pytest-timeout": "pytest-timeout",
2323
"pytest-xdist": "pytest-xdist",
24+
"safetensors": "safetensors",
2425
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
2526
"scipy": "scipy",
2627
"regex": "regex!=2019.12.17",

modeling_utils.py

Lines changed: 120 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030
CONFIG_NAME,
3131
DIFFUSERS_CACHE,
3232
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
33+
SAFETENSORS_WEIGHTS_NAME,
3334
WEIGHTS_NAME,
3435
is_accelerate_available,
36+
is_safetensors_available,
3537
is_torch_version,
3638
logging,
3739
)
@@ -51,6 +53,9 @@
5153
from accelerate.utils import set_module_tensor_to_device
5254
from accelerate.utils.versions import is_torch_version
5355

56+
if is_safetensors_available():
57+
import safetensors
58+
5459

5560
def get_parameter_device(parameter: torch.nn.Module):
5661
try:
@@ -84,10 +89,13 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
8489

8590
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
8691
"""
87-
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
92+
Reads a checkpoint file, returning properly formatted errors if they arise.
8893
"""
8994
try:
90-
return torch.load(checkpoint_file, map_location="cpu")
95+
if os.path.basename(checkpoint_file) == WEIGHTS_NAME:
96+
return torch.load(checkpoint_file, map_location="cpu")
97+
else:
98+
return safetensors.torch.load_file(checkpoint_file, device="cpu")
9199
except Exception as e:
92100
try:
93101
with open(checkpoint_file) as f:
@@ -104,7 +112,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
104112
) from e
105113
except (UnicodeDecodeError, ValueError):
106114
raise OSError(
107-
f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
115+
f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
108116
f"at '{checkpoint_file}'. "
109117
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
110118
)
@@ -375,75 +383,39 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
375383

376384
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
377385
# Load model
378-
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
379-
if os.path.isdir(pretrained_model_name_or_path):
380-
if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
381-
# Load from a PyTorch checkpoint
382-
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
383-
elif subfolder is not None and os.path.isfile(
384-
os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
385-
):
386-
model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
387-
else:
388-
raise EnvironmentError(
389-
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
390-
)
391-
else:
386+
387+
model_file = None
388+
if is_safetensors_available():
392389
try:
393-
# Load from URL or cache if already cached
394-
model_file = hf_hub_download(
390+
model_file = _get_model_file(
395391
pretrained_model_name_or_path,
396-
filename=WEIGHTS_NAME,
392+
weights_name=SAFETENSORS_WEIGHTS_NAME,
397393
cache_dir=cache_dir,
398394
force_download=force_download,
399-
proxies=proxies,
400395
resume_download=resume_download,
396+
proxies=proxies,
401397
local_files_only=local_files_only,
402398
use_auth_token=use_auth_token,
403-
user_agent=user_agent,
404-
subfolder=subfolder,
405399
revision=revision,
400+
subfolder=subfolder,
401+
user_agent=user_agent,
406402
)
407-
408-
except RepositoryNotFoundError:
409-
raise EnvironmentError(
410-
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
411-
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
412-
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
413-
"login`."
414-
)
415-
except RevisionNotFoundError:
416-
raise EnvironmentError(
417-
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
418-
"this model name. Check the model page at "
419-
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
420-
)
421-
except EntryNotFoundError:
422-
raise EnvironmentError(
423-
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}."
424-
)
425-
except HTTPError as err:
426-
raise EnvironmentError(
427-
"There was a specific connection error when trying to load"
428-
f" {pretrained_model_name_or_path}:\n{err}"
429-
)
430-
except ValueError:
431-
raise EnvironmentError(
432-
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
433-
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
434-
f" directory containing a file named {WEIGHTS_NAME} or"
435-
" \nCheckout your internet connection or see how to run the library in"
436-
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
437-
)
438-
except EnvironmentError:
439-
raise EnvironmentError(
440-
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
441-
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
442-
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
443-
f"containing a file named {WEIGHTS_NAME}"
444-
)
445-
446-
# restore default dtype
403+
except:
404+
pass
405+
if model_file is None:
406+
model_file = _get_model_file(
407+
pretrained_model_name_or_path,
408+
weights_name=WEIGHTS_NAME,
409+
cache_dir=cache_dir,
410+
force_download=force_download,
411+
resume_download=resume_download,
412+
proxies=proxies,
413+
local_files_only=local_files_only,
414+
use_auth_token=use_auth_token,
415+
revision=revision,
416+
subfolder=subfolder,
417+
user_agent=user_agent,
418+
)
447419

448420
if low_cpu_mem_usage:
449421
# Instantiate model with empty weights
@@ -691,3 +663,88 @@ def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
691663
return unwrap_model(model.module)
692664
else:
693665
return model
666+
667+
668+
def _get_model_file(
669+
pretrained_model_name_or_path,
670+
*,
671+
weights_name,
672+
subfolder,
673+
cache_dir,
674+
force_download,
675+
proxies,
676+
resume_download,
677+
local_files_only,
678+
use_auth_token,
679+
user_agent,
680+
revision,
681+
):
682+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
683+
if os.path.isdir(pretrained_model_name_or_path):
684+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
685+
# Load from a PyTorch checkpoint
686+
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
687+
return model_file
688+
elif subfolder is not None and os.path.isfile(
689+
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
690+
):
691+
model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
692+
return model_file
693+
else:
694+
raise EnvironmentError(
695+
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
696+
)
697+
else:
698+
try:
699+
# Load from URL or cache if already cached
700+
model_file = hf_hub_download(
701+
pretrained_model_name_or_path,
702+
filename=weights_name,
703+
cache_dir=cache_dir,
704+
force_download=force_download,
705+
proxies=proxies,
706+
resume_download=resume_download,
707+
local_files_only=local_files_only,
708+
use_auth_token=use_auth_token,
709+
user_agent=user_agent,
710+
subfolder=subfolder,
711+
revision=revision,
712+
)
713+
return model_file
714+
715+
except RepositoryNotFoundError:
716+
raise EnvironmentError(
717+
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
718+
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
719+
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
720+
"login`."
721+
)
722+
except RevisionNotFoundError:
723+
raise EnvironmentError(
724+
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
725+
"this model name. Check the model page at "
726+
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
727+
)
728+
except EntryNotFoundError:
729+
raise EnvironmentError(
730+
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
731+
)
732+
except HTTPError as err:
733+
raise EnvironmentError(
734+
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
735+
)
736+
except ValueError:
737+
raise EnvironmentError(
738+
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
739+
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
740+
f" directory containing a file named {weights_name} or"
741+
" \nCheckout your internet connection or see how to run the library in"
742+
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
743+
)
744+
except EnvironmentError:
745+
raise EnvironmentError(
746+
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
747+
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
748+
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
749+
f"containing a file named {weights_name}"
750+
)

pipeline_utils.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
import diffusers
2828
import PIL
29-
from huggingface_hub import snapshot_download
29+
from huggingface_hub import model_info, snapshot_download
3030
from packaging import version
3131
from PIL import Image
3232
from tqdm.auto import tqdm
@@ -44,6 +44,7 @@
4444
BaseOutput,
4545
deprecate,
4646
is_accelerate_available,
47+
is_safetensors_available,
4748
is_torch_version,
4849
is_transformers_available,
4950
logging,
@@ -117,6 +118,23 @@ class AudioPipelineOutput(BaseOutput):
117118
audios: np.ndarray
118119

119120

121+
def is_safetensors_compatible(info) -> bool:
122+
filenames = set(sibling.rfilename for sibling in info.siblings)
123+
pt_filenames = set(filename for filename in filenames if filename.endswith(".bin"))
124+
is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames)
125+
for pt_filename in pt_filenames:
126+
prefix, raw = os.path.split(pt_filename)
127+
if raw == "pytorch_model.bin":
128+
# transformers specific
129+
sf_filename = os.path.join(prefix, "model.safetensors")
130+
else:
131+
sf_filename = pt_filename[: -len(".bin")] + ".safetensors"
132+
if sf_filename not in filenames:
133+
logger.warning("{sf_filename} not found")
134+
is_safetensors_compatible = False
135+
return is_safetensors_compatible
136+
137+
120138
class DiffusionPipeline(ConfigMixin):
121139
r"""
122140
Base class for all models.
@@ -459,7 +477,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
459477
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
460478

461479
# make sure we don't download flax weights
462-
ignore_patterns = "*.msgpack"
480+
ignore_patterns = ["*.msgpack"]
463481

464482
if custom_pipeline is not None:
465483
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
@@ -473,6 +491,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
473491
user_agent["custom_pipeline"] = custom_pipeline
474492
user_agent = http_user_agent(user_agent)
475493

494+
if is_safetensors_available():
495+
info = model_info(
496+
pretrained_model_name_or_path,
497+
use_auth_token=use_auth_token,
498+
revision=revision,
499+
)
500+
if is_safetensors_compatible(info):
501+
ignore_patterns.append("*.bin")
502+
476503
# download all allow_patterns
477504
cached_folder = snapshot_download(
478505
pretrained_model_name_or_path,

utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
is_inflect_available,
2929
is_modelcards_available,
3030
is_onnx_available,
31+
is_safetensors_available,
3132
is_scipy_available,
3233
is_tf_available,
3334
is_torch_available,
@@ -69,6 +70,7 @@
6970
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
7071
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
7172
ONNX_WEIGHTS_NAME = "model.onnx"
73+
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
7274
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
7375
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
7476
DIFFUSERS_CACHE = default_cache_path

utils/import_utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
4343
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
4444
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
45+
USE_SAFETENSORS = os.environ.get("USE_SAFETENSORS", "AUTO").upper()
4546

4647
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
4748

@@ -55,7 +56,7 @@
5556
except importlib_metadata.PackageNotFoundError:
5657
_torch_available = False
5758
else:
58-
logger.info("Disabling PyTorch because USE_TF is set")
59+
logger.info("Disabling PyTorch because USE_TORCH is set")
5960
_torch_available = False
6061

6162

@@ -109,6 +110,17 @@
109110
else:
110111
_flax_available = False
111112

113+
if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES:
114+
_safetensors_available = importlib.util.find_spec("safetensors") is not None
115+
if _safetensors_available:
116+
try:
117+
_safetensors_version = importlib_metadata.version("safetensors")
118+
logger.info(f"Safetensors version {_safetensors_version} available.")
119+
except importlib_metadata.PackageNotFoundError:
120+
_safetensors_available = False
121+
else:
122+
logger.info("Disabling Safetensors because USE_TF is set")
123+
_safetensors_available = False
112124

113125
_transformers_available = importlib.util.find_spec("transformers") is not None
114126
try:
@@ -190,6 +202,10 @@ def is_torch_available():
190202
return _torch_available
191203

192204

205+
def is_safetensors_available():
206+
return _safetensors_available
207+
208+
193209
def is_tf_available():
194210
return _tf_available
195211

0 commit comments

Comments
 (0)