Skip to content

[Schedulers] Analysis of simple, exponential, polyexponential and beta #9490

@hlky

Description

@hlky

I'm creating this issue to present my findings in relation to a discussion in #9416 about supporting additional schedulers used in A1111/Forge/Comfy etc. specifically simple, exponential, polyexponential and beta schedulers.

I've tested these schedulers and compared them to Diffusers with step counts 4, 8, 15, and 30. sgm_uniform is also included in these tests to confirm the findings in the above linked issue. On the Diffusers side we test timestep_spacing linspace, leading and trailing with both interpolation_type linear and log_linear.

Conclusions

Simple

simple schedule type is an exact match to timestep_spacing trailing interpolation_type linear.

SGM Uniform

As found in the link issue this schedule type is a near exact match to timestep_spacing trailing interpolation_type linear and in turn a near exact match to simple schedule type.

Exponential and Polyexponential

These schedule types produce the exact same results with the default rho (1.0) in Polyexponential, and there is no match to Diffusers, so needs to be implemented.

Beta

There is no match to Diffusers, so needs to be implemented.

Notes

Code used to produce these results is attached below test results, I've confirmed the results are accurate with simple modification to Forge that prints the sigmas for each tested schedule type.

Test results

4 steps

timestep_spacing=linspace, interpolation_type=linear

tensor([14.6146,  2.9183,  0.9324,  0.0292,  0.0000])

timestep_spacing=leading, interpolation_type=linear

tensor([4.1167, 1.6237, 0.6984, 0.0413, 0.0000])

timestep_spacing=trailing, interpolation_type=linear

tensor([14.6146,  4.0817,  1.6129,  0.6932,  0.0000])

timestep_spacing=linspace, interpolation_type=log_linear

tensor([14.6146,  3.0890,  0.6529,  0.1380,  0.0292,  0.0000])

timestep_spacing=leading, interpolation_type=log_linear

tensor([14.6146,  3.0890,  0.6529,  0.1380,  0.0292,  0.0000])

timestep_spacing=trailing, interpolation_type=log_linear

tensor([14.6146,  3.0890,  0.6529,  0.1380,  0.0292,  0.0000])

simple

tensor([14.6146,  4.0817,  1.6129,  0.6932,  0.0000])

sgm_uniform

tensor([14.6146,  4.0861,  1.6156,  0.6952,  0.0000])

exponential

tensor([14.6146,  1.8400,  0.2317,  0.0292,  0.0000])

polyexponential

tensor([14.6147,  1.8400,  0.2317,  0.0292,  0.0000])

beta

tensor([14.6146, 10.5976,  4.0462,  0.0292,  0.0000])

8 steps

timestep_spacing=linspace, interpolation_type=linear

tensor([14.6146,  6.6780,  3.5221,  2.0606,  1.2768,  0.7913,  0.4397,  0.0292,
         0.0000])

timestep_spacing=leading, interpolation_type=linear

tensor([7.3718, 4.1167, 2.5109, 1.6237, 1.0760, 0.6984, 0.4022, 0.0413, 0.0000])

timestep_spacing=trailing, interpolation_type=linear

tensor([14.6146,  7.2974,  4.0817,  2.4925,  1.6129,  1.0690,  0.6932,  0.3977,
         0.0000])

timestep_spacing=linspace, interpolation_type=log_linear

tensor([14.6146,  6.7190,  3.0890,  1.4201,  0.6529,  0.3002,  0.1380,  0.0634,
         0.0292,  0.0000])

timestep_spacing=leading, interpolation_type=log_linear

tensor([14.6146,  6.7190,  3.0890,  1.4201,  0.6529,  0.3002,  0.1380,  0.0634,
         0.0292,  0.0000])

timestep_spacing=trailing, interpolation_type=log_linear

tensor([14.6146,  6.7190,  3.0890,  1.4201,  0.6529,  0.3002,  0.1380,  0.0634,
         0.0292,  0.0000])

simple

tensor([14.6146,  7.2974,  4.0817,  2.4925,  1.6129,  1.0690,  0.6932,  0.3977,
         0.0000])

sgm_uniform

tensor([14.6146,  7.3020,  4.0861,  2.4960,  1.6156,  1.0712,  0.6952,  0.3997,
         0.0000])

exponential

tensor([14.6146,  6.0130,  2.4740,  1.0179,  0.4188,  0.1723,  0.0709,  0.0292,
         0.0000])

polyexponential

tensor([14.6147,  6.0130,  2.4740,  1.0179,  0.4188,  0.1723,  0.0709,  0.0292,
         0.0000])

beta

tensor([14.6146, 13.5770, 11.4518,  8.7596,  5.8842,  3.1920,  1.0668,  0.0292,
         0.0000])

15 steps

timestep_spacing=linspace, interpolation_type=linear

tensor([14.6146,  9.6826,  6.6780,  4.7746,  3.5221,  2.6666,  2.0606,  1.6156,
         1.2768,  1.0097,  0.7913,  0.6056,  0.4397,  0.2780,  0.0292,  0.0000])

timestep_spacing=leading, interpolation_type=linear

tensor([9.5436, 6.7684, 4.9510, 3.7216, 2.8629, 2.2441, 1.7841, 1.4316, 1.1530,
        0.9261, 0.7353, 0.5693, 0.4179, 0.2677, 0.0413, 0.0000])

timestep_spacing=trailing, interpolation_type=linear

tensor([14.6146,  9.9172,  7.0089,  5.0878,  3.7997,  2.9183,  2.2765,  1.8024,
         1.4458,  1.1606,  0.9292,  0.7380,  0.5693,  0.4156,  0.2653,  0.0000])

timestep_spacing=linspace, interpolation_type=log_linear

tensor([14.6146,  9.6560,  6.3797,  4.2151,  2.7850,  1.8400,  1.2157,  0.8032,
         0.5307,  0.3506,  0.2317,  0.1531,  0.1011,  0.0668,  0.0441,  0.0292,
         0.0000])

timestep_spacing=leading, interpolation_type=log_linear

tensor([14.6146,  9.6560,  6.3797,  4.2151,  2.7850,  1.8400,  1.2157,  0.8032,
         0.5307,  0.3506,  0.2317,  0.1531,  0.1011,  0.0668,  0.0441,  0.0292,
         0.0000])

timestep_spacing=trailing, interpolation_type=log_linear

tensor([14.6146,  9.6560,  6.3797,  4.2151,  2.7850,  1.8400,  1.2157,  0.8032,
         0.5307,  0.3506,  0.2317,  0.1531,  0.1011,  0.0668,  0.0441,  0.0292,
         0.0000])

simple

tensor([14.6146,  9.9720,  7.0089,  5.0878,  3.8155,  2.9183,  2.2765,  1.8085,
         1.4458,  1.1606,  0.9324,  0.7380,  0.5693,  0.4179,  0.2653,  0.0000])

sgm_uniform

tensor([14.6146,  9.9391,  7.0019,  5.0924,  3.8092,  2.9183,  2.2797,  1.8073,
         1.4467,  1.1629,  0.9324,  0.7391,  0.5712,  0.4183,  0.2667,  0.0000])

exponential

tensor([14.6146,  9.3743,  6.0130,  3.8569,  2.4740,  1.5869,  1.0179,  0.6529,
         0.4188,  0.2686,  0.1723,  0.1105,  0.0709,  0.0455,  0.0292,  0.0000])

polyexponential

tensor([14.6147,  9.3743,  6.0130,  3.8569,  2.4740,  1.5869,  1.0179,  0.6529,
         0.4188,  0.2686,  0.1723,  0.1105,  0.0709,  0.0455,  0.0292,  0.0000])

beta

tensor([14.6146, 14.2837, 13.5770, 12.6113, 11.4518, 10.1517,  8.7596,  7.3219,
         5.8842,  4.4921,  3.1920,  2.0325,  1.0668,  0.3601,  0.0292,  0.0000])

