Skip to content

Commit 5389d07

Browse files
committed
fix
1 parent 1a4f44c commit 5389d07

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -453,12 +453,12 @@ def run_safety_checker(self, image, device, dtype):
453453
)
454454
return image, has_nsfw_concept
455455

456-
def decode_latents(self, latents, generator=None):
456+
def decode_latents(self, latents):
457457
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
458458
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
459459

460460
latents = 1 / self.vae.config.scaling_factor * latents
461-
image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
461+
image = self.vae.decode(latents, return_dict=False)[0]
462462
image = (image / 2 + 0.5).clamp(0, 1)
463463
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
464464
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

tests/models/test_models_vae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,7 @@ def test_sd_f16(self):
890890

891891
actual_output = out[:2, :2, :2].flatten().cpu()
892892
expected_output = torch.tensor(
893-
[0.2510, 0.3776, 0.0000, 0.0285, 0.1519, 0.1814, 0.0000, 0.0000], dtype=torch.float16
893+
[0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035], dtype=torch.float16
894894
)
895895

896896
assert torch_all_close(actual_output, expected_output, atol=5e-3)

0 commit comments

Comments
 (0)