Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
IPNDMScheduler,
KarrasVeScheduler,
PNDMScheduler,
Expand Down
9 changes: 7 additions & 2 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,15 +767,20 @@ def numpy_to_pil(images):

return pil_images

def progress_bar(self, iterable):
def progress_bar(self, iterable=None, total=None):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
elif not isinstance(self._progress_bar_config, dict):
raise ValueError(
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)

return tqdm(iterable, **self._progress_bar_config)
if iterable is not None:
return tqdm(iterable, **self._progress_bar_config)
elif total is not None:
return tqdm(total=total, **self._progress_bar_config)
else:
raise ValueError("Either `total` or `iterable` has to be defined.")

def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs
44 changes: 24 additions & 20 deletions src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,30 +541,34 @@ def __call__(
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

# 7. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

# 8. Post-processing
image = self.decode_latents(latents)

# 9. Run safety checker
# 9. Run safety checker as
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)

# 10. Convert to PIL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def get_timesteps(self, num_inference_steps, strength, device):
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]

return timesteps
return timesteps, num_inference_steps - int(num_inference_steps * strength)

def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
init_image = init_image.to(device=device, dtype=dtype)
Expand Down Expand Up @@ -562,7 +562,7 @@ def __call__(

# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.get_timesteps(num_inference_steps, strength, device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)

# 6. Prepare latent variables
Expand All @@ -574,25 +574,29 @@ def __call__(
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

# 8. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

# 9. Post-processing
image = self.decode_latents(latents)
Expand Down
126 changes: 65 additions & 61 deletions src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def get_timesteps(self, num_inference_steps, strength, device):
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]

return timesteps
return timesteps, num_inference_steps - int(num_inference_steps * strength)

def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
init_image = init_image.to(device=device, dtype=dtype)
Expand Down Expand Up @@ -607,7 +607,7 @@ def __call__(

# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.get_timesteps(num_inference_steps, strength, device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)

# 6. Prepare latent variables
Expand All @@ -621,66 +621,70 @@ def __call__(
generator = extra_step_kwargs.pop("generator", None)

# 8. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2)
source_latent_model_input = torch.cat([source_latents] * 2)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t)

# predict the noise residual
concat_latent_model_input = torch.stack(
[
source_latent_model_input[0],
latent_model_input[0],
source_latent_model_input[1],
latent_model_input[1],
],
dim=0,
)
concat_text_embeddings = torch.stack(
[
source_text_embeddings[0],
text_embeddings[0],
source_text_embeddings[1],
text_embeddings[1],
],
dim=0,
)
concat_noise_pred = self.unet(
concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings
).sample

# perform guidance
(
source_noise_pred_uncond,
noise_pred_uncond,
source_noise_pred_text,
noise_pred_text,
) = concat_noise_pred.chunk(4, dim=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
source_noise_pred_text - source_noise_pred_uncond
)

# Sample source_latents from the posterior distribution.
prev_source_latents = posterior_sample(
self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs
)
# Compute noise.
noise = compute_noise(
self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs
)
source_latents = prev_source_latents

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs
).prev_sample
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2)
source_latent_model_input = torch.cat([source_latents] * 2)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t)

# predict the noise residual
concat_latent_model_input = torch.stack(
[
source_latent_model_input[0],
latent_model_input[0],
source_latent_model_input[1],
latent_model_input[1],
],
dim=0,
)
concat_text_embeddings = torch.stack(
[
source_text_embeddings[0],
text_embeddings[0],
source_text_embeddings[1],
text_embeddings[1],
],
dim=0,
)
concat_noise_pred = self.unet(
concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings
).sample

# perform guidance
(
source_noise_pred_uncond,
noise_pred_uncond,
source_noise_pred_text,
noise_pred_text,
) = concat_noise_pred.chunk(4, dim=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
source_noise_pred_text - source_noise_pred_uncond
)

# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# Sample source_latents from the posterior distribution.
prev_source_latents = posterior_sample(
self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs
)
# Compute noise.
noise = compute_noise(
self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs
)
source_latents = prev_source_latents

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs
).prev_sample

# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

# 9. Post-processing
image = self.decode_latents(latents)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -540,30 +540,34 @@ def __call__(
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

# 7. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

# 8. Post-processing
image = self.decode_latents(latents)

# 9. Run safety checker
# 9. Run safety checker as
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)

# 10. Convert to PIL
Expand Down
Loading