huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.98k stars 5.35k forks source link

DPM++ leaves residual noise in SDXL images, which is unexpected #5689

Open nhnt11 opened 11 months ago

nhnt11 commented 11 months ago

Describe the bug

When using DPM++ with SDXL, there is residual noise in the result.

Upon investigation, my colleague @CodeCorrupt and I found that the final sigma was non-zero which we empirically deemed to be the culprit - i.e. when we set the final sigma to 0 artificially, the images were clean.

Reproduction

import torch
from diffusers import StableDiffusionXLPipeline
from typing import cast
from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline, 

sdxl_model = cast(StableDiffusionXLPipeline, StableDiffusionXLPipeline.from_pretrained(
    'stabilityai/stable-diffusion-xl-base-1.0',
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16",
    revision="76d28af79639c28a79fa5c6c6468febd3490a37e",
)).to('cuda')
common_config = {'beta_start': 0.00085, 'beta_end': 0.012, 'beta_schedule': 'scaled_linear'}
scheduler =  DPMSolverMultistepScheduler(**common_config)

generator = torch.Generator(device='cuda')
sdxl_model.scheduler = scheduler
generator.manual_seed(96404382)
params = {
    'prompt': ['clorful abstract background, glowing gradient waves on black background'],
    'negative_prompt': [''],
    "num_inference_steps": 40,
    "guidance_scale": 8,
    "width": 1024,
    "height": 1024
}
sdxl_res = sdxl_model(**params, generator=generator, output_type='pil')
sdxl_img = sdxl_res.images[0]
display(sdxl_img)

Output with diffusers main: image

Output with a hack to set the last sigma to 0: image

Logs

No response

System Info

Who can help?

@yiyixuxu @patrickvonplaten

nhnt11 commented 11 months ago

Here is the code where we manually set the last sigma to 0: https://github.com/playgroundai/diffusers/commit/3b4a2fd460bdf7b58e2277c663dbf62012a39df8

bghira commented 11 months ago

well that wouldn't be DPM++ then, now would it? :D it's a different scheduler you made. you probably want to rescale the betas.

nhnt11 commented 11 months ago

@bghira Wouldn't the DPM++ implementation already technically not be DPM++ by that definition if you use the euler_at_final flag? I think there are several approximations in place to increase numerical stability of these schedulers that admittedly compromise the theory.

Note that for many other samplers, diffusers explicitly appends a 0 sigma. My understanding is that this is necessary due to the way the diffusion loop is implemented: for every loop iteration, the unet pass happens before stepping the scheduler, and so the final iteration will step the scheduler before returning the result. If the scheduler adds noise in that final step, you get a noisy result.

Opinion: I think it's a valid tradeoff to compromise the theory a bit when implementing math in finite precision computing environments :)

Disclaimer: very new at all this so just thinking out loud and learning :)

yiyixuxu commented 11 months ago

as discussed here you can use use_lu_lambdas and euler_at_final to avoid this issue https://github.com/huggingface/diffusers/pull/5541

thanks :)

nhnt11 commented 11 months ago

Hey @yiyixuxu, thanks for the pointer - but we're already using those. For clarity - we needed to add this hack to zero out the last sigma ON TOP of the euler_at_final option (or with Lu lambdas, or with Karras sigmas).

spezialspezial commented 10 months ago

It brings a bit of relief but using lu_lambdas and euler_at_final does not solve this issue. Thanks to @nhnt11 for the investigation!

mar-muel commented 10 months ago

@nhnt11 Interesting! But is this issue maybe not related to poor performance of DPM++ for SDXL in general?

For example, in Automatic1111 (based on the k-diffusion repo) I'm seeing similar artefacting for the DPM++ 2M scheduler: 00092-42

"colorful abstract background, glowing gradient waves on black background"
Steps: 20, Sampler: DPM++ 2M, CFG scale: 7.5, Seed: 42, Size: 1024x1024, Model hash: 31e35c80fc, Model: sd_xl_base_1.0, Version: v1.7.0

But then using DPM++ 2M Karras gives 00093-42

"colorful abstract background, glowing gradient waves on black background"
Steps: 20, Sampler: DPM++ 2M Karras, CFG scale: 7.5, Seed: 42, Size: 1024x1024, Model hash: 31e35c80fc, Model: sd_xl_base_1.0, Version: v1.7.0

Similar improvements can be obtained by setting use_karras_sigmas=True in diffusers.

I agree with you that use_lu_lambdas/euler_at_final are more of an extension to DPM++ and not really a "fix" for it.

EDIT: I tried this again today, and it indeed improves results!

github-actions[bot] commented 9 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

patrickvonplaten commented 9 months ago

This PR should have also improved the quality: https://github.com/huggingface/diffusers/pull/6477

bghira commented 9 months ago

it has not

github-actions[bot] commented 8 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

yiyixuxu commented 8 months ago

@bghira let us know if you have any suggestions to further improve the DPM scheduler for SDXL

github-actions[bot] commented 7 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.