Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/diffusers/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,16 @@ def __init__(self, parameters, deterministic=False):
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
self.var = self.std = torch.zeros_like(
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)

def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
device = self.parameters.device
sample_device = "cpu" if device.type == "mps" else device
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device)
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device)
# make sure sample is on the same device as the parameters and has same dtype
sample = sample.to(device=device, dtype=self.parameters.dtype)
x = self.mean + self.std * sample
return x

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,26 +217,6 @@ def __call__(
if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image)

# encode the init image into latents and scale the latents
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents

# expand init_latents for batch_size
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)

# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)

timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)

# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)

# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
Expand Down Expand Up @@ -297,6 +277,28 @@ def __call__(
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

# encode the init image into latents and scale the latents
latents_dtype = text_embeddings.dtype
init_image = init_image.to(device=self.device, dtype=latents_dtype)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents

# expand init_latents for batch_size
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)

# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)

timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)

# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
Expand Down Expand Up @@ -341,7 +343,9 @@ def __call__(
image = image.cpu().permute(0, 2, 3, 1).numpy()

safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)

if output_type == "pil":
image = self.numpy_to_pil(image)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,43 +234,6 @@ def __call__(
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

# preprocess image
if not isinstance(init_image, torch.FloatTensor):
init_image = preprocess_image(init_image)
init_image = init_image.to(self.device)

# encode the init image into latents and scale the latents
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)

init_latents = 0.18215 * init_latents

# Expand init_latents for batch_size and num_images_per_prompt
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
init_latents_orig = init_latents

# preprocess mask
if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image)
mask_image = mask_image.to(self.device)
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)

# check sizes
if not mask.shape == init_latents.shape:
raise ValueError("The mask and init_image should be the same size!")

# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)

timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)

# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)

# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
Expand Down Expand Up @@ -335,6 +298,43 @@ def __call__(
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

# preprocess image
if not isinstance(init_image, torch.FloatTensor):
init_image = preprocess_image(init_image)

# encode the init image into latents and scale the latents
latents_dtype = text_embeddings.dtype
init_image = init_image.to(device=self.device, dtype=latents_dtype)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents

# Expand init_latents for batch_size and num_images_per_prompt
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
init_latents_orig = init_latents

# preprocess mask
if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image)
mask_image = mask_image.to(device=self.device, dtype=latents_dtype)
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)

# check sizes
if not mask.shape == init_latents.shape:
raise ValueError("The mask and init_image should be the same size!")

# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)

timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)

# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
Expand Down
118 changes: 118 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,124 @@ def test_stable_diffusion_inpaint_num_images_per_prompt(self):

assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)

@unittest.skipIf(torch_device == "cpu", "This test requires a GPU")
def test_stable_diffusion_fp16(self):
"""Test that stable diffusion works with fp16"""
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

# put models in fp16
unet = unet.half()
vae = vae.half()
bert = bert.half()

# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)

prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images

assert image.shape == (1, 128, 128, 3)

@unittest.skipIf(torch_device == "cpu", "This test requires a GPU")
def test_stable_diffusion_img2img_fp16(self):
"""Test that stable diffusion img2img works with fp16"""
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

init_image = self.dummy_image.to(torch_device)

# put models in fp16
unet = unet.half()
vae = vae.half()
bert = bert.half()

# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionImg2ImgPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)

prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
image = sd_pipe(
[prompt],
generator=generator,
num_inference_steps=2,
output_type="np",
init_image=init_image,
).images

assert image.shape == (1, 32, 32, 3)

@unittest.skipIf(torch_device == "cpu", "This test requires a GPU")
def test_stable_diffusion_inpaint_fp16(self):
"""Test that stable diffusion inpaint works with fp16"""
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))

# put models in fp16
unet = unet.half()
vae = vae.half()
bert = bert.half()

# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionInpaintPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)

prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
image = sd_pipe(
[prompt],
generator=generator,
num_inference_steps=2,
output_type="np",
init_image=init_image,
mask_image=mask_image,
).images

assert image.shape == (1, 32, 32, 3)


class PipelineTesterMixin(unittest.TestCase):
def tearDown(self):
Expand Down