Skip to content

Commit 482e7a2

Browse files
hlkyyiyixuxu
authored andcommitted
[Schedulers] Add exponential sigmas / exponential noise schedule (huggingface#9499)
* exponential sigmas * Apply suggestions from code review Co-authored-by: YiYi Xu <[email protected]> * make style --------- Co-authored-by: YiYi Xu <[email protected]>
1 parent fa130e1 commit 482e7a2

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

src/diffusers/schedulers/scheduling_euler_discrete.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
158158
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
159159
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
160160
the sigmas are determined according to a sequence of noise levels {σi}.
161+
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
162+
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
161163
timestep_spacing (`str`, defaults to `"linspace"`):
162164
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
163165
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
@@ -186,6 +188,7 @@ def __init__(
186188
prediction_type: str = "epsilon",
187189
interpolation_type: str = "linear",
188190
use_karras_sigmas: Optional[bool] = False,
191+
use_exponential_sigmas: Optional[bool] = False,
189192
sigma_min: Optional[float] = None,
190193
sigma_max: Optional[float] = None,
191194
timestep_spacing: str = "linspace",
@@ -235,6 +238,7 @@ def __init__(
235238

236239
self.is_scale_input_called = False
237240
self.use_karras_sigmas = use_karras_sigmas
241+
self.use_exponential_sigmas = use_exponential_sigmas
238242

239243
self._step_index = None
240244
self._begin_index = None
@@ -332,6 +336,12 @@ def set_timesteps(
332336
raise ValueError("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`.")
333337
if timesteps is not None and self.config.use_karras_sigmas:
334338
raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.")
339+
if timesteps is not None and self.config.use_exponential_sigmas:
340+
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
341+
if self.config.use_exponential_sigmas and self.config.use_karras_sigmas:
342+
raise ValueError(
343+
"Cannot set both `config.use_exponential_sigmas = True` and config.use_karras_sigmas = True`"
344+
)
335345
if (
336346
timesteps is not None
337347
and self.config.timestep_type == "continuous"
@@ -396,6 +406,10 @@ def set_timesteps(
396406
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
397407
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
398408

409+
elif self.config.use_exponential_sigmas:
410+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
411+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
412+
399413
if self.config.final_sigmas_type == "sigma_min":
400414
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
401415
elif self.config.final_sigmas_type == "zero":
@@ -468,6 +482,28 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
468482
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
469483
return sigmas
470484

485+
# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L26
486+
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
487+
"""Constructs an exponential noise schedule."""
488+
489+
# Hack to make sure that other schedulers which copy this function don't break
490+
# TODO: Add this logic to the other schedulers
491+
if hasattr(self.config, "sigma_min"):
492+
sigma_min = self.config.sigma_min
493+
else:
494+
sigma_min = None
495+
496+
if hasattr(self.config, "sigma_max"):
497+
sigma_max = self.config.sigma_max
498+
else:
499+
sigma_max = None
500+
501+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
502+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
503+
504+
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
505+
return sigmas
506+
471507
def index_for_timestep(self, timestep, schedule_timesteps=None):
472508
if schedule_timesteps is None:
473509
schedule_timesteps = self.timesteps

0 commit comments

Comments
 (0)