-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Allow converting Flax to PyTorch by adding a "from_flax" keyword #1900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
|
it is working like this !git clone https://huggingface.co/camenduru/plushies
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("/content/plushies", safety_checker=None, from_flax=True).to("cpu")but if I do from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("camenduru/plushies", safety_checker=None, from_flax=True).to("cpu")it is downloading safetensors 😐 |
|
if only flax model in repo it is skipping 😐 |
|
Hi @camenduru! This is a great effort, however network downloading is a bit complicated. When you clone the repo, all the files are available locally and you select the flax ones, but downloading attempts to only retrieve the files that will be needed. I haven't had enough time to study your code in depth, but I believe you should manipulate For easier debugging, I would concentrate on a repo that only has flax weights and not safetensors. When that works, I'd suggest ignoring all the I hope that helps. I'll try to test it better tomorrow. Thanks a lot for working on this! |
|
@pcuenca thanks ❤ I have a question I found this https://setuptools.pypa.io/en/latest/userguide/development_mode.html @patil-suraj taught me |
Absolutely! If you install with |
|
thanks @pcuenca ❤ working now 🎉🎊✨ |
|
how can I pass this test should I edit the test 🤔 |
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("camenduru/plushies", safety_checker=None, from_flax=True).to("cpu")
pipe.save_pretrained("pt")from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("/content/pt", safety_checker=None).to("cuda")
image = pipe("duck", num_inference_steps=50).images[0]
display(image)from diffusers import FlaxStableDiffusionPipeline
pipe, params = FlaxStableDiffusionPipeline.from_pretrained("camenduru/plushies-pt", from_pt=True)
pipe.save_pretrained("flax", params=params)from diffusers import FlaxStableDiffusionPipeline
pipe, params = FlaxStableDiffusionPipeline.from_pretrained("/workspaces/flax", dtype=jax.numpy.bfloat16, safety_checker=None)
params = replicate(params)
real_seed = random.randint(0, 2147483647)
prng_seed = jax.random.PRNGKey(real_seed)
prng_seed = jax.random.split(prng_seed, jax.device_count())
num_samples = jax.device_count()
prompt = "duck"
prompt_n = num_samples * [prompt]
prompt_ids = pipe.prepare_inputs(prompt_n)
prompt_ids = shard(prompt_ids)
images = pipe(prompt_ids, params, prng_seed, jit=True).images
images = pipe.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
display(images[0]) |
|
oh no 😐 I did something wrong |
|
how can I add reviewers back oops sorry |
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't tried it out, but the changes to modeling_utils.py look good to me! Also all tests are passing 😍
Great job @camenduru - this was one of the harder PRs!
@pcuenca @patil-suraj mind taking a look here as well. Would like to have 2 more 👀 on this one as it touches core functionality.
pcuenca
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks a lot @camenduru! I just left a few minor comments. I'll try to test it tomorrow, but I agree with Patrick that it'd be really cool if we could add a simple test that shows how this works.
Great job!
| else: | ||
| logger.warning(f"All Flax model weights were used when initializing {pt_model.__class__.__name__}.\n") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the ok case, right? I wonder if we should just skip this warning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
| else: | ||
| logger.warning( | ||
| f"All the weights of {pt_model.__class__.__name__} were initialized from the Flax model.\n" | ||
| "If your task is similar to the task the model of the checkpoint was trained on, " | ||
| f"you can already use {pt_model.__class__.__name__} for predictions without further training." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here, this means everything went well, doesn't it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes deleted
Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
|
thanks @pcuenca ❤
I have a question about test should I write test other then this #1900 (comment) |
|
Added a test, conversion seems to be working perfectly! Think we can merge this one :-) Amazing job @camenduru ! |
|
woohoo 🥳 thanks @patrickvonplaten ❤ @pcuenca ❤ @patil-suraj ❤ 🥳 🎉 🎊 |
…gingface#1900) * from_flax * oops * oops * make style with pip install -e ".[dev]" * oops * now code quality happy 😋 * allow_patterns += FLAX_WEIGHTS_NAME * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/modeling_utils.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Patrick von Platen <[email protected]> * for test * bye bye is_flax_available() * oops * Update src/diffusers/models/modeling_pytorch_flax_utils.py Co-authored-by: Pedro Cuenca <[email protected]> * Update src/diffusers/models/modeling_pytorch_flax_utils.py Co-authored-by: Pedro Cuenca <[email protected]> * Update src/diffusers/models/modeling_pytorch_flax_utils.py Co-authored-by: Pedro Cuenca <[email protected]> * Update src/diffusers/models/modeling_utils.py Co-authored-by: Pedro Cuenca <[email protected]> * Update src/diffusers/models/modeling_utils.py Co-authored-by: Pedro Cuenca <[email protected]> * make style * add test * finihs Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>



everything changed I am confused 😐 this probably not works @patrickvonplaten please help me