chengzeyi / stable-fast

Best inference performance optimization framework for HuggingFace Diffusers on NVIDIA GPUs.
MIT License
1.19k stars 73 forks source link

请问一下谁可以共享一下在config.enable_cuda_graph = True的情况下切换Lora已经可用的Demo嘛 #144

Closed CallmeZhangChenchen closed 4 months ago

CallmeZhangChenchen commented 7 months ago

不是我懒,而是我研究了好几天没做出来 config.enable_cuda_graph = True的情况下 sd1.5推理只需要700ms, config.enable_cuda_graph = False的情况下 sd1.5推理需要3s, 所以必须要有一种config.enable_cuda_graph = True的情况下 ,可行的切换Lora的解决方案

def update_state_dict(dst, src):
    for key, value in src.items():
        dst[key].copy_(value)

README里面的切换Lora的代码,有好多疑问,不同的Lora结构不一样,dst与src里面的key的名字貌似不一样,它是如何copy_的 貌似只有用同一套代码训练的同样结构的Lora才能切换成功

TODO:我将做一些尝试,依据现有代码,共享指针,只需要替换底层数据,深入研究一下Lora的原理,提供一个可行的切换Lora的PR

CallmeZhangChenchen commented 7 months ago

我尝试拿两个结构一样的Lora去做实验, 发现没有成功,出来的结果还是一样的


from diffusers.loaders.lora import LoraLoaderMixin

# load_state_dict with assign=True requires torch >= 2.1.0
loraloader = LoraLoaderMixin()
# Switch "another" LoRA into UNet
def switch_lora(model, lora):
    unet = model.unet
    # Store the original UNet parameters
    state_dict = unet.state_dict()
    text_encoder = model.text_encoder.state_dict()   # 
    # Load another LoRA into unet
    state_dict_new_lora, network_alphas_new_lora = loraloader.lora_state_dict(lora)
    # Inplace copy current UNet parameters to the original unet parameters
    for key, value in state_dict_new_lora.items():
        if 'unet' in key:
            key = key.replace("processor.", "")
            key = key.replace("to_out_lora.down","to_out.0.lora_A.default_0")
            key = key.replace("to_out_lora.up","to_out.0.lora_B.default_0")
            key = key.replace("_lora.down", ".lora_A.default_0")
            key = key.replace("_lora.up", ".lora_B.default_0")
            key = key.replace(".down.","_A.default_0.")
            key = key.replace(".up.","_B.default_0.")
            key = key.split('unet.')[-1]
            state_dict[key].copy_(value)
        if 'text_encoder' in key:
            key = key.replace("_lora.down", ".lora_A.default_0")
            key = key.replace("_lora.up", ".lora_B.default_0")
            key = key.replace("linear_layer.down", "A.default_0")
            key = key.replace("linear_layer.up", "B.default_0")
            key = key.replace("to_k", "k_proj")
            key = key.replace("to_v", "v_proj")
            key = key.replace("to_q", "q_proj")
            key = key.replace("to_out", "out_proj")
            key = key.split('text_encoder.')[-1]
            text_encoder[key].copy_(value)
    # Load the original UNet parameters back.
    # We use assign=True because we still want to hold the references
    # of the original UNet parameters
    unet.load_state_dict(state_dict, assign=True)
    model.text_encoder.load_state_dict(text_encoder, assign=True)

TODO:参数的确是copy过去了, 需要研究一下为什么推理的时候没有生效,感觉跟fuse_lora有关系

CallmeZhangChenchen commented 7 months ago

哇 成功了!模型初始化Load Lora的时候不能fuse_lora, 用上面的代码就能正常切换了

PS:SDXL的模型不需要这样做,不管enable_cuda_graph是True or False,直接使用set_adapter()的方式可以切换成功,而且时间都是一样的

TODO:适配所有规格的Lora,切换Lora的耗时让用户感知不到

wanxingzd commented 7 months ago

性能提升是怎么样的?

CallmeZhangChenchen commented 6 months ago

性能提升是怎么样的?

你说的是stable-fast 还是切换lora的性能? 都挺不错,不过还是建议你自己尝试一下

CallmeZhangChenchen commented 6 months ago

经过测试,需要lora_scale的直接这样就可以 textencoder[key].copy(value * lora_scale)

TODO: 这样切换lora,没有fuse操作,推理的时候加上lora,耗时会由700ms增加到900ms, 需要研究一下fuse,unfuse 操作

senlyu163 commented 6 months ago

My implementation is:

  1. Load Lora weights into the diffusers library;
  2. Adjust the rank, key, etc. of Lora (refer to the implementation in diffusers/peft);
  3. When fusing, add the Lora weight of the processed key to the unet/textencoder of the same key, and then replace the fused weight with the update_state-dict method provided by the author; Similarly, subtracting when unfuse restores the original weights of UNet.

Conclusion:

  1. After testing, the inference speed of sfast will not change.
  2. The switching cost of sd1.5 is 2-3 seconds, while the switching cost of trt8.6 is about 9 seconds.
chengzeyi commented 6 months ago

A new project which can achieve peak performance without using CUDA graph is under active development. I hope it could be made publicly soon.