huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
Apache License 2.0
25.44k stars 5.27k forks source link

[Schedulers] Analysis of `simple`, `exponential`, `polyexponential` and `beta` #9490

Open hlky opened 2 weeks ago

hlky commented 2 weeks ago

I'm creating this issue to present my findings in relation to a discussion in #9416 about supporting additional schedulers used in A1111/Forge/Comfy etc. specifically simple, exponential, polyexponential and beta schedulers.

I've tested these schedulers and compared them to Diffusers with step counts 4, 8, 15, and 30. sgm_uniform is also included in these tests to confirm the findings in the above linked issue. On the Diffusers side we test timestep_spacing linspace, leading and trailing with both interpolation_type linear and log_linear.



simple schedule type is an exact match to timestep_spacing trailing interpolation_type linear.

SGM Uniform

As found in the link issue this schedule type is a near exact match to timestep_spacing trailing interpolation_type linear and in turn a near exact match to simple schedule type.

Exponential and Polyexponential

These schedule types produce the exact same results with the default rho (1.0) in Polyexponential, and there is no match to Diffusers, so needs to be implemented.


There is no match to Diffusers, so needs to be implemented.


Code used to produce these results is attached below test results, I've confirmed the results are accurate with simple modification to Forge that prints the sigmas for each tested schedule type.

Test results

4 steps

timestep_spacing=linspace, interpolation_type=linear

tensor([14.6146,  2.9183,  0.9324,  0.0292,  0.0000])

timestep_spacing=leading, interpolation_type=linear

tensor([4.1167, 1.6237, 0.6984, 0.0413, 0.0000])

timestep_spacing=trailing, interpolation_type=linear

tensor([14.6146,  4.0817,  1.6129,  0.6932,  0.0000])

timestep_spacing=linspace, interpolation_type=log_linear

tensor([14.6146,  3.0890,  0.6529,  0.1380,  0.0292,  0.0000])

timestep_spacing=leading, interpolation_type=log_linear

tensor([14.6146,  3.0890,  0.6529,  0.1380,  0.0292,  0.0000])

timestep_spacing=trailing, interpolation_type=log_linear

tensor([14.6146,  3.0890,  0.6529,  0.1380,  0.0292,  0.0000])


tensor([14.6146,  4.0817,  1.6129,  0.6932,  0.0000])


tensor([14.6146,  4.0861,  1.6156,  0.6952,  0.0000])


tensor([14.6146,  1.8400,  0.2317,  0.0292,  0.0000])


tensor([14.6147,  1.8400,  0.2317,  0.0292,  0.0000])


tensor([14.6146, 10.5976,  4.0462,  0.0292,  0.0000])

8 steps

timestep_spacing=linspace, interpolation_type=linear

tensor([14.6146,  6.6780,  3.5221,  2.0606,  1.2768,  0.7913,  0.4397,  0.0292,

timestep_spacing=leading, interpolation_type=linear

tensor([7.3718, 4.1167, 2.5109, 1.6237, 1.0760, 0.6984, 0.4022, 0.0413, 0.0000])

timestep_spacing=trailing, interpolation_type=linear

tensor([14.6146,  7.2974,  4.0817,  2.4925,  1.6129,  1.0690,  0.6932,  0.3977,

timestep_spacing=linspace, interpolation_type=log_linear

tensor([14.6146,  6.7190,  3.0890,  1.4201,  0.6529,  0.3002,  0.1380,  0.0634,
         0.0292,  0.0000])

timestep_spacing=leading, interpolation_type=log_linear

tensor([14.6146,  6.7190,  3.0890,  1.4201,  0.6529,  0.3002,  0.1380,  0.0634,
         0.0292,  0.0000])

timestep_spacing=trailing, interpolation_type=log_linear

tensor([14.6146,  6.7190,  3.0890,  1.4201,  0.6529,  0.3002,  0.1380,  0.0634,
         0.0292,  0.0000])


tensor([14.6146,  7.2974,  4.0817,  2.4925,  1.6129,  1.0690,  0.6932,  0.3977,


tensor([14.6146,  7.3020,  4.0861,  2.4960,  1.6156,  1.0712,  0.6952,  0.3997,


tensor([14.6146,  6.0130,  2.4740,  1.0179,  0.4188,  0.1723,  0.0709,  0.0292,


tensor([14.6147,  6.0130,  2.4740,  1.0179,  0.4188,  0.1723,  0.0709,  0.0292,


tensor([14.6146, 13.5770, 11.4518,  8.7596,  5.8842,  3.1920,  1.0668,  0.0292,

15 steps

timestep_spacing=linspace, interpolation_type=linear

tensor([14.6146,  9.6826,  6.6780,  4.7746,  3.5221,  2.6666,  2.0606,  1.6156,
         1.2768,  1.0097,  0.7913,  0.6056,  0.4397,  0.2780,  0.0292,  0.0000])

timestep_spacing=leading, interpolation_type=linear

tensor([9.5436, 6.7684, 4.9510, 3.7216, 2.8629, 2.2441, 1.7841, 1.4316, 1.1530,
        0.9261, 0.7353, 0.5693, 0.4179, 0.2677, 0.0413, 0.0000])

timestep_spacing=trailing, interpolation_type=linear

tensor([14.6146,  9.9172,  7.0089,  5.0878,  3.7997,  2.9183,  2.2765,  1.8024,
         1.4458,  1.1606,  0.9292,  0.7380,  0.5693,  0.4156,  0.2653,  0.0000])

timestep_spacing=linspace, interpolation_type=log_linear

tensor([14.6146,  9.6560,  6.3797,  4.2151,  2.7850,  1.8400,  1.2157,  0.8032,
         0.5307,  0.3506,  0.2317,  0.1531,  0.1011,  0.0668,  0.0441,  0.0292,

timestep_spacing=leading, interpolation_type=log_linear

tensor([14.6146,  9.6560,  6.3797,  4.2151,  2.7850,  1.8400,  1.2157,  0.8032,
         0.5307,  0.3506,  0.2317,  0.1531,  0.1011,  0.0668,  0.0441,  0.0292,

timestep_spacing=trailing, interpolation_type=log_linear

tensor([14.6146,  9.6560,  6.3797,  4.2151,  2.7850,  1.8400,  1.2157,  0.8032,
         0.5307,  0.3506,  0.2317,  0.1531,  0.1011,  0.0668,  0.0441,  0.0292,


tensor([14.6146,  9.9720,  7.0089,  5.0878,  3.8155,  2.9183,  2.2765,  1.8085,
         1.4458,  1.1606,  0.9324,  0.7380,  0.5693,  0.4179,  0.2653,  0.0000])


tensor([14.6146,  9.9391,  7.0019,  5.0924,  3.8092,  2.9183,  2.2797,  1.8073,
         1.4467,  1.1629,  0.9324,  0.7391,  0.5712,  0.4183,  0.2667,  0.0000])


tensor([14.6146,  9.3743,  6.0130,  3.8569,  2.4740,  1.5869,  1.0179,  0.6529,
         0.4188,  0.2686,  0.1723,  0.1105,  0.0709,  0.0455,  0.0292,  0.0000])


tensor([14.6147,  9.3743,  6.0130,  3.8569,  2.4740,  1.5869,  1.0179,  0.6529,
         0.4188,  0.2686,  0.1723,  0.1105,  0.0709,  0.0455,  0.0292,  0.0000])


tensor([14.6146, 14.2837, 13.5770, 12.6113, 11.4518, 10.1517,  8.7596,  7.3219,
         5.8842,  4.4921,  3.1920,  2.0325,  1.0668,  0.3601,  0.0292,  0.0000])

30 steps

timestep_spacing=linspace, interpolation_type=linear

tensor([14.6146, 11.9176,  9.8142,  8.1585,  6.8431,  5.7886,  4.9356,  4.2397,
         3.6669,  3.1913,  2.7931,  2.4569,  2.1705,  1.9246,  1.7116,  1.5257,
         1.3619,  1.2166,  1.0865,  0.9691,  0.8622,  0.7640,  0.6730,  0.5877,
         0.5067,  0.4286,  0.3515,  0.2722,  0.1835,  0.0292,  0.0000])

timestep_spacing=leading, interpolation_type=linear

tensor([11.4769,  9.5436,  8.0043,  6.7684,  5.7678,  4.9510,  4.2790,  3.7216,
         3.2556,  2.8629,  2.5295,  2.2441,  1.9980,  1.7841,  1.5968,  1.4316,
         1.2846,  1.1530,  1.0342,  0.9261,  0.8270,  0.7353,  0.6499,  0.5693,
         0.4924,  0.4179,  0.3439,  0.2677,  0.1822,  0.0413,  0.0000])

timestep_spacing=trailing, interpolation_type=linear

tensor([14.6146, 12.0177,  9.9172,  8.3028,  7.0089,  5.9347,  5.0878,  4.3919,
         3.7997,  3.3211,  2.9183,  2.5671,  2.2765,  2.0260,  1.8024,  1.6129,
         1.4458,  1.2931,  1.1606,  1.0410,  0.9292,  0.8299,  0.7380,  0.6499,
         0.5693,  0.4924,  0.4156,  0.3417,  0.2653,  0.1763,  0.0000])

timestep_spacing=linspace, interpolation_type=log_linear

tensor([14.6146, 11.8793,  9.6560,  7.8487,  6.3797,  5.1857,  4.2151,  3.4262,
         2.7850,  2.2637,  1.8400,  1.4956,  1.2157,  0.9882,  0.8032,  0.6529,
         0.5307,  0.4314,  0.3506,  0.2850,  0.2317,  0.1883,  0.1531,  0.1244,
         0.1011,  0.0822,  0.0668,  0.0543,  0.0441,  0.0359,  0.0292,  0.0000])

timestep_spacing=leading, interpolation_type=log_linear

tensor([14.6146, 11.8793,  9.6560,  7.8487,  6.3797,  5.1857,  4.2151,  3.4262,
         2.7850,  2.2637,  1.8400,  1.4956,  1.2157,  0.9882,  0.8032,  0.6529,
         0.5307,  0.4314,  0.3506,  0.2850,  0.2317,  0.1883,  0.1531,  0.1244,
         0.1011,  0.0822,  0.0668,  0.0543,  0.0441,  0.0359,  0.0292,  0.0000])

timestep_spacing=trailing, interpolation_type=log_linear

tensor([14.6146, 11.8793,  9.6560,  7.8487,  6.3797,  5.1857,  4.2151,  3.4262,
         2.7850,  2.2637,  1.8400,  1.4956,  1.2157,  0.9882,  0.8032,  0.6529,
         0.5307,  0.4314,  0.3506,  0.2850,  0.2317,  0.1883,  0.1531,  0.1244,
         0.1011,  0.0822,  0.0668,  0.0543,  0.0441,  0.0359,  0.0292,  0.0000])


tensor([14.6146, 12.0177,  9.9720,  8.3028,  7.0089,  5.9631,  5.0878,  4.3919,
         3.8155,  3.3211,  2.9183,  2.5767,  2.2765,  2.0260,  1.8085,  1.6129,
         1.4458,  1.2973,  1.1606,  1.0410,  0.9324,  0.8299,  0.7380,  0.6524,
         0.5693,  0.4924,  0.4179,  0.3417,  0.2653,  0.1793,  0.0000])


tensor([14.6146, 11.9969,  9.9391,  8.3072,  7.0019,  5.9489,  5.0924,  4.3900,
         3.8092,  3.3251,  2.9183,  2.5738,  2.2797,  2.0267,  1.8073,  1.6156,
         1.4467,  1.2969,  1.1629,  1.0421,  0.9324,  0.8319,  0.7391,  0.6526,
         0.5712,  0.4936,  0.4183,  0.3437,  0.2667,  0.1801,  0.0000])


tensor([14.6146, 11.7947,  9.5190,  7.6823,  6.2000,  5.0037,  4.0382,  3.2591,
         2.6302,  2.1227,  1.7131,  1.3826,  1.1158,  0.9005,  0.7268,  0.5865,
         0.4734,  0.3820,  0.3083,  0.2488,  0.2008,  0.1621,  0.1308,  0.1056,
         0.0852,  0.0688,  0.0555,  0.0448,  0.0361,  0.0292,  0.0000])


tensor([14.6147, 11.7948,  9.5190,  7.6823,  6.2000,  5.0037,  4.0382,  3.2591,
         2.6302,  2.1227,  1.7131,  1.3826,  1.1158,  0.9005,  0.7268,  0.5865,
         0.4734,  0.3820,  0.3083,  0.2488,  0.2008,  0.1621,  0.1308,  0.1056,
         0.0852,  0.0688,  0.0555,  0.0448,  0.0361,  0.0292,  0.0000])


tensor([14.6146, 14.5159, 14.3024, 14.0041, 13.6349, 13.2047, 12.7212, 12.1912,
        11.6212, 11.0168, 10.3837,  9.7273,  9.0529,  8.3656,  7.6707,  6.9732,
         6.2782,  5.5909,  4.9165,  4.2601,  3.6270,  3.0226,  2.4526,  1.9227,
         1.4392,  1.0089,  0.6397,  0.3414,  0.1279,  0.0292,  0.0000])


from diffusers import EulerDiscreteScheduler
import torch
import math
import numpy as np
from scipy import stats

beta_start = 0.00085
beta_end = 0.012
num_train_timesteps = 1000
betas = (
        beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32
    ** 2
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
# not flipped, contrary to diffusers
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
log_sigmas = sigmas.log()
discard_next_to_last_sigma = False

def sigma_to_t(sigma: torch.Tensor):
    log_sigma = sigma.log()
    dists = log_sigma - log_sigmas[:, None]
    low_idx =[0] - 2)
    high_idx = low_idx + 1
    low, high = log_sigmas[low_idx], log_sigmas[high_idx]
    w = (low - log_sigma) / (low - high)
    w = w.clamp(0, 1)
    t = (1 - w) * low_idx + w * high_idx
    return t.view(sigma.shape)

def t_to_sigma(t: torch.Tensor):
    t = t.float()
    low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
    log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx]
    return log_sigma.exp()

m_sigma_min, m_sigma_max = (sigmas[0].item(), sigmas[-1].item())

def append_zero(x: torch.Tensor):
    return[x, x.new_zeros([1])])

def simple_scheduler(n: int):
    sigs = []
    ss = len(sigmas) / n
    for x in range(n):
        sigs += [float(sigmas[-(1 + int(x * ss))])]
    sigs += [0.0]
    return torch.FloatTensor(sigs)

def sgm_uniform(n: int, sigma_min: float, sigma_max: float):
    start = sigma_to_t(torch.tensor(sigma_max))
    end = sigma_to_t(torch.tensor(sigma_min))
    sigs = [t_to_sigma(ts) for ts in torch.linspace(start, end, n)[:-1]]
    sigs += [0.0]
    return torch.FloatTensor(sigs)

def get_sigmas_polyexponential(
    n: int, sigma_min: float, sigma_max: float, rho: float = 1.0
    """Constructs an polynomial in log sigma noise schedule."""
    ramp = torch.linspace(1, 0, n) ** rho
    sigmas = torch.exp(
        ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min)
    return append_zero(sigmas)

def get_sigmas_exponential(n: int, sigma_min: float, sigma_max: float):
    sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n).exp()
    return append_zero(sigmas)

def beta_scheduler(
    n: int, sigma_min: float, sigma_max: float, alpha: float = 0.6, beta: float = 0.6
    # From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024) """
    timesteps = 1 - np.linspace(0, 1, n)
    timesteps = [stats.beta.ppf(x, alpha, beta) for x in timesteps]
    sigmas = [sigma_min + (x * (sigma_max - sigma_min)) for x in timesteps]
    sigmas += [0.0]
    return torch.FloatTensor(sigmas)

def diffusers_scheduler(
    num_inference_steps: int, timestep_spacing: str, interpolation_type: str
    scheduler: EulerDiscreteScheduler = EulerDiscreteScheduler.from_pretrained(
        f"### timestep_spacing={timestep_spacing}, interpolation_type={interpolation_type}"

def simple(num_inference_steps: int):
    simple_sigmas = simple_scheduler(num_inference_steps)
    print(f"### simple")

def sgm(num_inference_steps: int):
    sgm_uniform_sigmas = sgm_uniform(
        n=num_inference_steps + (1 if not discard_next_to_last_sigma else 0),
    print(f"### sgm_uniform")

def exponential(num_inference_steps: int):
    exponential_sigmas = get_sigmas_exponential(
        num_inference_steps, m_sigma_min, m_sigma_max
    print(f"### exponential")

def polyexponential(num_inference_steps: int):
    polyexponential_sigmas = get_sigmas_polyexponential(
        num_inference_steps, m_sigma_min, m_sigma_max
    print(f"### polyexponential")

def beta(num_inference_steps: int):
    beta_sigmas = beta_scheduler(
        num_inference_steps, m_sigma_min, m_sigma_max, alpha=0.6, beta=0.6
    print(f"### beta")

def sigmas_for_steps(num_inference_steps: int):
    print(f"## {steps}")
        num_inference_steps, timestep_spacing="linspace", interpolation_type="linear"
        num_inference_steps, timestep_spacing="leading", interpolation_type="linear"
        num_inference_steps, timestep_spacing="trailing", interpolation_type="linear"
        num_inference_steps, timestep_spacing="leading", interpolation_type="log_linear"

for steps in [4, 8, 15, 30]:

cc @asomoza

asomoza commented 2 weeks ago

@yiyixuxu WDYT?

yiyixuxu commented 1 week ago

thanks a lot for doing this let's close the "scheduler gap"! 😅

hlky commented 2 days ago

Exponential and beta are now merged so I think we can consider this issue resolved. As mentioned polyexponential matches exponential unless the rho value is changed from the default and I don't have any stats on how common it is for users to actually change that value in the webuis.