huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.98k stars 5.35k forks source link

LMS scheduler leaves a lot of noise leftover in the result image when used with SDXL Img2Img Pipeline #5630

Closed nhnt11 closed 9 months ago

nhnt11 commented 1 year ago

Describe the bug

When using the LMS scheduler with SDXL Img2Img pipeline, there is a lot of noise leftover in the image especially when strength is closer to 0. In other words, when the total number of performed steps is "low" (e.g. num_inference_steps=50 and strength=0.1), the result images are unusably noisy.

Reproduction

Here's some code that first does a prompt-to-image generation, and then an image-to-image from that result with strength =0.1. The image-to-image result looks like an intermediate latent. Note that the prompt-to-image result looks completely fine. This is reproducible with any input image - I just used a p2i gen because it felt easier to share here.

import torch
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
from typing import cast
from diffusers import LMSDiscreteScheduler

sdxl_model = cast(StableDiffusionXLPipeline, StableDiffusionXLPipeline.from_pretrained(
    'stabilityai/stable-diffusion-xl-base-1.0',
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16",
    revision="76d28af79639c28a79fa5c6c6468febd3490a37e",
)).to('cuda')
sdxl_img2img_model = cast(StableDiffusionXLImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline.from_pretrained(
    'stabilityai/stable-diffusion-xl-base-1.0',
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16",
    revision="76d28af79639c28a79fa5c6c6468febd3490a37e",
)).to('cuda')

common_config = {'beta_start': 0.00085, 'beta_end': 0.012, 'beta_schedule': 'scaled_linear'}
scheduler = LMSDiscreteScheduler(**common_config)
sdxl_model.scheduler = scheduler
sdxl_img2img_model.scheduler = scheduler

sdxl_model.watermark = None
generator = torch.Generator(device='cuda')
generator.manual_seed(12345)

params = {
    'prompt': ['evening sunset scenery blue sky nature, glass bottle with a galaxy in it'],
    'negative_prompt': ['text, watermark'],
    "negative_prompt": [''],
    "num_inference_steps": 50,
    "guidance_scale": 7,
    "width": 1024,
    "height": 1024
}
sdxl_res = sdxl_model(**params, generator=generator, output_type='pil')
sdxl_img = sdxl_res.images[0]
display(sdxl_img)

img2img_params = {
    'prompt': ['evening sunset scenery blue sky nature, glass bottle with a galaxy in it'],
    'negative_prompt': ['text, watermark'],
    "negative_prompt": [''],
    "num_inference_steps": 50,
    "guidance_scale": 7,
    "image": sdxl_img,
    "strength": 0.1
}

sdxl_img2img_res = sdxl_img2img_model(**img2img_params, generator=generator, output_type='pil')

display(sdxl_img2img_res.images[0])

Image-to-Image Result: image

Logs

No response

System Info

Who can help?

@yiyixuxu @patrickvonplaten

nhnt11 commented 1 year ago

By the way, I've tried prompt-to-image with very low step count, and it works "fine" - the images aren't great but they don't look like intermediates.

Here's a p2i result with num_inference_steps=5: image

nhnt11 commented 1 year ago

By the way, this issue also impacts prompt-to-image gens when using the refiner, since the refiner uses the img2img pipeline.

nhnt11 commented 12 months ago

OK this is based on incomplete understanding, but after a lot of reading I am suspicious that LMS is simply a very bad choice for img2img at high strength, and that there might not be any particular implementation bug.

Consider a generation with 100 steps and image strength 95. Essentially, the generation will:

If you look at the last 5 sigmas when LMS for 100 steps, this is what they look like:

Look at how small the Karras sigma values are relative to without.

My wild guess is this is related to what @LuChengTHU says about numerical stability close to t=0 for second order solvers. LMS is a fourth order sampler by default, which could explain why it's so exaggerated.

If I force the order to 1, 2, 3, and 4 (by passing it in as a param to step()), here is how the result varies: image

Very very naively, I would suggest that we do something similar for LMS as we do for DPM++ 2M with the lower_order_final thing - i.e. if we are approaching the last few timesteps, we reduce the "derivative depth" so to speak.

I am a huge noob here so just thinking out loud and learning 😄 🙏

nhnt11 commented 12 months ago

Here is a simple change which forces order = 1 for the last 15 timesteps https://github.com/playgroundai/diffusers/commit/c3a629155853591953b3830c1d87468b50956ccb

This is a proof-of-concept change and I don't know, for example, whether the order should "gradually" drop off instead of going from 4 -> 1 when there are 15 steps left.

It eliminates the noise for various values of strength when num_inference_steps=50:

image

nhnt11 commented 12 months ago

For fun, here's another approach where we always ensure to start with order=1 for the first step performed (even if the first step index > 0) and also drop off to order=1 for the final timesteps: https://github.com/playgroundai/diffusers/commit/6542f63b713db83a84a7ebe954267c52aecac3d6

Results alongside the previous approach (I think this approach yields slightly more detail):

image

nhnt11 commented 12 months ago

Oh yikes, I just saw that in the screenshots above, the strength values are wrong. The values in the screenshots are 1 - strength. So e.g. the images labeled 0.9 were actually generated with strength=0.1.

patrickvonplaten commented 11 months ago

cc @yiyixuxu

github-actions[bot] commented 10 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

patrickvonplaten commented 10 months ago

@yiyixuxu is this solved here?

yiyixuxu commented 10 months ago

Is this related to https://github.com/huggingface/diffusers/pull/6187 ?

github-actions[bot] commented 9 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.