Skip to content

Conversation

@townwish4git
Copy link
Contributor

What does this PR do?

Fixes # (issue)
Fix incorrect call to self.decode() within AsymmetricAutoencoderKL.forward():

- dec = self.decode(z, sample, mask).sample
+ dec = self.decode(z, generator, sample, mask).sample

related to: issue#8317

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul @yiyixuxu @DN6

@tolgacangoz
Copy link
Contributor

Thanks for opening this PR!
I wonder what generator is used for here?

@townwish4git
Copy link
Contributor Author

Thanks for opening this PR! I wonder what generator is used for here?

The generator used here is the one passed into the AsymmetricAutoencoderKL.forward()

def forward(
        self,
        sample: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        sample_posterior: bool = False,
        return_dict: bool = True,
        generator: Optional[torch.Generator] = None,
    ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
        ...
-        dec = self.decode(z, sample, mask).sample
+        dec = self.decode(z, generator, sample, mask).sample
        ...

@tolgacangoz
Copy link
Contributor

But decode() function doesn't use it 🤔.

else:
z = posterior.mode()
dec = self.decode(z, sample, mask).sample
dec = self.decode(z, generator, sample, mask).sample
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tend to agree with @tolgacangoz's comments here. It's not used in the decode() function. Similar to AutoencoderKL. It's used in the forward():

z = posterior.sample(generator=generator)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's a stochastic component in the decode() function in the first place. So, I further think there's no need to have generator in here too:

def decode(self, z: torch.Tensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.Tensor]:

@yiyixuxu WDYT here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul I had the same questions, but their explanation makes sense I think

@townwish4git
Copy link
Contributor Author

But decode() function doesn't use it 🤔.

I think this is to maintain consistency of decode() interfaces across different VAEs. For example, here is a demo using AsymmetricAutoencoderKL:

...
from diffusers import AsymmetricAutoencoderKL, StableDiffusionInpaintPipeline
...
pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
pipe.vae = AsymmetricAutoencoderKL.from_pretrained("cross-attention/asymmetric-autoencoder-kl-x-1-5")
pipe.to("cuda")

image = pipe(prompt=prompt, image=image, mask_image=mask_image).images[0]
image.save("image.jpeg")

in this case, when origin vae is replaced with AsymmetricAutoencoderKL, you don't have to modify the codes that calls self.vae.decode() within the pipeline, and you can simply retain the original code where arguments generator is passed in:

image = self.vae.decode(
    latents / self.vae.config.scaling_factor, return_dict=False, generator=generator, **condition_kwargs
)[0]

@tolgacangoz
Copy link
Contributor

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jun 3, 2024

@tolgacangoz @sayakpaul
I want to understand a little bit more before merging this PR
with this demo #8378 (comment) I think it will currently throw an error so this PR fixes it

is there a different way to use AsymmetricAutoencoderKL? asking because we updated the signature and can potentially break the previous use case that's working. If it is previously not working, we are good

@townwish4git
Copy link
Contributor Author

@yiyixuxu Most use cases call the AsymmetricAutoencoderKL.encode() or decode() functions, which will be not affected before and after merging.

I'm not sure if there is any forward() called directly in the current use cases, and if it exist:

  • Before merging: it wouldn't throw an error, but would have unexpected computed results as issue#8317 mentioned
  • After merging: get fixed

@yiyixuxu yiyixuxu merged commit 6be43bd into huggingface:main Jun 4, 2024
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants