Panchovix / stable-diffusion-webui-reForge

GNU Affero General Public License v3.0
281 stars 10 forks source link

[Feature Request]: Add "Cache FP16 weight for LoRA" setting in FP8 mode #44

Open Animus777 opened 1 month ago

Animus777 commented 1 month ago

Is there an existing issue for this?

What would your feature do ?

In both UIs A1111 and Forge in Optimizations settings there is an option Cache FP16 weight for LoRA.

1

It allows you to preserve LoRA quality while keeping checkpoint models in FP8 (to save VRAM and RAM). It works in A1111 but it doesn't work in Forge (FP8 weight setting above it also doesn't work). I assume that's because, unlike A1111, in Forge FP8 mode is enabled not via ui setting but via command line flag --unet-in-fp8-e4m3fn. But I haven't found any command line flags to keep LoRA in FP16:

--all-in-fp32
--all-in-fp16
--unet-in-bf16
--unet-in-fp16
--unet-in-fp8-e4m3fn
--unet-in-fp8-e5m2
--vae-in-fp16
--vae-in-fp32
--vae-in-bf16
--clip-in-fp8-e4m3fn
--clip-in-fp8-e5m2
--clip-in-fp16
--clip-in-fp32

By digging into files I found that A1111 in modules folder has a file named sd_models.py which seems to have the relevant code about shared.opts.cache_fp16_weight:

    if check_fp8(model):
        devices.fp8 = True
        first_stage = model.first_stage_model
        model.first_stage_model = None
        for module in model.modules():
            if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
                if shared.opts.cache_fp16_weight:
                    module.fp16_weight = module.weight.data.clone().cpu().half()
                    if module.bias is not None:
                        module.fp16_bias = module.bias.data.clone().cpu().half()
                module.to(torch.float8_e4m3fn)
        model.first_stage_model = first_stage
        timer.record("apply fp8")
    else:
        devices.fp8 = False

I'm not sure if it's a simple copy paste job though... To test if it's working I like to use acceleration LoRA i.e LCM Just set Sampler to LCM, Steps to 4 and CFG to 1. The results are extremely bad in FP8 but if you enable Cache FP16 weight for LoRA they are OK.

Panchovix commented 1 month ago

Hi there, thanks for the suggestion. Pushed some commits that maybe helps with this, but not sure if it will work (Kinda ported the A1111 implementation into model management). Can you try and tell me how it goes?

EDIT: Didn't work, reverting the commits for now. Will need to investigate more how to add this feature.