luosiallen / latent-consistency-model

Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference
MIT License
4.31k stars 224 forks source link

Infer problem about loading lora weights #57

Open dcfucheng opened 10 months ago

dcfucheng commented 10 months ago

Hey~

Good jobs~ I have trained SD Lora on my custom dataset. But I have some problems with inference ONLY.

With the state_dict() we saved by ''' lora_state_dict = get_peft_model_statedict(unet, adapter_name="default") StableDiffusionPipeline.save_lora_weights(os.path.join(output_dir, "unet_lora"), lora_state_dict) '''

The keys of the saved model are named like ''' base_model.model.mid_block.resnets.1.time_emb_proj.lora_B.weight '''

But I checked the pytorch_lora_weights.safetensors are like ''' lora_unet_up_blocks_2_attentions_0_proj_in.lora_up.weight ''' which can be correctly loaded by "pipe.load_lora_weights()".

But the models we saved can not be loaded directly. So, the question is how to load the Lora weights we save. Or should we convert the Lora weights before we save?

Thanks~

JingyeChen commented 10 months ago

i have encounted the same problem

dcfucheng commented 10 months ago

i� have encounted the same problem

I try to load Lora weight as this way. The weights can be loaded, but I train the SD2.1 which generates a noise picture. https://github.com/luosiallen/latent-consistency-model/issues/65

So, I am not sure this is correct. You can try it. Welcome to discuss~

def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = "default"):
    kohya_ss_state_dict = {}
    for peft_key, weight in module.items():
        kohya_key = peft_key.replace("unet.base_model.model", prefix)
        kohya_key = kohya_key.replace("lora_A", "lora_down")
        kohya_key = kohya_key.replace("lora_B", "lora_up")
        kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
        kohya_ss_state_dict[kohya_key] = weight.to(dtype)
        # Set alpha parameter
        if "lora_down" in kohya_key:
            alpha_key = f'{kohya_key.split(".")[0]}.alpha'
            kohya_ss_state_dict[alpha_key] = torch.tensor(8).to(dtype)

    return kohya_ss_state_dict

from safetensors.torch import load_file
lora_weight= load_file('/path/unet_lora/pytorch_lora_weights.safetensors')
lora_state_dict = get_module_kohya_state_dict(lora_weight, "lora_unet", torch.float16)
pipe.load_lora_weights(lora_state_dict)
pipe.fuse_lora()
zjysteven commented 5 months ago

i� have encounted the same problem

I try to load Lora weight as this way. The weights can be loaded, but I train the SD2.1 which generates a noise picture. #65

So, I am not sure this is correct. You can try it. Welcome to discuss~

def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = "default"):
    kohya_ss_state_dict = {}
    for peft_key, weight in module.items():
        kohya_key = peft_key.replace("unet.base_model.model", prefix)
        kohya_key = kohya_key.replace("lora_A", "lora_down")
        kohya_key = kohya_key.replace("lora_B", "lora_up")
        kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
        kohya_ss_state_dict[kohya_key] = weight.to(dtype)
        # Set alpha parameter
        if "lora_down" in kohya_key:
            alpha_key = f'{kohya_key.split(".")[0]}.alpha'
            kohya_ss_state_dict[alpha_key] = torch.tensor(8).to(dtype)

    return kohya_ss_state_dict

from safetensors.torch import load_file
lora_weight= load_file('/path/unet_lora/pytorch_lora_weights.safetensors')
lora_state_dict = get_module_kohya_state_dict(lora_weight, "lora_unet", torch.float16)
pipe.load_lora_weights(lora_state_dict)
pipe.fuse_lora()

This works, but one caveat is that in this current snippet kohya_ss_state_dict[alpha_key] = torch.tensor(8).to(dtype) the 8 which is the lora_alpha is hard-coded. Remember to change according to your config. An alternative which is more robust is to save the lora weights with the provided get_module_kohya_state_dict function in the training script: https://github.com/luosiallen/latent-consistency-model/blob/a9ad79587cc8bd1e404ccd1a3056a3da969b2f62/LCM_Training_Script/consistency_distillation/train_lcm_distill_lora_sd_wds.py#L79-L93 with which you can save the trained lora weights by

from safetensors.torch import save_file

lora_state_dict = get_module_kohya_state_dict(unet, "lora_unet", weight_dtype)
# this will add 'unet.' prefix to the state_dict keys
# StableDiffusionPipeline.save_lora_weights(os.path.join(output_dir, "unet_lora"), lora_state_dict)
# instead can directly save the state_dict
save_file(lora_state_dict, os.path.join(output_dir, "unet_lora", "pytorch_lora_weights.safetensors"))

Then you should be able to load directly with pipe.load_lora_weights().