30 steps

timestep_spacing=linspace, interpolation_type=linear

tensor([14.6146, 11.9176,  9.8142,  8.1585,  6.8431,  5.7886,  4.9356,  4.2397,
         3.6669,  3.1913,  2.7931,  2.4569,  2.1705,  1.9246,  1.7116,  1.5257,
         1.3619,  1.2166,  1.0865,  0.9691,  0.8622,  0.7640,  0.6730,  0.5877,
         0.5067,  0.4286,  0.3515,  0.2722,  0.1835,  0.0292,  0.0000])

timestep_spacing=leading, interpolation_type=linear

tensor([11.4769,  9.5436,  8.0043,  6.7684,  5.7678,  4.9510,  4.2790,  3.7216,
         3.2556,  2.8629,  2.5295,  2.2441,  1.9980,  1.7841,  1.5968,  1.4316,
         1.2846,  1.1530,  1.0342,  0.9261,  0.8270,  0.7353,  0.6499,  0.5693,
         0.4924,  0.4179,  0.3439,  0.2677,  0.1822,  0.0413,  0.0000])

timestep_spacing=trailing, interpolation_type=linear

tensor([14.6146, 12.0177,  9.9172,  8.3028,  7.0089,  5.9347,  5.0878,  4.3919,
         3.7997,  3.3211,  2.9183,  2.5671,  2.2765,  2.0260,  1.8024,  1.6129,
         1.4458,  1.2931,  1.1606,  1.0410,  0.9292,  0.8299,  0.7380,  0.6499,
         0.5693,  0.4924,  0.4156,  0.3417,  0.2653,  0.1763,  0.0000])

timestep_spacing=linspace, interpolation_type=log_linear

tensor([14.6146, 11.8793,  9.6560,  7.8487,  6.3797,  5.1857,  4.2151,  3.4262,
         2.7850,  2.2637,  1.8400,  1.4956,  1.2157,  0.9882,  0.8032,  0.6529,
         0.5307,  0.4314,  0.3506,  0.2850,  0.2317,  0.1883,  0.1531,  0.1244,
         0.1011,  0.0822,  0.0668,  0.0543,  0.0441,  0.0359,  0.0292,  0.0000])

timestep_spacing=leading, interpolation_type=log_linear

tensor([14.6146, 11.8793,  9.6560,  7.8487,  6.3797,  5.1857,  4.2151,  3.4262,
         2.7850,  2.2637,  1.8400,  1.4956,  1.2157,  0.9882,  0.8032,  0.6529,
         0.5307,  0.4314,  0.3506,  0.2850,  0.2317,  0.1883,  0.1531,  0.1244,
         0.1011,  0.0822,  0.0668,  0.0543,  0.0441,  0.0359,  0.0292,  0.0000])

timestep_spacing=trailing, interpolation_type=log_linear

tensor([14.6146, 11.8793,  9.6560,  7.8487,  6.3797,  5.1857,  4.2151,  3.4262,
         2.7850,  2.2637,  1.8400,  1.4956,  1.2157,  0.9882,  0.8032,  0.6529,
         0.5307,  0.4314,  0.3506,  0.2850,  0.2317,  0.1883,  0.1531,  0.1244,
         0.1011,  0.0822,  0.0668,  0.0543,  0.0441,  0.0359,  0.0292,  0.0000])

simple

tensor([14.6146, 12.0177,  9.9720,  8.3028,  7.0089,  5.9631,  5.0878,  4.3919,
         3.8155,  3.3211,  2.9183,  2.5767,  2.2765,  2.0260,  1.8085,  1.6129,
         1.4458,  1.2973,  1.1606,  1.0410,  0.9324,  0.8299,  0.7380,  0.6524,
         0.5693,  0.4924,  0.4179,  0.3417,  0.2653,  0.1793,  0.0000])

sgm_uniform

