chengzeyi / stable-fast

Best inference performance optimization framework for HuggingFace Diffusers on NVIDIA GPUs.
MIT License
1.19k stars 74 forks source link

Dynamically Switch LoRA error #104

Open yangshaobooo opened 10 months ago

yangshaobooo commented 10 months ago

When I use Dynamically Switch LoRA, I cannot achieve the switching of LoRa. Switching to LoRa 2 results in the same image as LoRa 1.

PyTorch version 2.1.2


def update_state_dict(dst, src):
    for key, value in src.items():
        dst[key].copy_(value)

# Switch "another" LoRA into UNet
def switch_lora(unet, lora):
    state_dict = unet.state_dict()
    unet.load_attn_procs(lora)
    update_state_dict(state_dict, unet.state_dict())
    unet.load_state_dict(state_dict, assign=True)

def main():
    controlnet = ControlNetModel.from_pretrained(args.controlnet, torch_dtype=torch.float16).to("cuda")
    model = StableDiffusionControlNetPipeline.from_pretrained(args.model, controlnet=controlnet, torch_dtype=torch.float16).to("cuda")
    model.load_lora_weights(LORA1)
    model = compile_model(model)

    control_image = Image.open(args.control_image)

    # first generate
    generator = torch.Generator("cuda").manual_seed(123)
    images = model(prompt=args.prompt, image=control_image, height=args.height, width=args.width, num_inference_steps=args.steps,
                guidance_scale=args.guidance_scale,generator=generator).images[0]
    images.save("aigc.png")

    # Dynamically Switch LoRA
    switch_lora(model.unet,LORA2)

    # second generate
    generator = torch.Generator("cuda").manual_seed(123)
    images = model(prompt=args.prompt, image=control_image, height=args.height, width=args.width, num_inference_steps=args.steps,
                guidance_scale=args.guidance_scale,generator=generator).images[0]
    images.save("aigc2.png")

if __name__ == "__main__":
    main()`
chengzeyi commented 10 months ago

Besides UNet, other parts may also need switching.

yangshaobooo commented 10 months ago

Besides UNet, other parts may also need switching

截屏2024-01-15 17 01 13

why the image1 is same to image2?Just convert Lora, do other parts also need to be switched? How to convert other parts?