-
Notifications
You must be signed in to change notification settings - Fork 6.5k
vq diffusion classifier free sampling #1294
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
vq diffusion classifier free sampling #1294
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
d0d5beb to
4ee1e06
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
4ee1e06 to
cac658d
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
cac658d to
f5fbbbe
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
f5fbbbe to
8746de8
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
8746de8 to
fe8db41
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
fe8db41 to
bfc4459
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
bfc4459 to
40dc3ff
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
40dc3ff to
10e2ea3
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
10e2ea3 to
08984ab
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
| "https://huggingface.co/datasets/williamberman/misc/resolve/main" | ||
| "/vq_diffusion/teddy_bear_pool_classifier_free_sampling.png" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be moved to the huggingface testing dataset. FWIW you might have to also regenerate the image because I get a different image on VQDiffusionPipelineIntegrationtests#test_vq_diffusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool :-) I'll move it!
src/diffusers/models/embeddings.py
Outdated
| class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin): | ||
| """ | ||
| Utility class for storing learned text embeddings for classifier free sampling | ||
| """ | ||
|
|
||
| @register_to_config | ||
| def __init__(self, learnable: bool, hidden_size: Optional[int] = None, length: Optional[int] = None): | ||
| super().__init__() | ||
|
|
||
| self.learnable = learnable | ||
|
|
||
| if self.learnable: | ||
| assert hidden_size is not None, "learnable=True requires `hidden_size` to be set" | ||
| assert length is not None, "learnable=True requires `length` to be set" | ||
|
|
||
| embeddings = torch.zeros(length, hidden_size) | ||
| else: | ||
| embeddings = None | ||
|
|
||
| self.embeddings = nn.Parameter(embeddings) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this is the preferred way to add the learned embeddings to the pipeline. An alternative might be to add the additional vector to the scheduler instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's very model specific, so moving it to the pipeline here directly :-)
Think that's a bit cleaner! The model works much better now though - thanks!
| tokenizer: CLIPTokenizer, | ||
| transformer: Transformer2DModel, | ||
| scheduler: VQDiffusionScheduler, | ||
| learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's definitely the right way to do it - it's quite specific to vq-diffusion IMO though, so will move it here :-)
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
|
Very nice job @williamberman ! |
* vq diffusion classifier free sampling * correct * uP Co-authored-by: Patrick von Platen <[email protected]>
Adds classifier free sampling to VQ diffusion. This results in significantly better image quality.
The pipeline now has a default guidance_scale of 5.0
Additionally, the ithq dataset uses a learned parameter for the classifier free embeddings. We modify the convert script to add this parameter to the ported model. Weights will have to be reuploaded
Prompts: "teddy bear playing in the pool" and "horse"
Diffusers VQ diffusion with classifier free sampling
Diffusers VQ diffusion without classifier free sampling
Original VQ diffusion implementation with classifier free sampling
Original VQ diffusion implementation without classifier free sampling