Skip to content

Conversation

@nateraw
Copy link
Contributor

@nateraw nateraw commented Oct 17, 2022

Adds interpolate_stable_diffusion.py, which is a pipeline that lets you generate images as you interpolate between different prompts/seeds.

Its __call__ fn is the same as StableDiffusionPipeline, but with the added text_embeddings kwarg. The walk function is where the logic happens. Fine with renaming these if need be so that walk becomes __call__.

Here's how I have been using it (from within examples/community dir)

from interpolate_stable_diffusion import StableDiffusionWalkPipeline
import torch

pipe = StableDiffusionWalkPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision='fp16',
    torch_dtype=torch.float16,
    safety_checker=None,  # Very important for videos...lots of false positives while interpolating
).to('cuda')

frame_filepaths = pipe.walk(
    prompts=['a dog', 'a cat', 'a horse'],
    seeds=[42, 1337, 1234],
    num_interpolation_steps=16,
    output_dir='./dreams',
    batch_size=24,
    height=512,
    width=512,
    guidance_scale=8.5,
    num_inference_steps=50,
)

Though it seems (I think) that once it's merged, you'll be able to use community='interpolate_stable_diffusion' in the from_pretrained fn instead of having to be in the examples/community dir.

Copy link
Contributor Author

@nateraw nateraw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few notes

|:----------|:----------------------|:-----------------|----------:|
| CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion| [Suraj Patil](https://github.com/patil-suraj/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) |
| One Step U-Net (Dummy) | [Patrick von Platen](https://github.com/patrickvonplaten/) | - |
| One Step U-Net (Dummy) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) | - |
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw fixed this table here... @patrickvonplaten your name was under the Description column.

| CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion| [Suraj Patil](https://github.com/patil-suraj/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) |
| One Step U-Net (Dummy) | [Patrick von Platen](https://github.com/patrickvonplaten/) | - |
| One Step U-Net (Dummy) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) | - |
| Stable Diffusion Interpolation | Interpolate the latent space of Stable Diffusion between different prompts/seeds | [Nate Raw](https://github.com/nateraw/) | - |
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add notebook once this is merged

@torch.no_grad()
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prompt becomes optional, as you can either pass that or text_embeddings

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 17, 2022

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten patrickvonplaten merged commit ee9875e into main Oct 17, 2022
@patrickvonplaten patrickvonplaten deleted the add-interpolation-pipeline branch October 17, 2022 11:48
@patrickvonplaten
Copy link
Contributor

Thanks a lot @nateraw !

kumquatexpress pushed a commit to harvestlabs/diffusers that referenced this pull request Oct 19, 2022
* ✨ Add Stable Diffusion Interpolation Example

* 💄 style

* Update examples/community/interpolate_stable_diffusion.py

Co-authored-by: Patrick von Platen <[email protected]>
prathikr pushed a commit to prathikr/diffusers that referenced this pull request Oct 26, 2022
* ✨ Add Stable Diffusion Interpolation Example

* 💄 style

* Update examples/community/interpolate_stable_diffusion.py

Co-authored-by: Patrick von Platen <[email protected]>
PhaneeshB pushed a commit to nod-ai/diffusers that referenced this pull request Mar 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants