siliconflow / onediff

OneDiff: An out-of-the-box acceleration library for diffusion models.
https://github.com/siliconflow/onediff/wiki
Apache License 2.0
1.4k stars 85 forks source link

加载lora->出图->卸载lora 重复多次后,质量降低 #981

Open vincentmmc opened 5 days ago

vincentmmc commented 5 days ago

pipeline 加载lora->出图->卸载lora 这个步骤重复多次后,效果会越来越差,是因为lora没卸载干净吗?

左边是第一次生成的图,右边是重复1000次后生成图 Snipaste_2024-06-26_09-31-45

以下是复现代码

from diffusers import StableDiffusionPipeline
import torch
import oneflow
from onediffx.lora import load_and_fuse_lora, set_and_fuse_adapters, delete_adapters
from onediffx import compile_pipe

pipe = StableDiffusionPipeline.from_pretrained('runwayml/stable-diffusion-v1-5',torch_dtype=torch.float16, variant='fp16')
pipe = pipe.to("cuda")
pipe = compile_pipe(pipe)
prompt = "a house"
for i in range(1000):
    generator = torch.Generator(device='cuda')
    generator.manual_seed(0)
    load_and_fuse_lora(pipe, 'Norod78/SD15-IllusionDiffusionPattern-LoRA', weight_name='SD15-IllusionDiffusionPattern-LoRA.safetensors', lora_scale=1.0, adapter_name='SD15-IllusionDiffusionPattern-LoRA')
    load_and_fuse_lora(pipe, 'Norod78/sd15-megaphone-lora', weight_name='SD15-Megaphone-LoRA.safetensors', lora_scale=1.0, adapter_name='SD15-Megaphone-LoRA')
    set_and_fuse_adapters(pipe, adapter_names=['SD15-Megaphone-LoRA','SD15-IllusionDiffusionPattern-LoRA'], adapter_weights=[0.2,0.2])
    result = pipe(prompt,
                generator=generator,
                height=512,
                width=512,
                num_inference_steps=20,).images[0]
    result.save('test_lora3/test_img_'+str(i)+'.png')
    delete_adapters(pipe,'SD15-Megaphone-LoRA')
    delete_adapters(pipe,'SD15-IllusionDiffusionPattern-LoRA')

我是用的版本是: onediff=1.1.0.dev1 oneflow=0.9.1.dev20240529+cu118

机器环境: NVIDIA GeForce RTX 3090 CUDA Version: 12.2

nono909090 commented 5 days ago

I met the same situation

vincentmmc commented 5 days ago

@lijunliangTG

lijunliangTG commented 4 days ago

谢谢您的反馈,目前正在复现您的问题。

lijunliangTG commented 3 days ago
from onediffx.lora import load_and_fuse_lora, set_and_fuse_adapters, delete_adapters,unfuse_lora #导入unfuse_lora

  delete_adapters(pipe,'SD15-Megaphone-LoRA')
  delete_adapters(pipe,'SD15-IllusionDiffusionPattern-LoRA')
  unfuse_lora(pipe)  #添加这行代码可以解决
marigoold commented 1 day ago
from onediffx.lora import load_and_fuse_lora, set_and_fuse_adapters, delete_adapters,unfuse_lora #导入unfuse_lora

  delete_adapters(pipe,'SD15-Megaphone-LoRA')
  delete_adapters(pipe,'SD15-IllusionDiffusionPattern-LoRA')
  unfuse_lora(pipe)  #添加这行代码可以解决

@vincentmmc @nono909090 @lijunliangTG This solution is incorrect. The fundamental reason for the decrease in quality is that LoRAs are fused one by one. During each fusion, FP16 weight (of Conv2d or Linear) is cast to FP32, added to the LoRA weight, and then converted back to FP16. Loading multiple LoRAs separately results in multiple FP16 to FP32 to FP16 conversions, leading to significant precision loss. If all LoRAs are loaded at once, there will only be a single FP16 to FP32 to FP16 conversion, preserving precision. To verify this, you can invoke unfuse after the load_and_fuse_lora function and then call set_adapters before generating images to activate all LoRAs without any precision loss. An api for loading multiple LoRAs at once is under development and will address this issue.

nono909090 commented 8 hours ago
from onediffx.lora import load_and_fuse_lora, set_and_fuse_adapters, delete_adapters,unfuse_lora #导入unfuse_lora

  delete_adapters(pipe,'SD15-Megaphone-LoRA')
  delete_adapters(pipe,'SD15-IllusionDiffusionPattern-LoRA')
  unfuse_lora(pipe)  #添加这行代码可以解决

@vincentmmc @nono909090 @lijunliangTG This solution is incorrect. The fundamental reason for the decrease in quality is that LoRAs are fused one by one. During each fusion, FP16 weight (of Conv2d or Linear) is cast to FP32, added to the LoRA weight, and then converted back to FP16. Loading multiple LoRAs separately results in multiple FP16 to FP32 to FP16 conversions, leading to significant precision loss. If all LoRAs are loaded at once, there will only be a single FP16 to FP32 to FP16 conversion, preserving precision. To verify this, you can invoke unfuse after the load_and_fuse_lora function and then call set_adapters before generating images to activate all LoRAs without any precision loss. An api for loading multiple LoRAs at once is under development and will address this issue.

