Skip to content

Commit c936097

Browse files
anton-lkumquatexpress
authored andcommitted
Remove the last of ["sample"] (huggingface#842)
1 parent b59419c commit c936097

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

tests/test_pipelines.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,9 +1693,9 @@ def test_ddpm_ddim_equality_batched(self):
16931693
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy").images
16941694

16951695
generator = torch.manual_seed(0)
1696-
ddim_images = ddim(batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")[
1697-
"sample"
1698-
]
1696+
ddim_images = ddim(
1697+
batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy"
1698+
).images
16991699

17001700
# the values aren't exactly equal, but the images look the same visually
17011701
assert np.abs(ddpm_images - ddim_images).max() < 1e-1
@@ -1729,9 +1729,9 @@ def test_lms_stable_diffusion_pipeline(self):
17291729

17301730
prompt = "a photograph of an astronaut riding a horse"
17311731
generator = torch.Generator(device=torch_device).manual_seed(0)
1732-
image = pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy")[
1733-
"sample"
1734-
]
1732+
image = pipe(
1733+
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
1734+
).images
17351735

17361736
image_slice = image[0, -3:, -3:, -1]
17371737
assert image.shape == (1, 512, 512, 3)

0 commit comments

Comments
 (0)