Skip to content

Commit bb9633a

Browse files
authored
Add Custom Timesteps Support to LCMScheduler and Supported Pipelines (huggingface#5874)
* Add custom timesteps support to LCMScheduler. * Add custom timesteps support to StableDiffusionPipeline. * Add custom timesteps support to StableDiffusionXLPipeline. * Add custom timesteps support to remaining Stable Diffusion pipelines which support LCMScheduler (img2img, inpaint). * Add custom timesteps support to remaining Stable Diffusion XL pipelines which support LCMScheduler (img2img, inpaint). * Add custom timesteps support to StableDiffusionControlNetPipeline. * Add custom timesteps support to T21 Stable Diffusion (XL) Adapters. * Clean up Stable Diffusion inpaint tests. * Manually add support for custom timesteps to AltDiffusion pipelines since make fix-copies doesn't appear to work correctly (it deletes the whole pipeline). * make style * Refactor pipeline timestep handling into the retrieve_timesteps function.
1 parent 423e273 commit bb9633a

14 files changed

+784
-53
lines changed

pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,51 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
7373
return noise_cfg
7474

7575

76+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
77+
def retrieve_timesteps(
78+
scheduler,
79+
num_inference_steps: Optional[int] = None,
80+
device: Optional[Union[str, torch.device]] = None,
81+
timesteps: Optional[List[int]] = None,
82+
**kwargs,
83+
):
84+
"""
85+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
86+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
87+
88+
Args:
89+
scheduler (`SchedulerMixin`):
90+
The scheduler to get timesteps from.
91+
num_inference_steps (`int`):
92+
The number of diffusion steps used when generating samples with a pre-trained model. If used,
93+
`timesteps` must be `None`.
94+
device (`str` or `torch.device`, *optional*):
95+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
96+
timesteps (`List[int]`, *optional*):
97+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
98+
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
99+
must be `None`.
100+
101+
Returns:
102+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
103+
second element is the number of inference steps.
104+
"""
105+
if timesteps is not None:
106+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
107+
if not accepts_timesteps:
108+
raise ValueError(
109+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
110+
f" timestep schedules. Please check whether you are using the correct scheduler."
111+
)
112+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
113+
timesteps = scheduler.timesteps
114+
num_inference_steps = len(timesteps)
115+
else:
116+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
117+
timesteps = scheduler.timesteps
118+
return timesteps, num_inference_steps
119+
120+
76121
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
77122
class AltDiffusionPipeline(
78123
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
@@ -662,6 +707,7 @@ def __call__(
662707
height: Optional[int] = None,
663708
width: Optional[int] = None,
664709
num_inference_steps: int = 50,
710+
timesteps: List[int] = None,
665711
guidance_scale: float = 7.5,
666712
negative_prompt: Optional[Union[str, List[str]]] = None,
667713
num_images_per_prompt: Optional[int] = 1,
@@ -693,6 +739,10 @@ def __call__(
693739
num_inference_steps (`int`, *optional*, defaults to 50):
694740
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
695741
expense of slower inference.
742+
timesteps (`List[int]`, *optional*):
743+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
744+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
745+
passed will be used. Must be in descending order.
696746
guidance_scale (`float`, *optional*, defaults to 7.5):
697747
A higher guidance scale value encourages the model to generate images closely linked to the text
698748
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -830,8 +880,7 @@ def __call__(
830880
image_embeds = torch.cat([negative_image_embeds, image_embeds])
831881

832882
# 4. Prepare timesteps
833-
self.scheduler.set_timesteps(num_inference_steps, device=device)
834-
timesteps = self.scheduler.timesteps
883+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
835884

836885
# 5. Prepare latent variables
837886
num_channels_latents = self.unet.config.in_channels

pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,51 @@ def preprocess(image):
109109
return image
110110

111111

112+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
113+
def retrieve_timesteps(
114+
scheduler,
115+
num_inference_steps: Optional[int] = None,
116+
device: Optional[Union[str, torch.device]] = None,
117+
timesteps: Optional[List[int]] = None,
118+
**kwargs,
119+
):
120+
"""
121+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
122+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
123+
124+
Args:
125+
scheduler (`SchedulerMixin`):
126+
The scheduler to get timesteps from.
127+
num_inference_steps (`int`):
128+
The number of diffusion steps used when generating samples with a pre-trained model. If used,
129+
`timesteps` must be `None`.
130+
device (`str` or `torch.device`, *optional*):
131+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
132+
timesteps (`List[int]`, *optional*):
133+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
134+
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
135+
must be `None`.
136+
137+
Returns:
138+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
139+
second element is the number of inference steps.
140+
"""
141+
if timesteps is not None:
142+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
143+
if not accepts_timesteps:
144+
raise ValueError(
145+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
146+
f" timestep schedules. Please check whether you are using the correct scheduler."
147+
)
148+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
149+
timesteps = scheduler.timesteps
150+
num_inference_steps = len(timesteps)
151+
else:
152+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
153+
timesteps = scheduler.timesteps
154+
return timesteps, num_inference_steps
155+
156+
112157
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
113158
class AltDiffusionImg2ImgPipeline(
114159
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
@@ -714,6 +759,7 @@ def __call__(
714759
image: PipelineImageInput = None,
715760
strength: float = 0.8,
716761
num_inference_steps: Optional[int] = 50,
762+
timesteps: List[int] = None,
717763
guidance_scale: Optional[float] = 7.5,
718764
negative_prompt: Optional[Union[str, List[str]]] = None,
719765
num_images_per_prompt: Optional[int] = 1,
@@ -751,6 +797,10 @@ def __call__(
751797
num_inference_steps (`int`, *optional*, defaults to 50):
752798
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
753799
expense of slower inference. This parameter is modulated by `strength`.
800+
timesteps (`List[int]`, *optional*):
801+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
802+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
803+
passed will be used. Must be in descending order.
754804
guidance_scale (`float`, *optional*, defaults to 7.5):
755805
A higher guidance scale value encourages the model to generate images closely linked to the text
756806
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -873,7 +923,7 @@ def __call__(
873923
image = self.image_processor.preprocess(image)
874924

875925
# 5. set timesteps
876-
self.scheduler.set_timesteps(num_inference_steps, device=device)
926+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
877927
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
878928
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
879929

pipelines/controlnet/pipeline_controlnet.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,51 @@
9191
"""
9292

9393

94+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
95+
def retrieve_timesteps(
96+
scheduler,
97+
num_inference_steps: Optional[int] = None,
98+
device: Optional[Union[str, torch.device]] = None,
99+
timesteps: Optional[List[int]] = None,
100+
**kwargs,
101+
):
102+
"""
103+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
104+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
105+
106+
Args:
107+
scheduler (`SchedulerMixin`):
108+
The scheduler to get timesteps from.
109+
num_inference_steps (`int`):
110+
The number of diffusion steps used when generating samples with a pre-trained model. If used,
111+
`timesteps` must be `None`.
112+
device (`str` or `torch.device`, *optional*):
113+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
114+
timesteps (`List[int]`, *optional*):
115+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
116+
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
117+
must be `None`.
118+
119+
Returns:
120+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
121+
second element is the number of inference steps.
122+
"""
123+
if timesteps is not None:
124+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
125+
if not accepts_timesteps:
126+
raise ValueError(
127+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
128+
f" timestep schedules. Please check whether you are using the correct scheduler."
129+
)
130+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
131+
timesteps = scheduler.timesteps
132+
num_inference_steps = len(timesteps)
133+
else:
134+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
135+
timesteps = scheduler.timesteps
136+
return timesteps, num_inference_steps
137+
138+
94139
class StableDiffusionControlNetPipeline(
95140
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
96141
):
@@ -812,6 +857,7 @@ def __call__(
812857
height: Optional[int] = None,
813858
width: Optional[int] = None,
814859
num_inference_steps: int = 50,
860+
timesteps: List[int] = None,
815861
guidance_scale: float = 7.5,
816862
negative_prompt: Optional[Union[str, List[str]]] = None,
817863
num_images_per_prompt: Optional[int] = 1,
@@ -854,6 +900,10 @@ def __call__(
854900
num_inference_steps (`int`, *optional*, defaults to 50):
855901
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
856902
expense of slower inference.
903+
timesteps (`List[int]`, *optional*):
904+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
905+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
906+
passed will be used. Must be in descending order.
857907
guidance_scale (`float`, *optional*, defaults to 7.5):
858908
A higher guidance scale value encourages the model to generate images closely linked to the text
859909
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -1059,8 +1109,7 @@ def __call__(
10591109
assert False
10601110

10611111
# 5. Prepare timesteps
1062-
self.scheduler.set_timesteps(num_inference_steps, device=device)
1063-
timesteps = self.scheduler.timesteps
1112+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
10641113
self._num_timesteps = len(timesteps)
10651114

10661115
# 6. Prepare latent variables

pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,51 @@ def retrieve_latents(encoder_output, generator):
5353
raise AttributeError("Could not access latents of provided encoder_output")
5454

5555

56+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
57+
def retrieve_timesteps(
58+
scheduler,
59+
num_inference_steps: Optional[int] = None,
60+
device: Optional[Union[str, torch.device]] = None,
61+
timesteps: Optional[List[int]] = None,
62+
**kwargs,
63+
):
64+
"""
65+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
66+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
67+
68+
Args:
69+
scheduler (`SchedulerMixin`):
70+
The scheduler to get timesteps from.
71+
num_inference_steps (`int`):
72+
The number of diffusion steps used when generating samples with a pre-trained model. If used,
73+
`timesteps` must be `None`.
74+
device (`str` or `torch.device`, *optional*):
75+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
76+
timesteps (`List[int]`, *optional*):
77+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
78+
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
79+
must be `None`.
80+
81+
Returns:
82+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
83+
second element is the number of inference steps.
84+
"""
85+
if timesteps is not None:
86+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
87+
if not accepts_timesteps:
88+
raise ValueError(
89+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
90+
f" timestep schedules. Please check whether you are using the correct scheduler."
91+
)
92+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
93+
timesteps = scheduler.timesteps
94+
num_inference_steps = len(timesteps)
95+
else:
96+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
97+
timesteps = scheduler.timesteps
98+
return timesteps, num_inference_steps
99+
100+
56101
EXAMPLE_DOC_STRING = """
57102
Examples:
58103
```py
@@ -592,6 +637,7 @@ def __call__(
592637
num_inference_steps: int = 4,
593638
strength: float = 0.8,
594639
original_inference_steps: int = None,
640+
timesteps: List[int] = None,
595641
guidance_scale: float = 8.5,
596642
num_images_per_prompt: Optional[int] = 1,
597643
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -623,6 +669,10 @@ def __call__(
623669
we will draw `num_inference_steps` evenly spaced timesteps from as our final timestep schedule,
624670
following the Skipping-Step method in the paper (see Section 4.3). If not set this will default to the
625671
scheduler's `original_inference_steps` attribute.
672+
timesteps (`List[int]`, *optional*):
673+
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
674+
timesteps on the original LCM training/distillation timestep schedule are used. Must be in descending
675+
order.
626676
guidance_scale (`float`, *optional*, defaults to 7.5):
627677
A higher guidance scale value encourages the model to generate images closely linked to the text
628678
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -728,10 +778,14 @@ def __call__(
728778
image = self.image_processor.preprocess(image)
729779

730780
# 5. Prepare timesteps
731-
self.scheduler.set_timesteps(
732-
num_inference_steps, device, original_inference_steps=original_inference_steps, strength=strength
781+
timesteps, num_inference_steps = retrieve_timesteps(
782+
self.scheduler,
783+
num_inference_steps,
784+
device,
785+
timesteps,
786+
original_inference_steps=original_inference_steps,
787+
strength=strength,
733788
)
734-
timesteps = self.scheduler.timesteps
735789

736790
# 6. Prepare latent variables
737791
original_inference_steps = (

0 commit comments

Comments
 (0)