Skip to content

Conversation

haofanwang
Copy link
Contributor

What does this PR do?

Fixes a upcasting bug.

Let's have a quick review here.

# manually for max memory savings
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
    self.upcast_vae()
    latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)

if not output_type == "latent":
    # make sure the VAE is in float32 mode, as it overflows in float16
    needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast

    if needs_upcasting:
        self.upcast_vae()
        latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)

    image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]

    # cast back to fp16 if needed
    if needs_upcasting:
        self.vae.to(dtype=torch.float16)

Assuming that we set our model to float16. At the first if, the vae and latents are upcasted to float32 for avoiding overflow. Then, needs_upcasting is False as the vae is in float32 now. The image is decoded in float32. Everything works fine.

But if we conduct the second inference, the vae is still in float32 while the latents is in float16, then we will encounter RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same. Removing the first few lines solves the problem.

@haofanwang
Copy link
Contributor Author

@sayakpaul Could you review this PR?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Could you apply this to the other SDXL ControlNet pipeline too?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@haofanwang
Copy link
Contributor Author

Done. Other pipelines look fine.

@haofanwang
Copy link
Contributor Author

@patrickvonplaten

@patrickvonplaten
Copy link
Contributor

@yiyixuxu can you review this?

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for your fix!

@sayakpaul sayakpaul merged commit c9081a8 into huggingface:main Jan 24, 2024
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* Update pipeline_controlnet_sd_xl.py

* Update pipeline_controlnet_xs_sd_xl.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants