-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[Proposal] Support loading from safetensors if file is present. #1357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
patil-suraj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the PR @Narsil ! The code looks good, and I'm in favor of supporting safetensors. Is there any example repo with safetensors weights which we could try ?
Also, don't we need any changes to save_pretrained to save weights in this format ?
Currently, I didn't do any change.
However, I don't think you can override Overall I would be more conservative with saving weights as an opt-in rather than a default to keep things more gradual in changes (since saving in safetensors meaning users will have to have it to load the weights). Actually we could load the weights even without safetensors because the loading is super simple https://gist.github.com/Narsil/3edeec2669a5e94e4707aa0f901d2282
Not yet, I'm in the process of updating https://github.com/huggingface/safetensors/blob/main/bindings/python/convert.py to support ANY framework by default (just convert ALL pytorch weights of the repo and just change the extension). |
|
Design also looks good to me! Thanks a lot for working on it @Narsil ! |
|
https://huggingface.co/Narsil/stable-diffusion-v1-4/ Are now converted. Simply Or simply this: https://huggingface.co/spaces/safetensors/convert |
|
Should I create a test for this ? (Could have a tiny random pipeline with only safetensors weights to test, but I would need to modify the CI somewhere to install it.) |
Yes a test would be great! Could you maybe just Line 189 in 44e56de
|
|
It should then be used automatically be the Github Runner (cc @anton-l ) |
|
The loading time is crazy fast, I'm getting ~1.3 sec on CPU 🔥 |
should be lower than that, I think I know why, I found a bug creating the test, I'm fixing everything right now. |
|
Okay I think you should check again the logic as I'm not confident about this change.
Since both This modification does the fix, but adds an extra network call to fetch the model_info to filter out the files before snapshot_download. This is the smallest change I could think of, but maybe we can do even better (and remove the However this feels like a larger change, which would require a separate PR. The idea would be to remove all network calls when using Edit: import torch
from diffusers import StableDiffusionPipeline
import datetime
start = datetime.datetime.now()
pipe = StableDiffusionPipeline.from_pretrained("Narsil/stable-diffusion-v1-4", torch_dtype=torch.float16).to("cuda")
print(f"Loaded in {datetime.datetime.now() - start}")
image = pipe("example prompt", num_inference_steps=2).images[0] |
|
Also if that's interesting I could open another PR to load the pipeline directly on CUDA, this will make a difference with safetensors (and the appropriate flag). |
src/diffusers/modeling_utils.py
Outdated
| pretrained_model_name_or_path, weights_name=SAFETENSORS_WEIGHTS_NAME, **kwargs | ||
| ) | ||
| return model_file, False | ||
| except: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we catch with a more explicit "File doesn't exist" error message here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| except: | |
| except EnvironmentError: |
It can only be an EnvironmentError no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably a lot can go wrong here (OOM, no disk space, network error), but we can definitely except on that. I'm not sure which kind of exception is thrown if the file is missing though.
src/diffusers/modeling_utils.py
Outdated
| return model | ||
|
|
||
|
|
||
| def get_model_file(pretrained_model_name_or_path, **kwargs) -> Tuple[str, bool]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The functions become a bit too nested for me here and hard to follow - could we maybe just remove this function and only keep _get_model_file? I don't think those 8 lines deserve a new function here, I'd prefer to just call _get_model_file directly above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It mostly keeps the caller tidyer IMO which is already a complex function, this nested code would have to live there, making it even more complex.
It's 8 lines with 2 indentations, it's complex in my book, so keeping this logic on its own is more readable to me.
It's (Try getting safetensors, if anything goes wrong, fetch the pytorch one).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to remove it, but I think the calling function will get messier.
| return model_file, True | ||
|
|
||
|
|
||
| def _get_model_file( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
src/diffusers/modeling_utils.py
Outdated
| if device_map is None: | ||
| param_device = "cpu" | ||
| state_dict = load_state_dict(model_file) | ||
| state_dict = load_state_dict(model_file, is_pytorch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually could we not use a flag here at all? And instead decide in load_state_dict depending on the model file name (it has to end with either SAFETENSORS_WEIGHTS_NAME or WEIGHTS_NAME whether it's is_pytorch or not. It's a bit confusing to me to pass around a is_pytorch flag here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again I think the behavior should not on the filename, but on what you did previously.
It's bikeshedding at this point, I'll try to find something even more elegant.
src/diffusers/modeling_utils.py
Outdated
|
|
||
|
|
||
| def load_state_dict(checkpoint_file: Union[str, os.PathLike]): | ||
| def load_state_dict(checkpoint_file: Union[str, os.PathLike], is_pytorch: bool): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def load_state_dict(checkpoint_file: Union[str, os.PathLike], is_pytorch: bool): | |
| def load_state_dict(checkpoint_file: Union[str, os.PathLike]): |
src/diffusers/modeling_utils.py
Outdated
| """ | ||
| try: | ||
| return torch.load(checkpoint_file, map_location="cpu") | ||
| if is_pytorch: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if is_pytorch: | |
| if Path(checkpoint_file).name == WEIGHTS_NAME: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel this is wrong, we already know at this point what kind of file it is so we don't have to guess.
I agree the is_pytorch is slightly odd but I don't think doing things based on filename is great.
I'll try to figure something out that's better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't find anything satisfactory to me so I went your way.
src/diffusers/modeling_utils.py
Outdated
| if is_pytorch: | ||
| return torch.load(checkpoint_file, map_location="cpu") | ||
| else: | ||
| return load_file(checkpoint_file, device="cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just for readability could we maybe do the following:
| return load_file(checkpoint_file, device="cpu") | |
| return safetensors.torch.load_file(checkpoint_file, device="cpu") |
then the reader directly knows the loading comes from safetensors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure.
src/diffusers/pipeline_utils.py
Outdated
| user_agent["custom_pipeline"] = custom_pipeline | ||
| user_agent = http_user_agent(user_agent) | ||
|
|
||
| info = model_info( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we wrap all of this into a function, called is_compatible_with_safetensors()?
And then just do something like:
if is_safetensors_available() and is_compatible_with_safetensors():
ignore_patterns.append("*.bin")There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can.
I'd rather leave model_info on its own though.
It's a network call, and function names like is_compatible seem innocuous and fast (which doing network operations is not :))
|
PR looks very nice! I left some feedback in modeling_utils regarding readability - always a bit worried to introduce too many nested functions and flags. Overall I think this is a really nice implementation! In pipeline_utils it would just be nice to also only do the extra call to the Hub if safetensors are installed and maybe move the whole logic into one simple sounding "is pipe compatible with safetensors" function |
I don't necessarily agree but I went your way, I feel it's more style preferences so happy to oblige.
Hmm we should probably just own that call is what I'm suggesting since it's exactly what snapshot download does: This calls for a larger PR though.
|
+ modify download logic to not download pytorch file if not necessary.
| user_agent["custom_pipeline"] = custom_pipeline | ||
| user_agent = http_user_agent(user_agent) | ||
|
|
||
| if is_safetensors_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool that works for me - agree with your point that "calling the hub" is too hidden when fully put in is_safetensors_compatible(...)
| ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} | ||
| ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) | ||
|
|
||
| USE_TF = os.environ.get("USE_TF", "AUTO").upper() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually we should remove this -> we don't have TF in diffusers yet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
happy to do in a follow up PR if it's too much work here. Essentially we copy-pasted this whole file from transformers and forgot to remove the USE_TF stuff -> we should remove all TF related code
| _torch_available = False | ||
| else: | ||
| logger.info("Disabling PyTorch because USE_TF is set") | ||
| logger.info("Disabling PyTorch because USE_TORCH is set") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| logger.info("Disabling PyTorch because USE_TORCH is set") | |
| logger.info("Disabling PyTorch because USE_TORCH is not set") |
| except importlib_metadata.PackageNotFoundError: | ||
| _safetensors_available = False | ||
| else: | ||
| logger.info("Disabling Safetensors because USE_TF is set") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| logger.info("Disabling Safetensors because USE_TF is set") | |
| logger.info("Disabling Safetensors because USE_SAFETENSORS is not set") |
Thanks for making the changes - fully agree that it's more style preference! Left some comments about Happy to merge from my side though! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice PR - thanks a lot!
Co-authored-by: Patrick von Platen <[email protected]>
…ingface#1357) * [Proposal] Support loading from safetensors if file is present. * Style. * Fix. * Adding some test to check loading logic. + modify download logic to not download pytorch file if not necessary. * Fixing the logic. * Adressing comments. * factor out into a function. * Remove dead function. * Typo. * Extra fetch only if safetensors is there. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
…ingface#1357) * [Proposal] Support loading from safetensors if file is present. * Style. * Fix. * Adding some test to check loading logic. + modify download logic to not download pytorch file if not necessary. * Fixing the logic. * Adressing comments. * factor out into a function. * Remove dead function. * Typo. * Extra fetch only if safetensors is there. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
Current proposed behavior mimics transformers.
safetensors.safetensorsis present AND the file is present (on the remote model or locally).the load from it preferrably.