crowsonkb / k-diffusion

Karras et al. (2022) diffusion models for PyTorch
MIT License
2.21k stars 371 forks source link

sample_dpmpp_2m has a bug? #56

Open hallatore opened 1 year ago

hallatore commented 1 year ago

Hi,

I've been playing around with the sample_dpmpp_2m sampling and found that swapping one variable changes/fixes blur. I don't know the math formula for this, so I might be wrong. But I think there might be a bug in the code? Let me know what you think. And if you want me to create a PR for it.

Here are my results

hallatore commented 1 year ago

Here is an example testing low steps.

xyz_grid-0023-3849107070

hallatore commented 1 year ago

Here is a version that works with DPM++ 2M. At least I seem to get pretty good results with it.

xyz_grid-0031-3849107065

And with "Always discard next-to-last sigma" turned OFF

xyz_grid-0030-3849107065

At 10 steps: https://imgsli.com/MTYxMjc5

@torch.no_grad()
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
    """DPM-Solver++(2M)."""
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    sigma_fn = lambda t: t.neg().exp()
    t_fn = lambda sigma: sigma.log().neg()
    old_denoised = None

    for i in trange(len(sigmas) - 1, disable=disable):
        denoised = model(x, sigmas[i] * s_in, **extra_args)
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
        t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
        h = t_next - t

        if old_denoised is None or sigmas[i + 1] == 0:
            x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
        else:
            h_last = t - t_fn(sigmas[i - 1])

            h_min = min(h_last, h)
            h_max = max(h_last, h)
            r = h_max / h_min

            h_d = (h_max + h_min) / 2
            denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
            x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h_d).expm1() * denoised_d

        old_denoised = denoised
    return x
hallatore commented 1 year ago

Some tests on human faces

xyz_grid-0003-954420047

2blackbar commented 1 year ago

will this get adressed or you just guys moved on?

wywywywy commented 1 year ago

Any comment on this @crowsonkb? Are you still maintaining this repo or should we just fork it?

pbaylies commented 1 year ago

@hallatore yes, please put up a PR for it; thank you!

crowsonkb commented 1 year ago

I need to look at this but I'm sick right now so it might be a few days

ClashSAN commented 1 year ago

@crowsonkb Take care! 🙏

Birch-san commented 1 year ago

to aid understanding, here's what the diff looks like:

image

it changes the second-order step.

when we compute r: we no longer take the ratio "h_last over h".
instead: we computer as the ratio "the greater of (h_last, h)" over "the smaller of (h_last, h)".

when computing x: we no longer use an (-h).expm1() term, but rather replace h with "the average of (h_last, h)".

@hallatore does that sound like a correct description of your technique?
can you explain how you came to this algorithm? is it a more faithful implemention of the paper, or is it a novel idea?
is there a requirement that r be greater than 1? it seems that's the guarantee you're trying to create?

personally: looking at the before/after samples, I'm not convinced it's "better" or more "correct" — to me it looks "different".

pbaylies commented 1 year ago

Just curious, is there any guarantee that h_min won't be zero, or close to zero, and therefore, should there be a check to make sure we have sane values of r? Judging from the samples here, it does seem that this change can help, in practice.

hallatore commented 1 year ago

A lot of the code takes for granted that h_last is always lower than h. When that is true we get a factor between 1..0. But when this is wrong we get a value above 1. I'm not sure if min/max-ing is the right approach to fixing this, but we do want to never have an r value above 1.

The other change is that h from the "current step" while denoised_d is a value between current and last step based on r. I think it makes sense that if denoised_d is a computed factor between current and last step, then h should also be computed from the same factor. Otherwise you use the current steps denoising factor on the computed denoised_d. Here i'm also unsure if the average is a good fit.

So to sum it up the changes try to address two things.

  1. In some edge cases the h_last value can be higher than h, which causes the r factor to be above 1.
  2. When multiplying h with denoised_d we use a current-step value with a computed last/current-step value. Which i'm unsure is a good way to do this.
pbaylies commented 1 year ago

If you never want an r value above one, then I'd say set that as a max; clamp the range, make sure it's in the valid range you want. And see if you can find some pathological test cases for it!

wywywywy commented 1 year ago

If you never want an r value above one, then I'd say set that as a max; clamp the range, make sure it's in the valid range you want. And see if you can find some pathological test cases for it!

Just tried r = torch.clamp(r, max=1.0), and the result is different. Not sure if it's better or worse.

Female warrior, 10 steps, seed 3013411575 without clamp

image

With clamp

image

When multiplying h with denoised_d we use a current-step value with a computed last/current-step value. Which i'm unsure is a good way to do this.

But the value of h is already h = t_next - t. Pardon my ignorance, I still don't understand why we should average it?

pbaylies commented 1 year ago

Yes, both of those look good to me...

Metachs commented 1 year ago

It looks to me like all you are seeing is faster convergence due to loss of detail. You get fewer artifacts with low numbers of steps, but the final converged image has significantly less detail, no matter how many steps you use.

The overall effect on generated images is like a soft focus or smearing vaseline on the lens, like what Star Trek would do every time a woman was on screen. It might look better in some instances, (particularly in closeups of women like the examples posted thus far), but it definitely isn't an overall improvement, this is very obvious in images of space/star fields and other noisy images.

In the following comparisons, "DPM++ 2M Test" is the modified function posted earlier in the thread, the loss of detail is extremely obvious. "DPM++ 2M TestX" is an altered version that removes the "h_d" change that averaged h & h_last, which made no sense to me. It isn't as bad, but still shows a loss of detail vs the original implementation.

xyz_grid-0102-20230424115312 325325235 Z_SD_v1-5-pruned realistic telescope imagery of space Steps: 10, Sampler: DPM++ 2M, CFG scale: 7, Seed: 325325235, Size: 512x512, Model hash: e1441589a6, Model: SD_v1-5-pruned

xyz_grid-0101-20230424114158 325325235 Z_SD_v1-5-pruned

pbaylies commented 1 year ago

@Metachs interesting; what about for higher values of CFG Scale, such as 10 or 15?

Metachs commented 1 year ago

Similar.

10 Steps xyz_grid-0108-20230424135535 325325235 Z_SD_v1-5-pruned

40 Steps xyz_grid-0109-20230424135535 325325235 Z_SD_v1-5-pruned

10 Steps xyz_grid-0111-20230424140315 325325235 Z_SD_v1-5-pruned

40 Steps xyz_grid-0112-20230424140315 325325235 Z_SD_v1-5-pruned

ride5k commented 1 year ago

In the following comparisons, "DPM++ 2M Test" is the modified function posted earlier in the thread, the loss of detail is extremely obvious. "DPM++ 2M TestX" is an altered version that removes the "h_d" change that averaged h & h_last, which made no sense to me. It isn't as bad, but still shows a loss of detail vs the original implementation.

i made the same change to h_d and prefer the result, seems a halfway point between original and OP mod.

elen07zz commented 1 year ago

In the following comparisons, "DPM++ 2M Test" is the modified function posted earlier in the thread, the loss of detail is extremely obvious. "DPM++ 2M TestX" is an altered version that removes the "h_d" change that averaged h & h_last, which made no sense to me. It isn't as bad, but still shows a loss of detail vs the original implementation.

i made the same change to h_d and prefer the result, seems a halfway point between original and OP mod.

how can i make that change