From 2542cd1d27de285f875492902f19ee4d6785baac Mon Sep 17 00:00:00 2001 From: Haofan Wang Date: Sun, 21 Jan 2024 22:18:01 +0800 Subject: [PATCH 1/2] Update pipeline_controlnet_sd_xl.py --- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 02e515c0ff55..78793c2866f4 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1404,11 +1404,6 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) - # manually for max memory savings - if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: - self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast From 149e9f69647177435fd476a086ccb656a1fd82e6 Mon Sep 17 00:00:00 2001 From: Haofan Wang Date: Mon, 22 Jan 2024 23:14:28 +0800 Subject: [PATCH 2/2] Update pipeline_controlnet_xs_sd_xl.py --- .../controlnetxs/pipeline_controlnet_xs_sd_xl.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py b/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py index be888d7e1145..ed45b3bb5a1b 100644 --- a/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py +++ b/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py @@ -1041,11 +1041,6 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) - # manually for max memory savings - if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: - self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast