-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Description
I'm creating this issue to present my findings in relation to a discussion in #9416 about supporting additional schedulers used in A1111/Forge/Comfy etc. specifically simple, exponential, polyexponential and beta schedulers.
I've tested these schedulers and compared them to Diffusers with step counts 4, 8, 15, and 30. sgm_uniform is also included in these tests to confirm the findings in the above linked issue. On the Diffusers side we test timestep_spacing linspace, leading and trailing with both interpolation_type linear and log_linear.
Conclusions
Simple
simple schedule type is an exact match to timestep_spacing trailing interpolation_type linear.
SGM Uniform
As found in the link issue this schedule type is a near exact match to timestep_spacing trailing interpolation_type linear and in turn a near exact match to simple schedule type.
Exponential and Polyexponential
These schedule types produce the exact same results with the default rho (1.0) in Polyexponential, and there is no match to Diffusers, so needs to be implemented.
Beta
There is no match to Diffusers, so needs to be implemented.
Notes
Code used to produce these results is attached below test results, I've confirmed the results are accurate with simple modification to Forge that prints the sigmas for each tested schedule type.
Test results
4 steps
timestep_spacing=linspace, interpolation_type=linear
tensor([14.6146, 2.9183, 0.9324, 0.0292, 0.0000])timestep_spacing=leading, interpolation_type=linear
tensor([4.1167, 1.6237, 0.6984, 0.0413, 0.0000])timestep_spacing=trailing, interpolation_type=linear
tensor([14.6146, 4.0817, 1.6129, 0.6932, 0.0000])timestep_spacing=linspace, interpolation_type=log_linear
tensor([14.6146, 3.0890, 0.6529, 0.1380, 0.0292, 0.0000])timestep_spacing=leading, interpolation_type=log_linear
tensor([14.6146, 3.0890, 0.6529, 0.1380, 0.0292, 0.0000])timestep_spacing=trailing, interpolation_type=log_linear
tensor([14.6146, 3.0890, 0.6529, 0.1380, 0.0292, 0.0000])simple
tensor([14.6146, 4.0817, 1.6129, 0.6932, 0.0000])sgm_uniform
tensor([14.6146, 4.0861, 1.6156, 0.6952, 0.0000])exponential
tensor([14.6146, 1.8400, 0.2317, 0.0292, 0.0000])polyexponential
tensor([14.6147, 1.8400, 0.2317, 0.0292, 0.0000])beta
tensor([14.6146, 10.5976, 4.0462, 0.0292, 0.0000])8 steps
timestep_spacing=linspace, interpolation_type=linear
tensor([14.6146, 6.6780, 3.5221, 2.0606, 1.2768, 0.7913, 0.4397, 0.0292,
0.0000])timestep_spacing=leading, interpolation_type=linear
tensor([7.3718, 4.1167, 2.5109, 1.6237, 1.0760, 0.6984, 0.4022, 0.0413, 0.0000])timestep_spacing=trailing, interpolation_type=linear
tensor([14.6146, 7.2974, 4.0817, 2.4925, 1.6129, 1.0690, 0.6932, 0.3977,
0.0000])timestep_spacing=linspace, interpolation_type=log_linear
tensor([14.6146, 6.7190, 3.0890, 1.4201, 0.6529, 0.3002, 0.1380, 0.0634,
0.0292, 0.0000])timestep_spacing=leading, interpolation_type=log_linear
tensor([14.6146, 6.7190, 3.0890, 1.4201, 0.6529, 0.3002, 0.1380, 0.0634,
0.0292, 0.0000])timestep_spacing=trailing, interpolation_type=log_linear
tensor([14.6146, 6.7190, 3.0890, 1.4201, 0.6529, 0.3002, 0.1380, 0.0634,
0.0292, 0.0000])simple
tensor([14.6146, 7.2974, 4.0817, 2.4925, 1.6129, 1.0690, 0.6932, 0.3977,
0.0000])sgm_uniform
tensor([14.6146, 7.3020, 4.0861, 2.4960, 1.6156, 1.0712, 0.6952, 0.3997,
0.0000])exponential
tensor([14.6146, 6.0130, 2.4740, 1.0179, 0.4188, 0.1723, 0.0709, 0.0292,
0.0000])polyexponential
tensor([14.6147, 6.0130, 2.4740, 1.0179, 0.4188, 0.1723, 0.0709, 0.0292,
0.0000])beta
tensor([14.6146, 13.5770, 11.4518, 8.7596, 5.8842, 3.1920, 1.0668, 0.0292,
0.0000])15 steps
timestep_spacing=linspace, interpolation_type=linear
tensor([14.6146, 9.6826, 6.6780, 4.7746, 3.5221, 2.6666, 2.0606, 1.6156,
1.2768, 1.0097, 0.7913, 0.6056, 0.4397, 0.2780, 0.0292, 0.0000])timestep_spacing=leading, interpolation_type=linear
tensor([9.5436, 6.7684, 4.9510, 3.7216, 2.8629, 2.2441, 1.7841, 1.4316, 1.1530,
0.9261, 0.7353, 0.5693, 0.4179, 0.2677, 0.0413, 0.0000])timestep_spacing=trailing, interpolation_type=linear
tensor([14.6146, 9.9172, 7.0089, 5.0878, 3.7997, 2.9183, 2.2765, 1.8024,
1.4458, 1.1606, 0.9292, 0.7380, 0.5693, 0.4156, 0.2653, 0.0000])timestep_spacing=linspace, interpolation_type=log_linear
tensor([14.6146, 9.6560, 6.3797, 4.2151, 2.7850, 1.8400, 1.2157, 0.8032,
0.5307, 0.3506, 0.2317, 0.1531, 0.1011, 0.0668, 0.0441, 0.0292,
0.0000])timestep_spacing=leading, interpolation_type=log_linear
tensor([14.6146, 9.6560, 6.3797, 4.2151, 2.7850, 1.8400, 1.2157, 0.8032,
0.5307, 0.3506, 0.2317, 0.1531, 0.1011, 0.0668, 0.0441, 0.0292,
0.0000])timestep_spacing=trailing, interpolation_type=log_linear
tensor([14.6146, 9.6560, 6.3797, 4.2151, 2.7850, 1.8400, 1.2157, 0.8032,
0.5307, 0.3506, 0.2317, 0.1531, 0.1011, 0.0668, 0.0441, 0.0292,
0.0000])simple
tensor([14.6146, 9.9720, 7.0089, 5.0878, 3.8155, 2.9183, 2.2765, 1.8085,
1.4458, 1.1606, 0.9324, 0.7380, 0.5693, 0.4179, 0.2653, 0.0000])sgm_uniform
tensor([14.6146, 9.9391, 7.0019, 5.0924, 3.8092, 2.9183, 2.2797, 1.8073,
1.4467, 1.1629, 0.9324, 0.7391, 0.5712, 0.4183, 0.2667, 0.0000])exponential
tensor([14.6146, 9.3743, 6.0130, 3.8569, 2.4740, 1.5869, 1.0179, 0.6529,
0.4188, 0.2686, 0.1723, 0.1105, 0.0709, 0.0455, 0.0292, 0.0000])polyexponential
tensor([14.6147, 9.3743, 6.0130, 3.8569, 2.4740, 1.5869, 1.0179, 0.6529,
0.4188, 0.2686, 0.1723, 0.1105, 0.0709, 0.0455, 0.0292, 0.0000])beta
tensor([14.6146, 14.2837, 13.5770, 12.6113, 11.4518, 10.1517, 8.7596, 7.3219,
5.8842, 4.4921, 3.1920, 2.0325, 1.0668, 0.3601, 0.0292, 0.0000])30 steps
timestep_spacing=linspace, interpolation_type=linear
tensor([14.6146, 11.9176, 9.8142, 8.1585, 6.8431, 5.7886, 4.9356, 4.2397,
3.6669, 3.1913, 2.7931, 2.4569, 2.1705, 1.9246, 1.7116, 1.5257,
1.3619, 1.2166, 1.0865, 0.9691, 0.8622, 0.7640, 0.6730, 0.5877,
0.5067, 0.4286, 0.3515, 0.2722, 0.1835, 0.0292, 0.0000])timestep_spacing=leading, interpolation_type=linear
tensor([11.4769, 9.5436, 8.0043, 6.7684, 5.7678, 4.9510, 4.2790, 3.7216,
3.2556, 2.8629, 2.5295, 2.2441, 1.9980, 1.7841, 1.5968, 1.4316,
1.2846, 1.1530, 1.0342, 0.9261, 0.8270, 0.7353, 0.6499, 0.5693,
0.4924, 0.4179, 0.3439, 0.2677, 0.1822, 0.0413, 0.0000])timestep_spacing=trailing, interpolation_type=linear
tensor([14.6146, 12.0177, 9.9172, 8.3028, 7.0089, 5.9347, 5.0878, 4.3919,
3.7997, 3.3211, 2.9183, 2.5671, 2.2765, 2.0260, 1.8024, 1.6129,
1.4458, 1.2931, 1.1606, 1.0410, 0.9292, 0.8299, 0.7380, 0.6499,
0.5693, 0.4924, 0.4156, 0.3417, 0.2653, 0.1763, 0.0000])timestep_spacing=linspace, interpolation_type=log_linear
tensor([14.6146, 11.8793, 9.6560, 7.8487, 6.3797, 5.1857, 4.2151, 3.4262,
2.7850, 2.2637, 1.8400, 1.4956, 1.2157, 0.9882, 0.8032, 0.6529,
0.5307, 0.4314, 0.3506, 0.2850, 0.2317, 0.1883, 0.1531, 0.1244,
0.1011, 0.0822, 0.0668, 0.0543, 0.0441, 0.0359, 0.0292, 0.0000])timestep_spacing=leading, interpolation_type=log_linear
tensor([14.6146, 11.8793, 9.6560, 7.8487, 6.3797, 5.1857, 4.2151, 3.4262,
2.7850, 2.2637, 1.8400, 1.4956, 1.2157, 0.9882, 0.8032, 0.6529,
0.5307, 0.4314, 0.3506, 0.2850, 0.2317, 0.1883, 0.1531, 0.1244,
0.1011, 0.0822, 0.0668, 0.0543, 0.0441, 0.0359, 0.0292, 0.0000])timestep_spacing=trailing, interpolation_type=log_linear
tensor([14.6146, 11.8793, 9.6560, 7.8487, 6.3797, 5.1857, 4.2151, 3.4262,
2.7850, 2.2637, 1.8400, 1.4956, 1.2157, 0.9882, 0.8032, 0.6529,
0.5307, 0.4314, 0.3506, 0.2850, 0.2317, 0.1883, 0.1531, 0.1244,
0.1011, 0.0822, 0.0668, 0.0543, 0.0441, 0.0359, 0.0292, 0.0000])simple
tensor([14.6146, 12.0177, 9.9720, 8.3028, 7.0089, 5.9631, 5.0878, 4.3919,
3.8155, 3.3211, 2.9183, 2.5767, 2.2765, 2.0260, 1.8085, 1.6129,
1.4458, 1.2973, 1.1606, 1.0410, 0.9324, 0.8299, 0.7380, 0.6524,
0.5693, 0.4924, 0.4179, 0.3417, 0.2653, 0.1793, 0.0000])sgm_uniform
tensor([14.6146, 11.9969, 9.9391, 8.3072, 7.0019, 5.9489, 5.0924, 4.3900,
3.8092, 3.3251, 2.9183, 2.5738, 2.2797, 2.0267, 1.8073, 1.6156,
1.4467, 1.2969, 1.1629, 1.0421, 0.9324, 0.8319, 0.7391, 0.6526,
0.5712, 0.4936, 0.4183, 0.3437, 0.2667, 0.1801, 0.0000])exponential
tensor([14.6146, 11.7947, 9.5190, 7.6823, 6.2000, 5.0037, 4.0382, 3.2591,
2.6302, 2.1227, 1.7131, 1.3826, 1.1158, 0.9005, 0.7268, 0.5865,
0.4734, 0.3820, 0.3083, 0.2488, 0.2008, 0.1621, 0.1308, 0.1056,
0.0852, 0.0688, 0.0555, 0.0448, 0.0361, 0.0292, 0.0000])polyexponential
tensor([14.6147, 11.7948, 9.5190, 7.6823, 6.2000, 5.0037, 4.0382, 3.2591,
2.6302, 2.1227, 1.7131, 1.3826, 1.1158, 0.9005, 0.7268, 0.5865,
0.4734, 0.3820, 0.3083, 0.2488, 0.2008, 0.1621, 0.1308, 0.1056,
0.0852, 0.0688, 0.0555, 0.0448, 0.0361, 0.0292, 0.0000])beta
tensor([14.6146, 14.5159, 14.3024, 14.0041, 13.6349, 13.2047, 12.7212, 12.1912,
11.6212, 11.0168, 10.3837, 9.7273, 9.0529, 8.3656, 7.6707, 6.9732,
6.2782, 5.5909, 4.9165, 4.2601, 3.6270, 3.0226, 2.4526, 1.9227,
1.4392, 1.0089, 0.6397, 0.3414, 0.1279, 0.0292, 0.0000])Code
from diffusers import EulerDiscreteScheduler
import torch
import math
import numpy as np
from scipy import stats
beta_start = 0.00085
beta_end = 0.012
num_train_timesteps = 1000
betas = (
torch.linspace(
beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32
)
** 2
)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
# not flipped, contrary to diffusers
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
log_sigmas = sigmas.log()
discard_next_to_last_sigma = False
def sigma_to_t(sigma: torch.Tensor):
log_sigma = sigma.log()
dists = log_sigma - log_sigmas[:, None]
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low, high = log_sigmas[low_idx], log_sigmas[high_idx]
w = (low - log_sigma) / (low - high)
w = w.clamp(0, 1)
t = (1 - w) * low_idx + w * high_idx
return t.view(sigma.shape)
def t_to_sigma(t: torch.Tensor):
t = t.float()
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx]
return log_sigma.exp()
m_sigma_min, m_sigma_max = (sigmas[0].item(), sigmas[-1].item())
def append_zero(x: torch.Tensor):
return torch.cat([x, x.new_zeros([1])])
def simple_scheduler(n: int):
sigs = []
ss = len(sigmas) / n
for x in range(n):
sigs += [float(sigmas[-(1 + int(x * ss))])]
sigs += [0.0]
return torch.FloatTensor(sigs)
def sgm_uniform(n: int, sigma_min: float, sigma_max: float):
start = sigma_to_t(torch.tensor(sigma_max))
end = sigma_to_t(torch.tensor(sigma_min))
sigs = [t_to_sigma(ts) for ts in torch.linspace(start, end, n)[:-1]]
sigs += [0.0]
return torch.FloatTensor(sigs)
def get_sigmas_polyexponential(
n: int, sigma_min: float, sigma_max: float, rho: float = 1.0
):
"""Constructs an polynomial in log sigma noise schedule."""
ramp = torch.linspace(1, 0, n) ** rho
sigmas = torch.exp(
ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min)
)
return append_zero(sigmas)
def get_sigmas_exponential(n: int, sigma_min: float, sigma_max: float):
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n).exp()
return append_zero(sigmas)
def beta_scheduler(
n: int, sigma_min: float, sigma_max: float, alpha: float = 0.6, beta: float = 0.6
):
# From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024) """
timesteps = 1 - np.linspace(0, 1, n)
timesteps = [stats.beta.ppf(x, alpha, beta) for x in timesteps]
sigmas = [sigma_min + (x * (sigma_max - sigma_min)) for x in timesteps]
sigmas += [0.0]
return torch.FloatTensor(sigmas)
def diffusers_scheduler(
num_inference_steps: int, timestep_spacing: str, interpolation_type: str
):
scheduler: EulerDiscreteScheduler = EulerDiscreteScheduler.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
timestep_spacing=timestep_spacing,
interpolation_type=interpolation_type,
subfolder="scheduler",
)
scheduler.set_timesteps(num_inference_steps=num_inference_steps)
print(
f"### timestep_spacing={timestep_spacing}, interpolation_type={interpolation_type}"
)
print(f"```python\n{scheduler.sigmas}\n```")
def simple(num_inference_steps: int):
simple_sigmas = simple_scheduler(num_inference_steps)
print(f"### simple")
print(f"```python\n{simple_sigmas}\n```")
def sgm(num_inference_steps: int):
sgm_uniform_sigmas = sgm_uniform(
n=num_inference_steps + (1 if not discard_next_to_last_sigma else 0),
sigma_min=m_sigma_min,
sigma_max=m_sigma_max,
)
print(f"### sgm_uniform")
print(f"```python\n{sgm_uniform_sigmas}\n```")
def exponential(num_inference_steps: int):
exponential_sigmas = get_sigmas_exponential(
num_inference_steps, m_sigma_min, m_sigma_max
)
print(f"### exponential")
print(f"```python\n{exponential_sigmas}\n```")
def polyexponential(num_inference_steps: int):
polyexponential_sigmas = get_sigmas_polyexponential(
num_inference_steps, m_sigma_min, m_sigma_max
)
print(f"### polyexponential")
print(f"```python\n{polyexponential_sigmas}\n```")
def beta(num_inference_steps: int):
beta_sigmas = beta_scheduler(
num_inference_steps, m_sigma_min, m_sigma_max, alpha=0.6, beta=0.6
)
print(f"### beta")
print(f"```python\n{beta_sigmas}\n```")
def sigmas_for_steps(num_inference_steps: int):
print(f"## {steps}")
diffusers_scheduler(
num_inference_steps, timestep_spacing="linspace", interpolation_type="linear"
)
diffusers_scheduler(
num_inference_steps, timestep_spacing="leading", interpolation_type="linear"
)
diffusers_scheduler(
num_inference_steps, timestep_spacing="trailing", interpolation_type="linear"
)
diffusers_scheduler(
num_inference_steps,
timestep_spacing="linspace",
interpolation_type="log_linear",
)
diffusers_scheduler(
num_inference_steps, timestep_spacing="leading", interpolation_type="log_linear"
)
diffusers_scheduler(
num_inference_steps,
timestep_spacing="trailing",
interpolation_type="log_linear",
)
simple(num_inference_steps)
sgm(num_inference_steps)
exponential(num_inference_steps)
polyexponential(num_inference_steps)
beta(num_inference_steps)
for steps in [4, 8, 15, 30]:
sigmas_for_steps(steps)cc @asomoza