我按照你说的操作,还是出现一样的问题,图片会变化越来越大,质量越来越差 我用上面的测试代码操作的:

from diffusers import StableDiffusionPipeline
import torch
import oneflow
from onediffx.lora import load_and_fuse_lora, set_and_fuse_adapters, delete_adapters, unfuse_lora
from onediffx import compile_pipe

pipe = StableDiffusionPipeline.from_pretrained('runwayml/stable-diffusion-v1-5',torch_dtype=torch.float16, variant='fp16')
pipe = pipe.to("cuda")
pipe = compile_pipe(pipe)
prompt = "a house"
for i in range(1000):
    generator = torch.Generator(device='cuda')
    generator.manual_seed(0)
    load_and_fuse_lora(pipe, 'Norod78/SD15-IllusionDiffusionPattern-LoRA', weight_name='SD15-IllusionDiffusionPattern-LoRA.safetensors', lora_scale=1.0, adapter_name='SD15-IllusionDiffusionPattern-LoRA')
    unfuse_lora(pipe)
    load_and_fuse_lora(pipe, 'Norod78/sd15-megaphone-lora', weight_name='SD15-Megaphone-LoRA.safetensors', lora_scale=1.0, adapter_name='SD15-Megaphone-LoRA')
    unfuse_lora(pipe)
    set_and_fuse_adapters(pipe, adapter_names=['SD15-Megaphone-LoRA','SD15-IllusionDiffusionPattern-LoRA'], adapter_weights=[0.2,0.2])
    result = pipe(prompt,
                generator=generator,
                height=512,
                width=512,
                num_inference_steps=20,).images[0]
    result.save('test_lora4/test_img_'+str(i)+'.png')
    delete_adapters(pipe,'SD15-Megaphone-LoRA')
    delete_adapters(pipe,'SD15-IllusionDiffusionPattern-LoRA')

@marigoold @vincentmmc @lijunliangTG

marigoold commented 7 hours ago
from onediffx.lora import load_and_fuse_lora, set_and_fuse_adapters, delete_adapters,unfuse_lora #导入unfuse_lora

  delete_adapters(pipe,'SD15-Megaphone-LoRA')
  delete_adapters(pipe,'SD15-IllusionDiffusionPattern-LoRA')
  unfuse_lora(pipe)  #添加这行代码可以解决

@vincentmmc @nono909090 @lijunliangTG This solution is incorrect. The fundamental reason for the decrease in quality is that LoRAs are fused one by one. During each fusion, FP16 weight (of Conv2d or Linear) is cast to FP32, added to the LoRA weight, and then converted back to FP16. Loading multiple LoRAs separately results in multiple FP16 to FP32 to FP16 conversions, leading to significant precision loss. If all LoRAs are loaded at once, there will only be a single FP16 to FP32 to FP16 conversion, preserving precision. To verify this, you can invoke unfuse after the load_and_fuse_lora function and then call set_adapters before generating images to activate all LoRAs without any precision loss. An api for loading multiple LoRAs at once is under development and will address this issue.

我按照你说的操作,还是出现一样的问题,图片会变化越来越大,质量越来越差 我用上面的测试代码操作的:

from diffusers import StableDiffusionPipeline
import torch
import oneflow
from onediffx.lora import load_and_fuse_lora, set_and_fuse_adapters, delete_adapters, unfuse_lora
from onediffx import compile_pipe

pipe = StableDiffusionPipeline.from_pretrained('runwayml/stable-diffusion-v1-5',torch_dtype=torch.float16, variant='fp16')
pipe = pipe.to("cuda")
pipe = compile_pipe(pipe)
prompt = "a house"
for i in range(1000):
    generator = torch.Generator(device='cuda')
    generator.manual_seed(0)
    load_and_fuse_lora(pipe, 'Norod78/SD15-IllusionDiffusionPattern-LoRA', weight_name='SD15-IllusionDiffusionPattern-LoRA.safetensors', lora_scale=1.0, adapter_name='SD15-IllusionDiffusionPattern-LoRA')
    unfuse_lora(pipe)
    load_and_fuse_lora(pipe, 'Norod78/sd15-megaphone-lora', weight_name='SD15-Megaphone-LoRA.safetensors', lora_scale=1.0, adapter_name='SD15-Megaphone-LoRA')
    unfuse_lora(pipe)
    set_and_fuse_adapters(pipe, adapter_names=['SD15-Megaphone-LoRA','SD15-IllusionDiffusionPattern-LoRA'], adapter_weights=[0.2,0.2])
    result = pipe(prompt,
                generator=generator,
                height=512,
                width=512,
                num_inference_steps=20,).images[0]
    result.save('test_lora4/test_img_'+str(i)+'.png')
    delete_adapters(pipe,'SD15-Megaphone-LoRA')
    delete_adapters(pipe,'SD15-IllusionDiffusionPattern-LoRA')

@marigoold @vincentmmc @lijunliangTG

最后的 delete_adapters 最好也一次性把所有 adapters 都删掉,如下所示

    delete_adapters(pipe, ['SD15-Megaphone-LoRA', 'SD15-IllusionDiffusionPattern-LoRA'])

因为删除完一次 adapter 也会从 fp32 cast 成 fp16,这里也会有精度损失。这个问题正在修复中,修复后第一时间通知您。