livepeer / ai-worker

8 stars 13 forks source link

SDXL-Lightning inference steps ignored #107

Open stronk-dev opened 2 weeks ago

stronk-dev commented 2 weeks ago

Premise: this is a model which does not accept the guidance_scale param and loads a specific set of model weights according to the amount of num_inference_steps you want to do (1, 2, 4 or 8 steps).

As apps would request the ByteDance/SDXL-Lightning model, the following code would make it default to 2 steps:

https://github.com/livepeer/ai-worker/blob/0a26654cccca8501bddf4e026d18cfee6c9891b2/runner/app/pipelines/text_to_image.py#L57-L71

And then when running inference, it would override num_inference_steps to 2:

https://github.com/livepeer/ai-worker/blob/0a26654cccca8501bddf4e026d18cfee6c9891b2/runner/app/pipelines/text_to_image.py#L188-L201

Apparently apps needs to append 4step or 8step to the model ID if they want to do a different amount of num_inference_steps. This can be very confusing to app developers, who likely just request ByteDance/SDXL-Lightning with a specific number of num_inference_steps, which then quietly get overwritten during inference.

This would also explain why people have reported this model to have bad output, as running this model at 8 steps provides a vastly different output than at 2 steps.

Proposed solutions could be to switch unet/LoRas during inference or to make the documentation very clear how this specifc model behaves. Luckily with models like RealVisXL_V4.0_Lightning you're not tied to a specific amount of inference_steps

yondonfu commented 2 weeks ago

Agreed that this behavior is confusing. FWIW I originally implemented this as a quick hack to support loading a specific N-step checkpoint for SDXL-Lightning (since all the SDXL-Lightning checkpoints are tied with a specific # of inference steps) on pipeline initialization.

LoRA switching at inference time could work (I'm not sure that unet switching at inference time would be a good idea as that would probably incur a lot more overhead), but since the general LoRA switching logic is not implemented yet IMO starting with the low hanging fruit of establishing clearer docs would be a better place to start.

rickstaa commented 2 weeks ago

Agreed that this behavior is confusing. FWIW I originally implemented this as a quick hack to support loading a specific N-step checkpoint for SDXL-Lightning (since all the SDXL-Lightning checkpoints are tied with a specific # of inference steps) on pipeline initialization.

LoRA switching at inference time could work (I'm not sure that unet switching at inference time would be a good idea as that would probably incur a lot more overhead), but since the general LoRA switching logic is not implemented yet IMO starting with the low hanging fruit of establishing clearer docs would be a better place to start.

image

@stronk-dev, @yondonfu, what are your thoughts on removing the 2/4-step models and exclusively serving the 8-step model, while documenting the behavior of the unused parameters? The 2-step model is only 1 second faster, and the difference between the 4-step and 8-step models is minimal. I think this will decrease the confusion.

stronk-dev commented 2 weeks ago

I haven't done testing with the 4 step model, so I can't speak to it's inference speed and quality difference. I'd expect there to be a bigger difference in inference time (I think the worker prints the amount of it/sec ? You could calculate the extra time required using that).

I'd certainly prefer the simplicity of advertising just 1 model, but I would be curious to see the quality difference between 4-step and 8-step first

yondonfu commented 1 week ago

what are your thoughts on removing the 2/4-step models and exclusively serving the 8-step model

IMO the 2/4-step models should continue to be supported and the 8-step model can just be used as the default if the model ID is set to ByteDance/SDXL-Lightning. This change in tandem with a couple sentences in relevant docs, noting that ByteDance/SDXL-Lightning is a special case (in most other cases each HF repo just has one model, but ByteDance structured their HF repo differently) where the user can also specify the -Nstep suffix in the model ID to use the corresponding N-step model, should address the OP.

My reasoning here is that each N-step model is actually a distinct checkpoint and devs should be able to request specific checkpoints if they want to and ByteDance/SDXL-Lightning will just be treated as an alias for ByteDance/SDXL-Lightning-8step.