THUDM / CogVideo

text and image to video generation: CogVideoX (2024) and CogVideo (ICLR 2023)
Apache License 2.0
8.4k stars 803 forks source link

Problem about load lora weight #427

Closed hanwenxu1 closed 1 week ago

hanwenxu1 commented 2 weeks ago

After train on my own data. When I use the lora weights to inference by python cli_demo.py, some errors happen.

ValueError: Target modules {'base_model.model.transformer_blocks.22.attn1.to_v', 'base_model.model.transformer_blocks.0.attn1.to_v', 'base_model.model.transformer_blocks.9.attn1.to_v', 'base_model.model.transformer_blocks.26.attn1.to_q', 'base_model.model.transformer_blocks.10.attn1.to_q', 'base_model.model.transformer_blocks.18.attn1.to_v', 'base_model.model.transformer_blocks.19.attn1.to_q', 'base_model.model.transformer_blocks.18.attn1.to_q', 'base_model.model.transformer_blocks.25.attn1.to_q', 'base_model.model.transformer_blocks.19.attn1.to_v', 'base_model.model.transformer_blocks.6.attn1.to_v', 'base_model.model.transformer_blocks.22.attn1.to_out.0', 'base_model.model.transformer_blocks.0.attn1.to_k', 'base_model.model.transformer_blocks.14.attn1.to_k', 'base_model.model.transformer_blocks.29.attn1.to_q', 'base_model.model.transformer_blocks.27.attn1.to_k', 'base_model.model.transformer_blocks.11.attn1.to_q', 'base_model.model.transformer_blocks.2.attn1.to_out.0', 'base_model.model.transformer_blocks.10.attn1.to_out.0', 'base_model.model.transformer_blocks.26.attn1.to_out.0', 'base_model.model.transformer_blocks.5.attn1.to_v', 'base_model.model.transformer_blocks.9.attn1.to_q', 'base_model.model.transformer_blocks.6.attn1.to_q', 'base_model.model.transformer_blocks.26.attn1.to_v', 'base_model.model.transformer_blocks.15.attn1.to_out.0', 'base_model.model.transformer_blocks.25.attn1.to_v', 'base_model.model.transformer_blocks.24.attn1.to_v', 'base_model.model.transformer_blocks.9.attn1.to_k', 'base_model.model.transformer_blocks.23.attn1.to_k', 'base_model.model.transformer_blocks.9.attn1.to_out.0', 'base_model.model.transformer_blocks.3.attn1.to_q', 'base_model.model.transformer_blocks.21.attn1.to_v', 'base_model.model.transformer_blocks.2.attn1.to_k', 'base_model.model.transformer_blocks.12.attn1.to_out.0', 'base_model.model.transformer_blocks.4.attn1.to_v', 'base_model.model.transformer_blocks.28.attn1.to_v', 'base_model.model.transformer_blocks.27.attn1.to_q', 'base_model.model.transformer_blocks.29.attn1.to_k', 'base_model.model.transformer_blocks.13.attn1.to_v', 'base_model.model.transformer_blocks.27.attn1.to_out.0', 'base_model.model.transformer_blocks.12.attn1.to_v', 'base_model.model.transformer_blocks.21.attn1.to_q', 'base_model.model.transformer_blocks.29.attn1.to_out.0', 'base_model.model.transformer_blocks.13.attn1.to_q', 'base_model.model.transformer_blocks.22.attn1.to_q', 'base_model.model.transformer_blocks.0.attn1.to_q', 'base_model.model.transformer_blocks.8.attn1.to_v', 'base_model.model.transformer_blocks.11.attn1.to_k', 'base_model.model.transformer_blocks.26.attn1.to_k', 'base_model.model.transformer_blocks.28.attn1.to_q', 'base_model.model.transformer_blocks.22.attn1.to_k', 'base_model.model.transformer_blocks.11.attn1.to_v', 'base_model.model.transformer_blocks.14.attn1.to_v', 'base_model.model.transformer_blocks.16.attn1.to_k', 'base_model.model.transformer_blocks.24.attn1.to_k', 'base_model.model.transformer_blocks.28.attn1.to_k', 'base_model.model.transformer_blocks.10.attn1.to_k', 'base_model.model.transformer_blocks.8.attn1.to_k', 'base_model.model.transformer_blocks.15.attn1.to_q', 'base_model.model.transformer_blocks.16.attn1.to_out.0', 'base_model.model.transformer_blocks.2.attn1.to_q', 'base_model.model.transformer_blocks.5.attn1.to_q', 'base_model.model.transformer_blocks.19.attn1.to_out.0', 'base_model.model.transformer_blocks.27.attn1.to_v', 'base_model.model.transformer_blocks.7.attn1.to_k', 'base_model.model.transformer_blocks.7.attn1.to_out.0', 'base_model.model.transformer_blocks.2.attn1.to_v', 'base_model.model.transformer_blocks.6.attn1.to_k', 'base_model.model.transformer_blocks.21.attn1.to_k', 'base_model.model.transformer_blocks.15.attn1.to_k', 'base_model.model.transformer_blocks.13.attn1.to_k', 'base_model.model.transformer_blocks.18.attn1.to_k', 'base_model.model.transformer_blocks.21.attn1.to_out.0', 'base_model.model.transformer_blocks.23.attn1.to_q', 'base_model.model.transformer_blocks.23.attn1.to_v', 'base_model.model.transformer_blocks.20.attn1.to_v', 'base_model.model.transformer_blocks.4.attn1.to_q', 'base_model.model.transformer_blocks.3.attn1.to_k', 'base_model.model.transformer_blocks.20.attn1.to_q', 'base_model.model.transformer_blocks.17.attn1.to_q', 'base_model.model.transformer_blocks.25.attn1.to_out.0', 'base_model.model.transformer_blocks.23.attn1.to_out.0', 'base_model.model.transformer_blocks.17.attn1.to_v', 'base_model.model.transformer_blocks.1.attn1.to_v', 'base_model.model.transformer_blocks.20.attn1.to_k', 'base_model.model.transformer_blocks.14.attn1.to_q', 'base_model.model.transformer_blocks.1.attn1.to_out.0', 'base_model.model.transformer_blocks.20.attn1.to_out.0', 'base_model.model.transformer_blocks.7.attn1.to_v', 'base_model.model.transformer_blocks.5.attn1.to_k', 'base_model.model.transformer_blocks.16.attn1.to_q', 'base_model.model.transformer_blocks.1.attn1.to_q', 'base_model.model.transformer_blocks.4.attn1.to_k', 'base_model.model.transformer_blocks.8.attn1.to_q', 'base_model.model.transformer_blocks.29.attn1.to_v', 'base_model.model.transformer_blocks.1.attn1.to_k', 'base_model.model.transformer_blocks.3.attn1.to_out.0', 'base_model.model.transformer_blocks.7.attn1.to_q', 'base_model.model.transformer_blocks.12.attn1.to_k', 'base_model.model.transformer_blocks.10.attn1.to_v', 'base_model.model.transformer_blocks.3.attn1.to_v', 'base_model.model.transformer_blocks.24.attn1.to_out.0', 'base_model.model.transformer_blocks.17.attn1.to_k', 'base_model.model.transformer_blocks.8.attn1.to_out.0', 'base_model.model.transformer_blocks.16.attn1.to_v', 'base_model.model.transformer_blocks.13.attn1.to_out.0', 'base_model.model.transformer_blocks.5.attn1.to_out.0', 'base_model.model.transformer_blocks.6.attn1.to_out.0', 'base_model.model.transformer_blocks.14.attn1.to_out.0', 'base_model.model.transformer_blocks.11.attn1.to_out.0', 'base_model.model.transformer_blocks.24.attn1.to_q', 'base_model.model.transformer_blocks.28.attn1.to_out.0', 'base_model.model.transformer_blocks.18.attn1.to_out.0', 'base_model.model.transformer_blocks.12.attn1.to_q', 'base_model.model.transformer_blocks.15.attn1.to_v', 'base_model.model.transformer_blocks.19.attn1.to_k', 'base_model.model.transformer_blocks.25.attn1.to_k', 'base_model.model.transformer_blocks.0.attn1.to_out.0', 'base_model.model.transformer_blocks.4.attn1.to_out.0', 'base_model.model.transformer_blocks.17.attn1.to_out.0'} not found in the base model. Please check the target modules and try again.``

I think the lora weights is not been merge to model. But cli_demo.py have the code to load lora if lora_path: pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1") pipe.fuse_lora(lora_scale=1 / lora_rank)

And in train code I use the` transformer_lora_config = LoraConfig( r=args.rank, lora_alpha=args.rank, init_lora_weights=True, target_modules=["to_k", "to_q", "to_v", "to_out.0"], )

Add adapter and make sure the trainable params are in float32.

transformer = get_peft_model(transformer ,transformer_lora_config)`

The command is python inference/cli_demo.py --prompt "a man is walking on the street." --lora_path my-path/checkpoint-10000

Please tell me how to solve it. Thanks!

zRzRzRzRzRzRzR commented 2 weeks ago

Could you take a look when it's convenient for you? @glide-the

hanwenxu1 commented 2 weeks ago

Could you take a look when it's convenient for you? @glide-the

I had solved it! I replaced the layers' name of the lora weights with the name of CogVideoXPipeline, and it was successfully loaded into model. Thank you for your attention!

glide-the commented 2 weeks ago

Differences between SAT and HuggfaceDiffusers framework The training setp of Diffusers is different from that of SAT. Detailed description can be found in the official Huggface document later. SAT is obviously faster than HF in 1500 iterations per sample, but the result is that under the same lora parameters and learning rate (because HF only provides this parameter modification), after 150 iterations of sat, the character image is learned, but the original sample is basically ungenerated Diffusers can perfectly restore the character image after iterating to setp 10,000 times. The original sample only has the difference of optical flow effect (tending to the target training sample For specific differences, please refer to the training task Lora_role_7 of SAT framework and the training task Lora_hf_text_video of Diffusers framework

@dctnorin

hanwenxu1 commented 2 weeks ago

Differences between SAT and HuggfaceDiffusers framework The training setp of Diffusers is different from that of SAT. Detailed description can be found in the official Huggface document later. SAT is obviously faster than HF in 1500 iterations per sample, but the result is that under the same lora parameters and learning rate (because HF only provides this parameter modification), after 150 iterations of sat, the character image is learned, but the original sample is basically ungenerated Diffusers can perfectly restore the character image after iterating to setp 10,000 times. The original sample only has the difference of optical flow effect (tending to the target training sample For specific differences, please refer to the training task Lora_role_7 of SAT framework and the training task Lora_hf_text_video of Diffusers framework

@dctnorin

I see. Thank you for your reply!