Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,7 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
Expand Down Expand Up @@ -674,6 +675,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
Expand Down Expand Up @@ -804,8 +809,19 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
if timesteps is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice!

Could we maybe factor this out into a function:

timesteps = retrieve_timesteps(self.scheduler, timesteps)

and then use #copied from for all pipelines?

accepts_timesteps = "timesteps" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {self.scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
timesteps = self.scheduler.timesteps
num_inference_steps = len(timesteps)
else:
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps

# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,7 @@ def __call__(
image: PipelineImageInput = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
timesteps: List[int] = None,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
Expand Down Expand Up @@ -734,6 +735,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter is modulated by `strength`.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
Expand Down Expand Up @@ -850,7 +855,18 @@ def __call__(
image = self.image_processor.preprocess(image)

# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {self.scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
timesteps = self.scheduler.timesteps
num_inference_steps = len(timesteps)
else:
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)

Expand Down
20 changes: 18 additions & 2 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,7 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
Expand Down Expand Up @@ -816,6 +817,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
Expand Down Expand Up @@ -986,8 +991,19 @@ def __call__(
assert False

# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {self.scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
timesteps = self.scheduler.timesteps
num_inference_steps = len(timesteps)
else:
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps

# 6. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ def __call__(
num_inference_steps: int = 4,
strength: float = 0.8,
original_inference_steps: int = None,
timesteps: List[int] = None,
guidance_scale: float = 8.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
Expand Down Expand Up @@ -623,6 +624,10 @@ def __call__(
we will draw `num_inference_steps` evenly spaced timesteps from as our final timestep schedule,
following the Skipping-Step method in the paper (see Section 4.3). If not set this will default to the
scheduler's `original_inference_steps` attribute.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps on the original LCM training/distillation timestep schedule are used. Must be in descending
order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
Expand Down Expand Up @@ -728,10 +733,20 @@ def __call__(
image = self.image_processor.preprocess(image)

# 5. Prepare timesteps
self.scheduler.set_timesteps(
num_inference_steps, device, original_inference_steps=original_inference_steps, strength=strength
)
timesteps = self.scheduler.timesteps
if timesteps is not None:
self.scheduler.set_timesteps(
device=device,
original_inference_steps=original_inference_steps,
timesteps=timesteps,
strength=strength,
)
timesteps = self.scheduler.timesteps
num_inference_steps = len(timesteps)
else:
self.scheduler.set_timesteps(
num_inference_steps, device, original_inference_steps=original_inference_steps, strength=strength
)
timesteps = self.scheduler.timesteps

# 6. Prepare latent variables
original_inference_steps = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ def __call__(
width: Optional[int] = None,
num_inference_steps: int = 4,
original_inference_steps: int = None,
timesteps: List[int] = None,
guidance_scale: float = 8.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
Expand Down Expand Up @@ -561,6 +562,10 @@ def __call__(
we will draw `num_inference_steps` evenly spaced timesteps from as our final timestep schedule,
following the Skipping-Step method in the paper (see Section 4.3). If not set this will default to the
scheduler's `original_inference_steps` attribute.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps on the original LCM training/distillation timestep schedule are used. Must be in descending
order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
Expand Down Expand Up @@ -668,8 +673,17 @@ def __call__(
)

# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device, original_inference_steps=original_inference_steps)
timesteps = self.scheduler.timesteps
if timesteps is not None:
self.scheduler.set_timesteps(
device=device, original_inference_steps=original_inference_steps, timesteps=timesteps
)
timesteps = self.scheduler.timesteps
num_inference_steps = len(timesteps)
else:
self.scheduler.set_timesteps(
num_inference_steps, device, original_inference_steps=original_inference_steps
)
timesteps = self.scheduler.timesteps

# 5. Prepare latent variable
num_channels_latents = self.unet.config.in_channels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
Expand Down Expand Up @@ -671,6 +672,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
Expand Down Expand Up @@ -801,8 +806,19 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {self.scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
timesteps = self.scheduler.timesteps
num_inference_steps = len(timesteps)
else:
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps

# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@ def __call__(
image: PipelineImageInput = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
timesteps: List[int] = None,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
Expand Down Expand Up @@ -737,6 +738,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter is modulated by `strength`.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
Expand Down Expand Up @@ -853,7 +858,18 @@ def __call__(
image = self.image_processor.preprocess(image)

# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {self.scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
timesteps = self.scheduler.timesteps
num_inference_steps = len(timesteps)
else:
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,7 @@ def __call__(
width: Optional[int] = None,
strength: float = 1.0,
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
Expand Down Expand Up @@ -878,6 +879,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter is modulated by `strength`.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
Expand Down Expand Up @@ -1030,7 +1035,18 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

# 4. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {self.scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
timesteps = self.scheduler.timesteps
num_inference_steps = len(timesteps)
else:
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps=num_inference_steps, strength=strength, device=device
)
Expand Down
Loading