Skip to content

Commit 1a4f44c

Browse files
committed
more
1 parent f6bcece commit 1a4f44c

File tree

4 files changed

+109
-6
lines changed

4 files changed

+109
-6
lines changed

src/diffusers/models/autoencoder_kl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,9 @@ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decod
294294
return DecoderOutput(sample=dec)
295295

296296
@apply_forward_hook
297-
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
297+
def decode(
298+
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
299+
) -> Union[DecoderOutput, torch.FloatTensor]:
298300
"""
299301
Decode a batch of images.
300302

src/diffusers/models/consistency_decoder_vae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def decode(
242242
return_dict: bool = True,
243243
num_inference_steps=2,
244244
) -> Union[DecoderOutput, torch.FloatTensor]:
245-
z = (z - self.means) / self.stds
245+
z = (z * self.config.scaling_factor - self.means) / self.stds
246246

247247
scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
248248
z = F.interpolate(z, mode="nearest", scale_factor=scale_factor)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,12 +453,12 @@ def run_safety_checker(self, image, device, dtype):
453453
)
454454
return image, has_nsfw_concept
455455

456-
def decode_latents(self, latents):
456+
def decode_latents(self, latents, generator=None):
457457
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
458458
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
459459

460460
latents = 1 / self.vae.config.scaling_factor * latents
461-
image = self.vae.decode(latents, return_dict=False)[0]
461+
image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
462462
image = (image / 2 + 0.5).clamp(0, 1)
463463
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
464464
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
@@ -838,7 +838,9 @@ def __call__(
838838
callback(step_idx, t, latents)
839839

840840
if not output_type == "latent":
841-
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
841+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
842+
0
843+
]
842844
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
843845
else:
844846
image = latents

tests/models/test_models_vae.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,19 @@
1616
import gc
1717
import unittest
1818

19+
import numpy as np
1920
import torch
2021
from parameterized import parameterized
2122

22-
from diffusers import AsymmetricAutoencoderKL, AutoencoderKL, AutoencoderTiny, ConsistencyDecoderVae
23+
from diffusers import (
24+
AsymmetricAutoencoderKL,
25+
AutoencoderKL,
26+
AutoencoderTiny,
27+
ConsistencyDecoderVae,
28+
StableDiffusionPipeline,
29+
)
2330
from diffusers.utils.import_utils import is_xformers_available
31+
from diffusers.utils.loading_utils import load_image
2432
from diffusers.utils.testing_utils import (
2533
enable_full_determinism,
2634
floats_tensor,
@@ -795,3 +803,94 @@ def test_stable_diffusion_encode_sample(self, seed, expected_slice):
795803

796804
tolerance = 3e-3 if torch_device != "mps" else 1e-2
797805
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
806+
807+
808+
@slow
809+
class ConsistencyDecoderVaeIntegrationTests(unittest.TestCase):
810+
def tearDown(self):
811+
# clean up the VRAM after each test
812+
super().tearDown()
813+
gc.collect()
814+
torch.cuda.empty_cache()
815+
816+
def test_encode_decode(self):
817+
vae = ConsistencyDecoderVae.from_pretrained("williamberman/consistency-decoder") # TODO - update
818+
vae.to(torch_device)
819+
820+
image = load_image(
821+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
822+
"/img2img/sketch-mountains-input.jpg"
823+
).resize((256, 256))
824+
image = torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[
825+
None, :, :, :
826+
].cuda()
827+
828+
latent = vae.encode(image).latent_dist.mean
829+
830+
sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
831+
832+
actual_output = sample[0, :2, :2, :2].flatten().cpu()
833+
expected_output = torch.tensor([-0.0141, -0.0014, 0.0115, 0.0086, 0.1051, 0.1053, 0.1031, 0.1024])
834+
835+
assert torch_all_close(actual_output, expected_output, atol=5e-3)
836+
837+
def test_sd(self):
838+
vae = ConsistencyDecoderVae.from_pretrained("williamberman/consistency-decoder") # TODO - update
839+
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None)
840+
pipe.to(torch_device)
841+
842+
out = pipe(
843+
"horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
844+
).images[0]
845+
846+
actual_output = out[:2, :2, :2].flatten().cpu()
847+
expected_output = torch.tensor([0.7686, 0.8228, 0.6489, 0.7455, 0.8661, 0.8797, 0.8241, 0.8759])
848+
849+
assert torch_all_close(actual_output, expected_output, atol=5e-3)
850+
851+
def test_encode_decode_f16(self):
852+
vae = ConsistencyDecoderVae.from_pretrained(
853+
"williamberman/consistency-decoder", torch_dtype=torch.float16
854+
) # TODO - update
855+
vae.to(torch_device)
856+
857+
image = load_image(
858+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
859+
"/img2img/sketch-mountains-input.jpg"
860+
).resize((256, 256))
861+
image = (
862+
torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[None, :, :, :]
863+
.half()
864+
.cuda()
865+
)
866+
867+
latent = vae.encode(image).latent_dist.mean
868+
869+
sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
870+
871+
actual_output = sample[0, :2, :2, :2].flatten().cpu()
872+
expected_output = torch.tensor(
873+
[-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471], dtype=torch.float16
874+
)
875+
876+
assert torch_all_close(actual_output, expected_output, atol=5e-3)
877+
878+
def test_sd_f16(self):
879+
vae = ConsistencyDecoderVae.from_pretrained(
880+
"williamberman/consistency-decoder", torch_dtype=torch.float16
881+
) # TODO - update
882+
pipe = StableDiffusionPipeline.from_pretrained(
883+
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, vae=vae, safety_checker=None
884+
)
885+
pipe.to(torch_device)
886+
887+
out = pipe(
888+
"horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
889+
).images[0]
890+
891+
actual_output = out[:2, :2, :2].flatten().cpu()
892+
expected_output = torch.tensor(
893+
[0.2510, 0.3776, 0.0000, 0.0285, 0.1519, 0.1814, 0.0000, 0.0000], dtype=torch.float16
894+
)
895+
896+
assert torch_all_close(actual_output, expected_output, atol=5e-3)

0 commit comments

Comments
 (0)