3030 CONFIG_NAME ,
3131 DIFFUSERS_CACHE ,
3232 HUGGINGFACE_CO_RESOLVE_ENDPOINT ,
33+ SAFETENSORS_WEIGHTS_NAME ,
3334 WEIGHTS_NAME ,
3435 is_accelerate_available ,
36+ is_safetensors_available ,
3537 is_torch_version ,
3638 logging ,
3739)
5153 from accelerate .utils import set_module_tensor_to_device
5254 from accelerate .utils .versions import is_torch_version
5355
56+ if is_safetensors_available ():
57+ import safetensors
58+
5459
5560def get_parameter_device (parameter : torch .nn .Module ):
5661 try :
@@ -84,10 +89,13 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
8489
8590def load_state_dict (checkpoint_file : Union [str , os .PathLike ]):
8691 """
87- Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
92+ Reads a checkpoint file, returning properly formatted errors if they arise.
8893 """
8994 try :
90- return torch .load (checkpoint_file , map_location = "cpu" )
95+ if os .path .basename (checkpoint_file ) == WEIGHTS_NAME :
96+ return torch .load (checkpoint_file , map_location = "cpu" )
97+ else :
98+ return safetensors .torch .load_file (checkpoint_file , device = "cpu" )
9199 except Exception as e :
92100 try :
93101 with open (checkpoint_file ) as f :
@@ -104,7 +112,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
104112 ) from e
105113 except (UnicodeDecodeError , ValueError ):
106114 raise OSError (
107- f"Unable to load weights from pytorch checkpoint file for '{ checkpoint_file } ' "
115+ f"Unable to load weights from checkpoint file for '{ checkpoint_file } ' "
108116 f"at '{ checkpoint_file } '. "
109117 "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
110118 )
@@ -375,75 +383,39 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
375383
376384 # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
377385 # Load model
378- pretrained_model_name_or_path = str (pretrained_model_name_or_path )
379- if os .path .isdir (pretrained_model_name_or_path ):
380- if os .path .isfile (os .path .join (pretrained_model_name_or_path , WEIGHTS_NAME )):
381- # Load from a PyTorch checkpoint
382- model_file = os .path .join (pretrained_model_name_or_path , WEIGHTS_NAME )
383- elif subfolder is not None and os .path .isfile (
384- os .path .join (pretrained_model_name_or_path , subfolder , WEIGHTS_NAME )
385- ):
386- model_file = os .path .join (pretrained_model_name_or_path , subfolder , WEIGHTS_NAME )
387- else :
388- raise EnvironmentError (
389- f"Error no file named { WEIGHTS_NAME } found in directory { pretrained_model_name_or_path } ."
390- )
391- else :
386+
387+ model_file = None
388+ if is_safetensors_available ():
392389 try :
393- # Load from URL or cache if already cached
394- model_file = hf_hub_download (
390+ model_file = _get_model_file (
395391 pretrained_model_name_or_path ,
396- filename = WEIGHTS_NAME ,
392+ weights_name = SAFETENSORS_WEIGHTS_NAME ,
397393 cache_dir = cache_dir ,
398394 force_download = force_download ,
399- proxies = proxies ,
400395 resume_download = resume_download ,
396+ proxies = proxies ,
401397 local_files_only = local_files_only ,
402398 use_auth_token = use_auth_token ,
403- user_agent = user_agent ,
404- subfolder = subfolder ,
405399 revision = revision ,
400+ subfolder = subfolder ,
401+ user_agent = user_agent ,
406402 )
407-
408- except RepositoryNotFoundError :
409- raise EnvironmentError (
410- f"{ pretrained_model_name_or_path } is not a local folder and is not a valid model identifier "
411- "listed on 'https://huggingface.co/models'\n If this is a private repository, make sure to pass a "
412- "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
413- "login`."
414- )
415- except RevisionNotFoundError :
416- raise EnvironmentError (
417- f"{ revision } is not a valid git identifier (branch name, tag name or commit id) that exists for "
418- "this model name. Check the model page at "
419- f"'https://huggingface.co/{ pretrained_model_name_or_path } ' for available revisions."
420- )
421- except EntryNotFoundError :
422- raise EnvironmentError (
423- f"{ pretrained_model_name_or_path } does not appear to have a file named { WEIGHTS_NAME } ."
424- )
425- except HTTPError as err :
426- raise EnvironmentError (
427- "There was a specific connection error when trying to load"
428- f" { pretrained_model_name_or_path } :\n { err } "
429- )
430- except ValueError :
431- raise EnvironmentError (
432- f"We couldn't connect to '{ HUGGINGFACE_CO_RESOLVE_ENDPOINT } ' to load this model, couldn't find it"
433- f" in the cached files and it looks like { pretrained_model_name_or_path } is not the path to a"
434- f" directory containing a file named { WEIGHTS_NAME } or"
435- " \n Checkout your internet connection or see how to run the library in"
436- " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
437- )
438- except EnvironmentError :
439- raise EnvironmentError (
440- f"Can't load the model for '{ pretrained_model_name_or_path } '. If you were trying to load it from "
441- "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
442- f"Otherwise, make sure '{ pretrained_model_name_or_path } ' is the correct path to a directory "
443- f"containing a file named { WEIGHTS_NAME } "
444- )
445-
446- # restore default dtype
403+ except :
404+ pass
405+ if model_file is None :
406+ model_file = _get_model_file (
407+ pretrained_model_name_or_path ,
408+ weights_name = WEIGHTS_NAME ,
409+ cache_dir = cache_dir ,
410+ force_download = force_download ,
411+ resume_download = resume_download ,
412+ proxies = proxies ,
413+ local_files_only = local_files_only ,
414+ use_auth_token = use_auth_token ,
415+ revision = revision ,
416+ subfolder = subfolder ,
417+ user_agent = user_agent ,
418+ )
447419
448420 if low_cpu_mem_usage :
449421 # Instantiate model with empty weights
@@ -691,3 +663,88 @@ def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
691663 return unwrap_model (model .module )
692664 else :
693665 return model
666+
667+
668+ def _get_model_file (
669+ pretrained_model_name_or_path ,
670+ * ,
671+ weights_name ,
672+ subfolder ,
673+ cache_dir ,
674+ force_download ,
675+ proxies ,
676+ resume_download ,
677+ local_files_only ,
678+ use_auth_token ,
679+ user_agent ,
680+ revision ,
681+ ):
682+ pretrained_model_name_or_path = str (pretrained_model_name_or_path )
683+ if os .path .isdir (pretrained_model_name_or_path ):
684+ if os .path .isfile (os .path .join (pretrained_model_name_or_path , weights_name )):
685+ # Load from a PyTorch checkpoint
686+ model_file = os .path .join (pretrained_model_name_or_path , weights_name )
687+ return model_file
688+ elif subfolder is not None and os .path .isfile (
689+ os .path .join (pretrained_model_name_or_path , subfolder , weights_name )
690+ ):
691+ model_file = os .path .join (pretrained_model_name_or_path , subfolder , weights_name )
692+ return model_file
693+ else :
694+ raise EnvironmentError (
695+ f"Error no file named { weights_name } found in directory { pretrained_model_name_or_path } ."
696+ )
697+ else :
698+ try :
699+ # Load from URL or cache if already cached
700+ model_file = hf_hub_download (
701+ pretrained_model_name_or_path ,
702+ filename = weights_name ,
703+ cache_dir = cache_dir ,
704+ force_download = force_download ,
705+ proxies = proxies ,
706+ resume_download = resume_download ,
707+ local_files_only = local_files_only ,
708+ use_auth_token = use_auth_token ,
709+ user_agent = user_agent ,
710+ subfolder = subfolder ,
711+ revision = revision ,
712+ )
713+ return model_file
714+
715+ except RepositoryNotFoundError :
716+ raise EnvironmentError (
717+ f"{ pretrained_model_name_or_path } is not a local folder and is not a valid model identifier "
718+ "listed on 'https://huggingface.co/models'\n If this is a private repository, make sure to pass a "
719+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
720+ "login`."
721+ )
722+ except RevisionNotFoundError :
723+ raise EnvironmentError (
724+ f"{ revision } is not a valid git identifier (branch name, tag name or commit id) that exists for "
725+ "this model name. Check the model page at "
726+ f"'https://huggingface.co/{ pretrained_model_name_or_path } ' for available revisions."
727+ )
728+ except EntryNotFoundError :
729+ raise EnvironmentError (
730+ f"{ pretrained_model_name_or_path } does not appear to have a file named { weights_name } ."
731+ )
732+ except HTTPError as err :
733+ raise EnvironmentError (
734+ f"There was a specific connection error when trying to load { pretrained_model_name_or_path } :\n { err } "
735+ )
736+ except ValueError :
737+ raise EnvironmentError (
738+ f"We couldn't connect to '{ HUGGINGFACE_CO_RESOLVE_ENDPOINT } ' to load this model, couldn't find it"
739+ f" in the cached files and it looks like { pretrained_model_name_or_path } is not the path to a"
740+ f" directory containing a file named { weights_name } or"
741+ " \n Checkout your internet connection or see how to run the library in"
742+ " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
743+ )
744+ except EnvironmentError :
745+ raise EnvironmentError (
746+ f"Can't load the model for '{ pretrained_model_name_or_path } '. If you were trying to load it from "
747+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
748+ f"Otherwise, make sure '{ pretrained_model_name_or_path } ' is the correct path to a directory "
749+ f"containing a file named { weights_name } "
750+ )
0 commit comments