G-U-N / Phased-Consistency-Model

[NeurIPS 2024] Boosting the performance of consistency models with PCM!
https://g-u-n.github.io/projects/pcm/
Apache License 2.0
342 stars 11 forks source link

[Inference Issue] ValueError when trying to load LoRA weights with diffusers #2

Open linoytsaban opened 4 months ago

linoytsaban commented 4 months ago

Hey!

Congrats on you work, and thanks a lot of sharing it 🤗 When trying to use the sd1.5 and sdxl checkpoints on the hub for inference with diffusers, I got this following error when calling load_lora_weights:

from diffusers import AutoPipelineForText2Image

model_id = "stabilityai/stable-diffusion-xl-base-1.0"
adapter_id = "wangfuyun/PCM_SDXL_LoRAs"

pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16")
pipe.load_lora_weights(adapter_id, weight_name="pcm_sdxl_normalcfg_16step.safetensors")

ValueError: Target modules {'base_model.model.up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0', 'base_model.model.up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_out.0', 'base_model.model.down_blocks.0.attentions.1.proj_in', 'base_model.model.up_blocks.1.attentions.1.proj_in', 'base_model.model.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0', 'base_model.model.up_blocks.3.resnets.0.conv_shortcut', 'base_model.model.down_blocks.3.resnets.0.conv1', 'base_model.model.down_blocks.3.resnets.0.time_emb_proj', 'base_model.model.up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0', 'base_model.model.up_blocks.3.attentions.1.transformer_blocks.0.ff.net.0.proj', 'base_model.model.down_blocks.3.resnets.0.conv2', '
....
, 'base_model.model.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v', 'base_model.model.up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q', 'base_model.model.up_blocks.2.attentions.1.proj_out', 'base_model.model.up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_v', 'base_model.model.up_blocks.3.attentions.0.proj_out', 'base_model.model.up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_v', , 'base_model.model.down_blocks.0.resnets.1.time_emb_proj', 'base_model.model.down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_v'} not found in the base model. Please check the target modules and try again.
G-U-N commented 4 months ago

The weights were not converted. I will upload the converted weights soon.

Try this


from safetensors.torch import load_file, save_file

def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = "default"):
    kohya_ss_state_dict = {}
    for peft_key, weight in module.items():
        kohya_key = peft_key.replace("base_model.model", prefix)
        kohya_key = kohya_key.replace("lora_A", "lora_down")
        kohya_key = kohya_key.replace("lora_B", "lora_up")
        kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
        kohya_ss_state_dict[kohya_key] = weight.to(dtype)
        # Set alpha parameter
        if "lora_down" in kohya_key:
            alpha_key = f'{kohya_key.split(".")[0]}.alpha'
            kohya_ss_state_dict[alpha_key] = torch.tensor(8).to(dtype)

    return kohya_ss_state_dict

pcm_lora_weight = load_file(pcm_lora_path)
pcm_lora_weight_convert = get_module_kohya_state_dict(pcm_lora_weight, "lora_unet", weight_dtype)
pipe.load_lora_weights(pcm_lora_weight_convert)
save_file(pcm_lora_weight_convert, "converted_pcm_lora.safetensors")
G-U-N commented 4 months ago

Also set

scheduler=DDIMScheduler(
            num_train_timesteps=1000,
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            timestep_spacing="trailing",
)  # DDIM should just work well. See our discussion on parameterization in the paper.
radames commented 4 months ago

thanks @G-U-N , do we have to use the modified DDIMScheduler from here? https://github.com/G-U-N/Phased-Consistency-Model/blob/master/code/text_to_image_sd15/scheduling_ddpm_modified.py

G-U-N commented 4 months ago

@radames. Don't need that for inference. I just add the `noise_travel' function in the original DDPM implementation of diffusers for training convenience.

G-U-N commented 4 months ago

Also set

scheduler=DDIMScheduler(
            num_train_timesteps=1000,
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            timestep_spacing="trailing",
)  # DDIM should just work well. See our discussion on parameterization in the paper.

We can just use this scheduler for inference. I have thought about a more reasonable scheduler design: You can image that as a series small LCM scheduler. Within each small LCM scheduler, we can do inference through stochastic inference. Cross different schedulers, we can apply the deterministic algorithm. But I think that will make the whole thing a bit too complex.

radames commented 4 months ago

Great! I works! Got some weird results for the normal cfg loras, but for smallcfg it was consistent. Can I a PR on huggingface with the converted loras?

Same params

prompt = "cinematic picture of an astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
negative_prompt = "3d render, carton, drawing, art, low light, blur, pixelated, low resolution, black and white"
num_inference_steps = 2
height = 512
width = height
guidance_scale = 0
seed = 2341412232

4step-normal guidance 7.5

8step-normal guidance 7.5

2step guidance 0

4step guidance 0

G-U-N commented 4 months ago

@radames. Yes, many thanks for testing!

For the results of normal CFG, I just realize some of my implementation is flawed. And I just find a better way to do that! I have seen some promising results, and I might upload them in the coming days!

radames commented 4 months ago

perfect! Please let us know if you want to setup a demo o HF Spaces, I'll be happy to kickstarted this for you and transfer to your profile!

radames commented 4 months ago

hi @G-U-N , for SDXL do I use the same params for the DDPMScheduler ?

G-U-N commented 4 months ago

@radames Yes, DDIM. TCDScheduler should also work.

radames commented 4 months ago

I noticed you've converted the weights! yeah! thanks BTW the TCDScheduler works better with SDXL!!

DDPMScheduler DDPMScheduler TCDScheduler TCD

G-U-N commented 4 months ago

Hi @radames. It does not look right with DDPM.

Both setting DDIM with

DDIMScheduler(
            num_train_timesteps=1000,
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            timestep_spacing="trailing",
) 

and using TCD should get good results.

G-U-N commented 4 months ago

DDPM is a stochastic scheduler in nature, which is not aligned with the training of PCM LoRA.

G-U-N commented 4 months ago

Reasons for DDPM not getting good results:

radames commented 4 months ago

yes that makes sense! thanks for the insight and amazing working! I'll setup a Space demo for you later today and transfer to your name! thanks

G-U-N commented 4 months ago

Many thanks @radames. I sincerely appreciate your attention and help!

radames commented 4 months ago

Final question, for the new LCM LIke Lora would it make sense to use the same params?

            num_train_timesteps=1000,
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            timestep_spacing="trailing",
        )

Making a Space demo with all Lora Options

image
G-U-N commented 4 months ago

Thanks @radames ! The demo looks awesome!

For the LCM like LoRA, it should use LCM scheduler and can flexible choose the step of sampling.

not-ski commented 4 months ago

@radames. Yes, many thanks for testing!

For the results of normal CFG, I just realize some of my implementation is flawed. And I just find a better way to do that! I have seen some promising results, and I might upload them in the coming days!

@G-U-N any update on this? Great work btw <3

xizi commented 3 weeks ago

The weights were not converted. I will upload the converted weights soon.

Try this

from safetensors.torch import load_file, save_file

def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = "default"):
    kohya_ss_state_dict = {}
    for peft_key, weight in module.items():
        kohya_key = peft_key.replace("base_model.model", prefix)
        kohya_key = kohya_key.replace("lora_A", "lora_down")
        kohya_key = kohya_key.replace("lora_B", "lora_up")
        kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
        kohya_ss_state_dict[kohya_key] = weight.to(dtype)
        # Set alpha parameter
        if "lora_down" in kohya_key:
            alpha_key = f'{kohya_key.split(".")[0]}.alpha'
            kohya_ss_state_dict[alpha_key] = torch.tensor(8).to(dtype)

    return kohya_ss_state_dict

pcm_lora_weight = load_file(pcm_lora_path)
pcm_lora_weight_convert = get_module_kohya_state_dict(pcm_lora_weight, "lora_unet", weight_dtype)
pipe.load_lora_weights(pcm_lora_weight_convert)
save_file(pcm_lora_weight_convert, "converted_pcm_lora.safetensors")

pipe.load_lora_weights(pcm_lora_weight_convert) 用自己训练的sdxl的pcm lora加载报错,*** IndexError: list index out of range

xizi commented 3 weeks ago

The weights were not converted. I will upload the converted weights soon. Try this

from safetensors.torch import load_file, save_file

def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = "default"):
    kohya_ss_state_dict = {}
    for peft_key, weight in module.items():
        kohya_key = peft_key.replace("base_model.model", prefix)
        kohya_key = kohya_key.replace("lora_A", "lora_down")
        kohya_key = kohya_key.replace("lora_B", "lora_up")
        kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
        kohya_ss_state_dict[kohya_key] = weight.to(dtype)
        # Set alpha parameter
        if "lora_down" in kohya_key:
            alpha_key = f'{kohya_key.split(".")[0]}.alpha'
            kohya_ss_state_dict[alpha_key] = torch.tensor(8).to(dtype)

    return kohya_ss_state_dict

pcm_lora_weight = load_file(pcm_lora_path)
pcm_lora_weight_convert = get_module_kohya_state_dict(pcm_lora_weight, "lora_unet", weight_dtype)
pipe.load_lora_weights(pcm_lora_weight_convert)
save_file(pcm_lora_weight_convert, "converted_pcm_lora.safetensors")

pipe.load_lora_weights(pcm_lora_weight_convert) 用自己训练的sdxl的pcm lora加载报错,*** IndexError: list index out of range

Problem solved