tensor([14.6146, 11.9969,  9.9391,  8.3072,  7.0019,  5.9489,  5.0924,  4.3900,
         3.8092,  3.3251,  2.9183,  2.5738,  2.2797,  2.0267,  1.8073,  1.6156,
         1.4467,  1.2969,  1.1629,  1.0421,  0.9324,  0.8319,  0.7391,  0.6526,
         0.5712,  0.4936,  0.4183,  0.3437,  0.2667,  0.1801,  0.0000])

exponential

tensor([14.6146, 11.7947,  9.5190,  7.6823,  6.2000,  5.0037,  4.0382,  3.2591,
         2.6302,  2.1227,  1.7131,  1.3826,  1.1158,  0.9005,  0.7268,  0.5865,
         0.4734,  0.3820,  0.3083,  0.2488,  0.2008,  0.1621,  0.1308,  0.1056,
         0.0852,  0.0688,  0.0555,  0.0448,  0.0361,  0.0292,  0.0000])

polyexponential

tensor([14.6147, 11.7948,  9.5190,  7.6823,  6.2000,  5.0037,  4.0382,  3.2591,
         2.6302,  2.1227,  1.7131,  1.3826,  1.1158,  0.9005,  0.7268,  0.5865,
         0.4734,  0.3820,  0.3083,  0.2488,  0.2008,  0.1621,  0.1308,  0.1056,
         0.0852,  0.0688,  0.0555,  0.0448,  0.0361,  0.0292,  0.0000])

beta

tensor([14.6146, 14.5159, 14.3024, 14.0041, 13.6349, 13.2047, 12.7212, 12.1912,
        11.6212, 11.0168, 10.3837,  9.7273,  9.0529,  8.3656,  7.6707,  6.9732,
         6.2782,  5.5909,  4.9165,  4.2601,  3.6270,  3.0226,  2.4526,  1.9227,
         1.4392,  1.0089,  0.6397,  0.3414,  0.1279,  0.0292,  0.0000])

Code

from diffusers import EulerDiscreteScheduler
import torch
import math
import numpy as np
from scipy import stats

beta_start = 0.00085
beta_end = 0.012
num_train_timesteps = 1000
betas = (
    torch.linspace(
        beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32
    )
    ** 2
)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
# not flipped, contrary to diffusers
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
log_sigmas = sigmas.log()
discard_next_to_last_sigma = False


def sigma_to_t(sigma: torch.Tensor):
    log_sigma = sigma.log()
    dists = log_sigma - log_sigmas[:, None]
    low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=log_sigmas.shape[0] - 2)
    high_idx = low_idx + 1
    low, high = log_sigmas[low_idx], log_sigmas[high_idx]
    w = (low - log_sigma) / (low - high)
    w = w.clamp(0, 1)
    t = (1 - w) * low_idx + w * high_idx
    return t.view(sigma.shape)


def t_to_sigma(t: torch.Tensor):
    t = t.float()
    low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
    log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx]
    return log_sigma.exp()


m_sigma_min, m_sigma_max = (sigmas[0].item(), sigmas[-1].item())


def append_zero(x: torch.Tensor):
    return torch.cat([x, x.new_zeros([1])])


def simple_scheduler(n: int):
    sigs = []
    ss = len(sigmas) / n
    for x in range(n):
        sigs += [float(sigmas[-(1 + int(x * ss))])]
    sigs += [0.0]
    return torch.FloatTensor(sigs)


def sgm_uniform(n: int, sigma_min: float, sigma_max: float):
    start = sigma_to_t(torch.tensor(sigma_max))
    end = sigma_to_t(torch.tensor(sigma_min))
    sigs = [t_to_sigma(ts) for ts in torch.linspace(start, end, n)[:-1]]
    sigs += [0.0]
    return torch.FloatTensor(sigs)


def get_sigmas_polyexponential(
    n: int, sigma_min: float, sigma_max: float, rho: float = 1.0
):
    """Constructs an polynomial in log sigma noise schedule."""
    ramp = torch.linspace(1, 0, n) ** rho
    sigmas = torch.exp(
        ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min)
    )
    return append_zero(sigmas)


def get_sigmas_exponential(n: int, sigma_min: float, sigma_max: float):
    sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n).exp()
    return append_zero(sigmas)


def beta_scheduler(
    n: int, sigma_min: float, sigma_max: float, alpha: float = 0.6, beta: float = 0.6
):
    # From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024) """
    timesteps = 1 - np.linspace(0, 1, n)
    timesteps = [stats.beta.ppf(x, alpha, beta) for x in timesteps]
    sigmas = [sigma_min + (x * (sigma_max - sigma_min)) for x in timesteps]
    sigmas += [0.0]
    return torch.FloatTensor(sigmas)


def diffusers_scheduler(
    num_inference_steps: int, timestep_spacing: str, interpolation_type: str
):
    scheduler: EulerDiscreteScheduler = EulerDiscreteScheduler.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        timestep_spacing=timestep_spacing,
        interpolation_type=interpolation_type,
        subfolder="scheduler",
    )
    scheduler.set_timesteps(num_inference_steps=num_inference_steps)
    print(
        f"### timestep_spacing={timestep_spacing}, interpolation_type={interpolation_type}"
    )
    print(f"```python\n{scheduler.sigmas}\n```")


def simple(num_inference_steps: int):
    simple_sigmas = simple_scheduler(num_inference_steps)
    print(f"### simple")
    print(f"```python\n{simple_sigmas}\n```")


def sgm(num_inference_steps: int):
    sgm_uniform_sigmas = sgm_uniform(
        n=num_inference_steps + (1 if not discard_next_to_last_sigma else 0),
        sigma_min=m_sigma_min,
        sigma_max=m_sigma_max,
    )
    print(f"### sgm_uniform")
    print(f"```python\n{sgm_uniform_sigmas}\n```")


def exponential(num_inference_steps: int):
    exponential_sigmas = get_sigmas_exponential(
        num_inference_steps, m_sigma_min, m_sigma_max
    )
    print(f"### exponential")
    print(f"```python\n{exponential_sigmas}\n```")


def polyexponential(num_inference_steps: int):
    polyexponential_sigmas = get_sigmas_polyexponential(
        num_inference_steps, m_sigma_min, m_sigma_max
    )
    print(f"### polyexponential")
    print(f"```python\n{polyexponential_sigmas}\n```")


def beta(num_inference_steps: int):
    beta_sigmas = beta_scheduler(
        num_inference_steps, m_sigma_min, m_sigma_max, alpha=0.6, beta=0.6
    )
    print(f"### beta")
    print(f"```python\n{beta_sigmas}\n```")


def sigmas_for_steps(num_inference_steps: int):
    print(f"## {steps}")
    diffusers_scheduler(
        num_inference_steps, timestep_spacing="linspace", interpolation_type="linear"
    )
    diffusers_scheduler(
        num_inference_steps, timestep_spacing="leading", interpolation_type="linear"
    )
    diffusers_scheduler(
        num_inference_steps, timestep_spacing="trailing", interpolation_type="linear"
    )
    diffusers_scheduler(
        num_inference_steps,
        timestep_spacing="linspace",
        interpolation_type="log_linear",
    )
    diffusers_scheduler(
        num_inference_steps, timestep_spacing="leading", interpolation_type="log_linear"
    )
    diffusers_scheduler(
        num_inference_steps,
        timestep_spacing="trailing",
        interpolation_type="log_linear",
    )
    simple(num_inference_steps)
    sgm(num_inference_steps)
    exponential(num_inference_steps)
    polyexponential(num_inference_steps)
    beta(num_inference_steps)


for steps in [4, 8, 15, 30]:
    sigmas_for_steps(steps)

cc @asomoza

Metadata

Metadata

Assignees

No one assigned

    Labels

    schedulerstaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions