Skip to content

Conversation

@Narsil
Copy link
Contributor

@Narsil Narsil commented Nov 21, 2022

Current proposed behavior mimics transformers.

  • Do not enforce safetensors.
  • If safetensors is present AND the file is present (on the remote model or locally).
    the load from it preferrably.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 21, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR @Narsil ! The code looks good, and I'm in favor of supporting safetensors. Is there any example repo with safetensors weights which we could try ?

Also, don't we need any changes to save_pretrained to save weights in this format ?

@Narsil
Copy link
Contributor Author

Narsil commented Nov 21, 2022

Also, don't we need any changes to save_pretrained to save weights in this format ?

Currently, I didn't do any change.

pipeline.save_pretrained will call all modules methods on its own which will make it harder to use safetensors.
model.save_pretrained does support save_function which makes it easy to support safetensors (by users I mean)

However, I don't think you can override save_function within the pipeline.save_pretrained.

Overall I would be more conservative with saving weights as an opt-in rather than a default to keep things more gradual in changes (since saving in safetensors meaning users will have to have it to load the weights). Actually we could load the weights even without safetensors because the loading is super simple https://gist.github.com/Narsil/3edeec2669a5e94e4707aa0f901d2282
I'm just being conservative in changing patterns.

Thanks a lot for the PR @Narsil ! The code looks good, and I'm in favor of supporting safetensors. Is there any example repo with safetensors weights which we could try ?

Not yet, I'm in the process of updating https://github.com/huggingface/safetensors/blob/main/bindings/python/convert.py to support ANY framework by default (just convert ALL pytorch weights of the repo and just change the extension).

@patrickvonplaten
Copy link
Contributor

Design also looks good to me! Thanks a lot for working on it @Narsil !

@Narsil
Copy link
Contributor Author

Narsil commented Nov 22, 2022

@patil-suraj

https://huggingface.co/Narsil/stable-diffusion-v1-4/
https://huggingface.co/Narsil/tiny-stable-diffusion-torch

Are now converted.

Simply

git clone https://github.com/huggingface/safetensors
cd bindings/python
python convert.py MODEL_ID

Or simply this: https://huggingface.co/spaces/safetensors/convert

@Narsil
Copy link
Contributor Author

Narsil commented Nov 22, 2022

Should I create a test for this ? (Could have a tiny random pipeline with only safetensors weights to test, but I would need to modify the CI somewhere to install it.)

@patrickvonplaten
Copy link
Contributor

Should I create a test for this ? (Could have a tiny random pipeline with only safetensors weights to test, but I would need to modify the CI somewhere to install it.)

Yes a test would be great! Could you maybe just safetensors here:

"torchvision",
?

@patrickvonplaten
Copy link
Contributor

It should then be used automatically be the Github Runner (cc @anton-l )

@patil-suraj
Copy link
Contributor

patil-suraj commented Nov 23, 2022

The loading time is crazy fast, I'm getting ~1.3 sec on CPU 🔥

@Narsil
Copy link
Contributor Author

Narsil commented Nov 23, 2022

The loading time is crazy fast, I'm getting ~1.3 sec on CPU fire

should be lower than that, I think I know why, I found a bug creating the test, I'm fixing everything right now.

@Narsil
Copy link
Contributor Author

Narsil commented Nov 23, 2022

Okay I think you should check again the logic as I'm not confident about this change.

snapshot_download uses allow and ignore lists to get the snapshot in a ready to go fashion.

Since both .safetensors files and .bin files were accepted before it downloaded both versions and everything still worked. However, it was loading the Pytorch files for transformers (see huggingface/safetensors#105).

This modification does the fix, but adds an extra network call to fetch the model_info to filter out the files before snapshot_download. This is the smallest change I could think of, but maybe we can do even better (and remove the *.msgpack oddity at the same time.

However this feels like a larger change, which would require a separate PR.

The idea would be to remove all network calls when using _OFFLINE and make sure the CPU load happens in a few ms (currently sitting at 1.6s for me and my lousy network, and with offline options it goes down only to 1.2, 1.3s which is not normal)

Edit:

import torch
from diffusers import StableDiffusionPipeline

import datetime

start = datetime.datetime.now()
pipe = StableDiffusionPipeline.from_pretrained("Narsil/stable-diffusion-v1-4", torch_dtype=torch.float16).to("cuda")
print(f"Loaded in {datetime.datetime.now() - start}")
image = pipe("example prompt", num_inference_steps=2).images[0]

@Narsil
Copy link
Contributor Author

Narsil commented Nov 23, 2022

Also if that's interesting I could open another PR to load the pipeline directly on CUDA, this will make a difference with safetensors (and the appropriate flag).

pretrained_model_name_or_path, weights_name=SAFETENSORS_WEIGHTS_NAME, **kwargs
)
return model_file, False
except:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we catch with a more explicit "File doesn't exist" error message here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
except:
except EnvironmentError:

It can only be an EnvironmentError no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably a lot can go wrong here (OOM, no disk space, network error), but we can definitely except on that. I'm not sure which kind of exception is thrown if the file is missing though.

return model


def get_model_file(pretrained_model_name_or_path, **kwargs) -> Tuple[str, bool]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The functions become a bit too nested for me here and hard to follow - could we maybe just remove this function and only keep _get_model_file? I don't think those 8 lines deserve a new function here, I'd prefer to just call _get_model_file directly above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It mostly keeps the caller tidyer IMO which is already a complex function, this nested code would have to live there, making it even more complex.

It's 8 lines with 2 indentations, it's complex in my book, so keeping this logic on its own is more readable to me.
It's (Try getting safetensors, if anything goes wrong, fetch the pytorch one).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to remove it, but I think the calling function will get messier.

return model_file, True


def _get_model_file(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

if device_map is None:
param_device = "cpu"
state_dict = load_state_dict(model_file)
state_dict = load_state_dict(model_file, is_pytorch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually could we not use a flag here at all? And instead decide in load_state_dict depending on the model file name (it has to end with either SAFETENSORS_WEIGHTS_NAME or WEIGHTS_NAME whether it's is_pytorch or not. It's a bit confusing to me to pass around a is_pytorch flag here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again I think the behavior should not on the filename, but on what you did previously.

It's bikeshedding at this point, I'll try to find something even more elegant.



def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
def load_state_dict(checkpoint_file: Union[str, os.PathLike], is_pytorch: bool):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def load_state_dict(checkpoint_file: Union[str, os.PathLike], is_pytorch: bool):
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):

"""
try:
return torch.load(checkpoint_file, map_location="cpu")
if is_pytorch:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if is_pytorch:
if Path(checkpoint_file).name == WEIGHTS_NAME:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this is wrong, we already know at this point what kind of file it is so we don't have to guess.

I agree the is_pytorch is slightly odd but I don't think doing things based on filename is great.
I'll try to figure something out that's better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't find anything satisfactory to me so I went your way.

if is_pytorch:
return torch.load(checkpoint_file, map_location="cpu")
else:
return load_file(checkpoint_file, device="cpu")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for readability could we maybe do the following:

Suggested change
return load_file(checkpoint_file, device="cpu")
return safetensors.torch.load_file(checkpoint_file, device="cpu")

then the reader directly knows the loading comes from safetensors?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

user_agent["custom_pipeline"] = custom_pipeline
user_agent = http_user_agent(user_agent)

info = model_info(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we wrap all of this into a function, called is_compatible_with_safetensors()?

And then just do something like:

if is_safetensors_available() and is_compatible_with_safetensors():
    ignore_patterns.append("*.bin")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can.

I'd rather leave model_info on its own though.

It's a network call, and function names like is_compatible seem innocuous and fast (which doing network operations is not :))

@patrickvonplaten
Copy link
Contributor

PR looks very nice!

I left some feedback in modeling_utils regarding readability - always a bit worried to introduce too many nested functions and flags.

Overall I think this is a really nice implementation! In pipeline_utils it would just be nice to also only do the extra call to the Hub if safetensors are installed and maybe move the whole logic into one simple sounding "is pipe compatible with safetensors" function

@Narsil
Copy link
Contributor Author

Narsil commented Nov 26, 2022

PR looks very nice!

I left some feedback in modeling_utils regarding readability - always a bit worried to introduce too many nested functions and flags.

I don't necessarily agree but I went your way, I feel it's more style preferences so happy to oblige.

n pipeline_utils it would just be nice to also only do the extra call to the Hub if safetensors are installed and maybe move the whole logic into one simple sounding "is pipe compatible with safetensors" function

Hmm we should probably just own that call is what I'm suggesting since it's exactly what snapshot download does:
https://github.com/huggingface/huggingface_hub/blob/22c1431c598dca24e74efa2bed803df4a58fd862/src/huggingface_hub/_snapshot_download.py#L173

This calls for a larger PR though.
Also using safetensors files could also be leveraged in Flax codebase (yet another PR)

snapshot_download assumes you're knowing which files you need uniquely based on where you are in the codebase (Flax vs PT). Assumption which this PR is breaking since now it's (what's on the servers + what's in the environment combination).

user_agent["custom_pipeline"] = custom_pipeline
user_agent = http_user_agent(user_agent)

if is_safetensors_available():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool that works for me - agree with your point that "calling the hub" is too hidden when fully put in is_safetensors_compatible(...)

ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})

USE_TF = os.environ.get("USE_TF", "AUTO").upper()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we should remove this -> we don't have TF in diffusers yet

Copy link
Contributor

@patrickvonplaten patrickvonplaten Nov 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

happy to do in a follow up PR if it's too much work here. Essentially we copy-pasted this whole file from transformers and forgot to remove the USE_TF stuff -> we should remove all TF related code

_torch_available = False
else:
logger.info("Disabling PyTorch because USE_TF is set")
logger.info("Disabling PyTorch because USE_TORCH is set")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.info("Disabling PyTorch because USE_TORCH is set")
logger.info("Disabling PyTorch because USE_TORCH is not set")

except importlib_metadata.PackageNotFoundError:
_safetensors_available = False
else:
logger.info("Disabling Safetensors because USE_TF is set")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.info("Disabling Safetensors because USE_TF is set")
logger.info("Disabling Safetensors because USE_SAFETENSORS is not set")

@patrickvonplaten
Copy link
Contributor

I don't necessarily agree but I went your way, I feel it's more style preferences so happy to oblige.

Thanks for making the changes - fully agree that it's more style preference!

Left some comments about USE_TF - in general we can remove all of those => we don't have TF in diffusers yet (if it's easy to do in import_utils.py, maybe we could do it in this PR - otherwise happy to do it myself in a follow-up PR

Happy to merge from my side though!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice PR - thanks a lot!

Co-authored-by: Patrick von Platen <[email protected]>
@Narsil Narsil merged commit 5755d16 into huggingface:main Nov 28, 2022
@Narsil Narsil deleted the support_safetensors branch November 28, 2022 09:39
sliard pushed a commit to sliard/diffusers that referenced this pull request Dec 21, 2022
…ingface#1357)

* [Proposal] Support loading from safetensors if file is present.

* Style.

* Fix.

* Adding some test to check loading logic.

+ modify download logic to not download pytorch file if not necessary.

* Fixing the logic.

* Adressing comments.

* factor out into a function.

* Remove dead function.

* Typo.

* Extra fetch only if safetensors is there.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

Co-authored-by: Patrick von Platen <[email protected]>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…ingface#1357)

* [Proposal] Support loading from safetensors if file is present.

* Style.

* Fix.

* Adding some test to check loading logic.

+ modify download logic to not download pytorch file if not necessary.

* Fixing the logic.

* Adressing comments.

* factor out into a function.

* Remove dead function.

* Typo.

* Extra fetch only if safetensors is there.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

Co-authored-by: Patrick von Platen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants