huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.33k stars 5.25k forks source link

[BUG] SNR gamma in v_prediction #9451

Open LinB203 opened 1 week ago

LinB203 commented 1 week ago

Describe the bug

I believe the SNR weighting of v_prediction should follow a similar trend as eps, otherwise, for T>600, the model learns almost nothing as the weight approaches zero. If I am wrong, please correct me. Thank you!

image

Reproduction

import torch
from diffusers import DDPMScheduler
from torch.nn import functional as F
from diffusers.training_utils import compute_snr
from matplotlib import pyplot as plt

def get_weight(noise_scheduler):
    snr = compute_snr(noise_scheduler, timesteps)
    mse_loss_weights = torch.stack([snr, snr_gamma * torch.ones_like(timesteps)], dim=1).min(
        dim=1
    )[0]
    if noise_scheduler.config.prediction_type == "epsilon":
        print('prediction_type == "epsilon"')
        mse_loss_weights = mse_loss_weights / snr
    elif noise_scheduler.config.prediction_type == "v_prediction":
        print('prediction_type == "v_prediction"')
        mse_loss_weights = mse_loss_weights / (snr + 1)
    print(f'timesteps {timesteps}\nsnr {snr}\nmse_loss_weights {mse_loss_weights}\n')
    return mse_loss_weights

noise_scheduler = DDPMScheduler(rescale_betas_zero_snr=True)
noise_scheduler_vpred = DDPMScheduler(prediction_type="v_prediction", rescale_betas_zero_snr=True)
timesteps = torch.arange(0, 1000, 5)
b = len(timesteps)
snr_gamma = 5.0

data = get_weight(noise_scheduler)
data_vpred = get_weight(noise_scheduler_vpred)

plt.figure()
plt.plot(timesteps, data, label='eps')
plt.plot(timesteps, data_vpred, label='vpred')
plt.legend()
plt.show()

Logs

prediction_type == "epsilon"
timesteps tensor([  0,   5,  10,  15,  20,  25,  30,  35,  40,  45,  50,  55,  60,  65,
         70,  75,  80,  85,  90,  95, 100, 105, 110, 115, 120, 125, 130, 135,
        140, 145, 150, 155, 160, 165, 170, 175, 180, 185, 190, 195, 200, 205,
        210, 215, 220, 225, 230, 235, 240, 245, 250, 255, 260, 265, 270, 275,
        280, 285, 290, 295, 300, 305, 310, 315, 320, 325, 330, 335, 340, 345,
        350, 355, 360, 365, 370, 375, 380, 385, 390, 395, 400, 405, 410, 415,
        420, 425, 430, 435, 440, 445, 450, 455, 460, 465, 470, 475, 480, 485,
        490, 495, 500, 505, 510, 515, 520, 525, 530, 535, 540, 545, 550, 555,
        560, 565, 570, 575, 580, 585, 590, 595, 600, 605, 610, 615, 620, 625,
        630, 635, 640, 645, 650, 655, 660, 665, 670, 675, 680, 685, 690, 695,
        700, 705, 710, 715, 720, 725, 730, 735, 740, 745, 750, 755, 760, 765,
        770, 775, 780, 785, 790, 795, 800, 805, 810, 815, 820, 825, 830, 835,
        840, 845, 850, 855, 860, 865, 870, 875, 880, 885, 890, 895, 900, 905,
        910, 915, 920, 925, 930, 935, 940, 945, 950, 955, 960, 965, 970, 975,
        980, 985, 990, 995])
snr tensor([9.9973e+03, 1.1057e+03, 4.5213e+02, 2.4851e+02, 1.5763e+02, 1.0899e+02,
        7.9858e+01, 6.1013e+01, 4.8113e+01, 3.8892e+01, 3.2070e+01, 2.6882e+01,
        2.2843e+01, 1.9637e+01, 1.7050e+01, 1.4932e+01, 1.3176e+01, 1.1704e+01,
        1.0458e+01, 9.3937e+00, 8.4778e+00, 7.6838e+00, 6.9911e+00, 6.3831e+00,
        5.8466e+00, 5.3708e+00, 4.9469e+00, 4.5678e+00, 4.2272e+00, 3.9202e+00,
        3.6426e+00, 3.3906e+00, 3.1614e+00, 2.9522e+00, 2.7607e+00, 2.5852e+00,
        2.4238e+00, 2.2752e+00, 2.1379e+00, 2.0110e+00, 1.8934e+00, 1.7842e+00,
        1.6827e+00, 1.5882e+00, 1.5001e+00, 1.4178e+00, 1.3408e+00, 1.2688e+00,
        1.2013e+00, 1.1379e+00, 1.0784e+00, 1.0224e+00, 9.6976e-01, 9.2013e-01,
        8.7334e-01, 8.2919e-01, 7.8749e-01, 7.4807e-01, 7.1080e-01, 6.7552e-01,
        6.4211e-01, 6.1046e-01, 5.8045e-01, 5.5198e-01, 5.2496e-01, 4.9930e-01,
        4.7493e-01, 4.5177e-01, 4.2976e-01, 4.0882e-01, 3.8890e-01, 3.6995e-01,
        3.5191e-01, 3.3473e-01, 3.1838e-01, 3.0279e-01, 2.8795e-01, 2.7380e-01,
        2.6032e-01, 2.4748e-01, 2.3523e-01, 2.2356e-01, 2.1243e-01, 2.0181e-01,
        1.9170e-01, 1.8205e-01, 1.7285e-01, 1.6408e-01, 1.5572e-01, 1.4776e-01,
        1.4016e-01, 1.3292e-01, 1.2602e-01, 1.1944e-01, 1.1317e-01, 1.0721e-01,
        1.0152e-01, 9.6104e-02, 9.0948e-02, 8.6040e-02, 8.1369e-02, 7.6924e-02,
        7.2695e-02, 6.8674e-02, 6.4850e-02, 6.1216e-02, 5.7762e-02, 5.4481e-02,
        5.1365e-02, 4.8407e-02, 4.5600e-02, 4.2937e-02, 4.0412e-02, 3.8017e-02,
        3.5749e-02, 3.3599e-02, 3.1564e-02, 2.9638e-02, 2.7815e-02, 2.6092e-02,
        2.4463e-02, 2.2923e-02, 2.1470e-02, 2.0097e-02, 1.8802e-02, 1.7581e-02,
        1.6430e-02, 1.5345e-02, 1.4324e-02, 1.3363e-02, 1.2459e-02, 1.1608e-02,
        1.0809e-02, 1.0059e-02, 9.3548e-03, 8.6941e-03, 8.0747e-03, 7.4943e-03,
        6.9507e-03, 6.4420e-03, 5.9662e-03, 5.5214e-03, 5.1059e-03, 4.7180e-03,
        4.3562e-03, 4.0188e-03, 3.7044e-03, 3.4118e-03, 3.1395e-03, 2.8863e-03,
        2.6511e-03, 2.4328e-03, 2.2302e-03, 2.0424e-03, 1.8685e-03, 1.7075e-03,
        1.5587e-03, 1.4212e-03, 1.2942e-03, 1.1771e-03, 1.0692e-03, 9.6990e-04,
        8.7856e-04, 7.9463e-04, 7.1759e-04, 6.4697e-04, 5.8229e-04, 5.2313e-04,
        4.6907e-04, 4.1975e-04, 3.7481e-04, 3.3391e-04, 2.9675e-04, 2.6303e-04,
        2.3249e-04, 2.0488e-04, 1.7996e-04, 1.5751e-04, 1.3733e-04, 1.1924e-04,
        1.0305e-04, 8.8608e-05, 7.5763e-05, 6.4374e-05, 5.4313e-05, 4.5460e-05,
        3.7706e-05, 3.0949e-05, 2.5096e-05, 2.0059e-05, 1.5762e-05, 1.2129e-05,
        9.0954e-06, 6.5987e-06, 4.5830e-06, 2.9970e-06, 1.7937e-06, 9.3007e-07,
        3.6714e-07, 6.9281e-08])
mse_loss_weights tensor([5.0013e-04, 4.5218e-03, 1.1059e-02, 2.0120e-02, 3.1720e-02, 4.5877e-02,
        6.2611e-02, 8.1950e-02, 1.0392e-01, 1.2856e-01, 1.5591e-01, 1.8600e-01,
        2.1889e-01, 2.5462e-01, 2.9326e-01, 3.3486e-01, 3.7949e-01, 4.2721e-01,
        4.7811e-01, 5.3227e-01, 5.8978e-01, 6.5072e-01, 7.1520e-01, 7.8332e-01,
        8.5520e-01, 9.3096e-01, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00])

prediction_type == "v_prediction"
timesteps tensor([  0,   5,  10,  15,  20,  25,  30,  35,  40,  45,  50,  55,  60,  65,
         70,  75,  80,  85,  90,  95, 100, 105, 110, 115, 120, 125, 130, 135,
        140, 145, 150, 155, 160, 165, 170, 175, 180, 185, 190, 195, 200, 205,
        210, 215, 220, 225, 230, 235, 240, 245, 250, 255, 260, 265, 270, 275,
        280, 285, 290, 295, 300, 305, 310, 315, 320, 325, 330, 335, 340, 345,
        350, 355, 360, 365, 370, 375, 380, 385, 390, 395, 400, 405, 410, 415,
        420, 425, 430, 435, 440, 445, 450, 455, 460, 465, 470, 475, 480, 485,
        490, 495, 500, 505, 510, 515, 520, 525, 530, 535, 540, 545, 550, 555,
        560, 565, 570, 575, 580, 585, 590, 595, 600, 605, 610, 615, 620, 625,
        630, 635, 640, 645, 650, 655, 660, 665, 670, 675, 680, 685, 690, 695,
        700, 705, 710, 715, 720, 725, 730, 735, 740, 745, 750, 755, 760, 765,
        770, 775, 780, 785, 790, 795, 800, 805, 810, 815, 820, 825, 830, 835,
        840, 845, 850, 855, 860, 865, 870, 875, 880, 885, 890, 895, 900, 905,
        910, 915, 920, 925, 930, 935, 940, 945, 950, 955, 960, 965, 970, 975,
        980, 985, 990, 995])
snr tensor([9.9973e+03, 1.1057e+03, 4.5213e+02, 2.4851e+02, 1.5763e+02, 1.0899e+02,
        7.9858e+01, 6.1013e+01, 4.8113e+01, 3.8892e+01, 3.2070e+01, 2.6882e+01,
        2.2843e+01, 1.9637e+01, 1.7050e+01, 1.4932e+01, 1.3176e+01, 1.1704e+01,
        1.0458e+01, 9.3937e+00, 8.4778e+00, 7.6838e+00, 6.9911e+00, 6.3831e+00,
        5.8466e+00, 5.3708e+00, 4.9469e+00, 4.5678e+00, 4.2272e+00, 3.9202e+00,
        3.6426e+00, 3.3906e+00, 3.1614e+00, 2.9522e+00, 2.7607e+00, 2.5852e+00,
        2.4238e+00, 2.2752e+00, 2.1379e+00, 2.0110e+00, 1.8934e+00, 1.7842e+00,
        1.6827e+00, 1.5882e+00, 1.5001e+00, 1.4178e+00, 1.3408e+00, 1.2688e+00,
        1.2013e+00, 1.1379e+00, 1.0784e+00, 1.0224e+00, 9.6976e-01, 9.2013e-01,
        8.7334e-01, 8.2919e-01, 7.8749e-01, 7.4807e-01, 7.1080e-01, 6.7552e-01,
        6.4211e-01, 6.1046e-01, 5.8045e-01, 5.5198e-01, 5.2496e-01, 4.9930e-01,
        4.7493e-01, 4.5177e-01, 4.2976e-01, 4.0882e-01, 3.8890e-01, 3.6995e-01,
        3.5191e-01, 3.3473e-01, 3.1838e-01, 3.0279e-01, 2.8795e-01, 2.7380e-01,
        2.6032e-01, 2.4748e-01, 2.3523e-01, 2.2356e-01, 2.1243e-01, 2.0181e-01,
        1.9170e-01, 1.8205e-01, 1.7285e-01, 1.6408e-01, 1.5572e-01, 1.4776e-01,
        1.4016e-01, 1.3292e-01, 1.2602e-01, 1.1944e-01, 1.1317e-01, 1.0721e-01,
        1.0152e-01, 9.6104e-02, 9.0948e-02, 8.6040e-02, 8.1369e-02, 7.6924e-02,
        7.2695e-02, 6.8674e-02, 6.4850e-02, 6.1216e-02, 5.7762e-02, 5.4481e-02,
        5.1365e-02, 4.8407e-02, 4.5600e-02, 4.2937e-02, 4.0412e-02, 3.8017e-02,
        3.5749e-02, 3.3599e-02, 3.1564e-02, 2.9638e-02, 2.7815e-02, 2.6092e-02,
        2.4463e-02, 2.2923e-02, 2.1470e-02, 2.0097e-02, 1.8802e-02, 1.7581e-02,
        1.6430e-02, 1.5345e-02, 1.4324e-02, 1.3363e-02, 1.2459e-02, 1.1608e-02,
        1.0809e-02, 1.0059e-02, 9.3548e-03, 8.6941e-03, 8.0747e-03, 7.4943e-03,
        6.9507e-03, 6.4420e-03, 5.9662e-03, 5.5214e-03, 5.1059e-03, 4.7180e-03,
        4.3562e-03, 4.0188e-03, 3.7044e-03, 3.4118e-03, 3.1395e-03, 2.8863e-03,
        2.6511e-03, 2.4328e-03, 2.2302e-03, 2.0424e-03, 1.8685e-03, 1.7075e-03,
        1.5587e-03, 1.4212e-03, 1.2942e-03, 1.1771e-03, 1.0692e-03, 9.6990e-04,
        8.7856e-04, 7.9463e-04, 7.1759e-04, 6.4697e-04, 5.8229e-04, 5.2313e-04,
        4.6907e-04, 4.1975e-04, 3.7481e-04, 3.3391e-04, 2.9675e-04, 2.6303e-04,
        2.3249e-04, 2.0488e-04, 1.7996e-04, 1.5751e-04, 1.3733e-04, 1.1924e-04,
        1.0305e-04, 8.8608e-05, 7.5763e-05, 6.4374e-05, 5.4313e-05, 4.5460e-05,
        3.7706e-05, 3.0949e-05, 2.5096e-05, 2.0059e-05, 1.5762e-05, 1.2129e-05,
        9.0954e-06, 6.5987e-06, 4.5830e-06, 2.9970e-06, 1.7937e-06, 9.3007e-07,
        3.6714e-07, 6.9281e-08])
mse_loss_weights tensor([5.0008e-04, 4.5177e-03, 1.1034e-02, 2.0039e-02, 3.1520e-02, 4.5460e-02,
        6.1837e-02, 8.0628e-02, 1.0181e-01, 1.2534e-01, 1.5119e-01, 1.7933e-01,
        2.0971e-01, 2.4229e-01, 2.7701e-01, 3.1384e-01, 3.5272e-01, 3.9358e-01,
        4.3639e-01, 4.8106e-01, 5.2755e-01, 5.7578e-01, 6.2570e-01, 6.7722e-01,
        7.3029e-01, 7.8483e-01, 8.3185e-01, 8.2039e-01, 8.0869e-01, 7.9676e-01,
        7.8460e-01, 7.7224e-01, 7.5969e-01, 7.4697e-01, 7.3410e-01, 7.2108e-01,
        7.0793e-01, 6.9467e-01, 6.8132e-01, 6.6788e-01, 6.5438e-01, 6.4083e-01,
        6.2724e-01, 6.1363e-01, 6.0001e-01, 5.8640e-01, 5.7280e-01, 5.5924e-01,
        5.4572e-01, 5.3225e-01, 5.1886e-01, 5.0555e-01, 4.9232e-01, 4.7920e-01,
        4.6620e-01, 4.5331e-01, 4.4056e-01, 4.2794e-01, 4.1548e-01, 4.0317e-01,
        3.9103e-01, 3.7906e-01, 3.6727e-01, 3.5566e-01, 3.4424e-01, 3.3302e-01,
        3.2200e-01, 3.1119e-01, 3.0058e-01, 2.9019e-01, 2.8001e-01, 2.7005e-01,
        2.6031e-01, 2.5079e-01, 2.4149e-01, 2.3242e-01, 2.2357e-01, 2.1495e-01,
        2.0655e-01, 1.9838e-01, 1.9043e-01, 1.8271e-01, 1.7521e-01, 1.6792e-01,
        1.6086e-01, 1.5401e-01, 1.4738e-01, 1.4096e-01, 1.3474e-01, 1.2873e-01,
        1.2293e-01, 1.1732e-01, 1.1191e-01, 1.0670e-01, 1.0167e-01, 9.6825e-02,
        9.2163e-02, 8.7678e-02, 8.3366e-02, 7.9224e-02, 7.5246e-02, 7.1429e-02,
        6.7769e-02, 6.4261e-02, 6.0901e-02, 5.7684e-02, 5.4608e-02, 5.1666e-02,
        4.8856e-02, 4.6172e-02, 4.3612e-02, 4.1170e-02, 3.8842e-02, 3.6625e-02,
        3.4515e-02, 3.2507e-02, 3.0598e-02, 2.8785e-02, 2.7063e-02, 2.5429e-02,
        2.3879e-02, 2.2410e-02, 2.1018e-02, 1.9701e-02, 1.8455e-02, 1.7277e-02,
        1.6164e-02, 1.5113e-02, 1.4122e-02, 1.3187e-02, 1.2305e-02, 1.1475e-02,
        1.0694e-02, 9.9589e-03, 9.2681e-03, 8.6192e-03, 8.0100e-03, 7.4385e-03,
        6.9027e-03, 6.4007e-03, 5.9308e-03, 5.4911e-03, 5.0800e-03, 4.6959e-03,
        4.3373e-03, 4.0027e-03, 3.6908e-03, 3.4002e-03, 3.1297e-03, 2.8780e-03,
        2.6441e-03, 2.4269e-03, 2.2252e-03, 2.0382e-03, 1.8650e-03, 1.7046e-03,
        1.5563e-03, 1.4191e-03, 1.2925e-03, 1.1757e-03, 1.0681e-03, 9.6896e-04,
        8.7778e-04, 7.9399e-04, 7.1708e-04, 6.4655e-04, 5.8195e-04, 5.2285e-04,
        4.6885e-04, 4.1958e-04, 3.7467e-04, 3.3380e-04, 2.9666e-04, 2.6296e-04,
        2.3244e-04, 2.0484e-04, 1.7992e-04, 1.5748e-04, 1.3731e-04, 1.1922e-04,
        1.0304e-04, 8.8601e-05, 7.5758e-05, 6.4370e-05, 5.4310e-05, 4.5458e-05,
        3.7705e-05, 3.0948e-05, 2.5095e-05, 2.0059e-05, 1.5761e-05, 1.2129e-05,
        9.0953e-06, 6.5986e-06, 4.5830e-06, 2.9970e-06, 1.7937e-06, 9.3007e-07,
        3.6714e-07, 6.9281e-08])

System Info

None

Who can help?

@yiyixuxu @asomoza @sayakpaul

sayakpaul commented 1 week ago

Indeed. I have very seldom trained with v-pred so, didn't explore this setting too much. Do you maybe wanna open a PR?

heart-du commented 4 days ago

Does it have any problem?

bghira commented 4 days ago

@drhead cc

drhead commented 4 days ago

So, the intent of min-snr-gamma is to balance how important each timestep is, so that they are treated as equally important.

Min-snr-gamma attempts to do this with a fixed formula, but you can also do this by setting up trainable parameters to control timestep weighting, which the EDM2 paper did (https://arxiv.org/pdf/2312.02696 pp. 19-22) , based on earlier work by Kendall et. al (https://arxiv.org/abs/1705.07115). This is a much more flexible approach, since you effectively don't have to do anything except ensure your model trains for long enough for the weights to settle and you can effectively guarantee that the parameters will settle at a point where the average training loss of each timestep is equalized in terms of how much it contributes to gradients.

I do have a training run where I am doing this right now, on v-prediction, and this is what the timestep weightings look like currently: image

I would say that this is evidence that the v-prediction min-snr-gamma curve is at least more correct than one looking more similar to the epsilon one, in that there's a peak around timestep 200-300 and lower weights towards the ends. But we can also see that the weights of the tails are definitely not zero, so I would say you are right to suspect that the tails of min snr gamma approaching zero is not ideal.

I can't really nail down a fixed formula that would represent the type of curve that I've gotten from the homoscedactic uncertainty method, but I do know that the intensity of the peak increases along with the size of the latent, so that would need to be accounted for: image

I do have a suspicion that this is related to what was pointed out in Hoogeboom et al. (https://arxiv.org/pdf/2301.11093) about high resolution models needing more noise to fully destroy the signal.

In conclusion though re: min-snr-gamma, the formula is correct as implemented if the curve is that shape, that's what the formula described in the paper for v-prediction looks like. There's arguably better methods and room for those to be implemented, but they'd have to be separate, there's nothing to really "fix" with min-snr-gamma.