Closed hanwenxu1 closed 1 week ago
Could you take a look when it's convenient for you? @glide-the
Could you take a look when it's convenient for you? @glide-the
I had solved it! I replaced the layers' name of the lora weights with the name of CogVideoXPipeline, and it was successfully loaded into model. Thank you for your attention!
Differences between SAT and HuggfaceDiffusers framework The training setp of Diffusers is different from that of SAT. Detailed description can be found in the official Huggface document later. SAT is obviously faster than HF in 1500 iterations per sample, but the result is that under the same lora parameters and learning rate (because HF only provides this parameter modification), after 150 iterations of sat, the character image is learned, but the original sample is basically ungenerated Diffusers can perfectly restore the character image after iterating to setp 10,000 times. The original sample only has the difference of optical flow effect (tending to the target training sample For specific differences, please refer to the training task Lora_role_7 of SAT framework and the training task Lora_hf_text_video of Diffusers framework
@dctnorin
Differences between SAT and HuggfaceDiffusers framework The training setp of Diffusers is different from that of SAT. Detailed description can be found in the official Huggface document later. SAT is obviously faster than HF in 1500 iterations per sample, but the result is that under the same lora parameters and learning rate (because HF only provides this parameter modification), after 150 iterations of sat, the character image is learned, but the original sample is basically ungenerated Diffusers can perfectly restore the character image after iterating to setp 10,000 times. The original sample only has the difference of optical flow effect (tending to the target training sample For specific differences, please refer to the training task Lora_role_7 of SAT framework and the training task Lora_hf_text_video of Diffusers framework
@dctnorin
I see. Thank you for your reply!
After train on my own data. When I use the lora weights to inference by
python cli_demo.py,
some errors happen.ValueError: Target modules {'base_model.model.transformer_blocks.22.attn1.to_v', 'base_model.model.transformer_blocks.0.attn1.to_v', 'base_model.model.transformer_blocks.9.attn1.to_v', 'base_model.model.transformer_blocks.26.attn1.to_q', 'base_model.model.transformer_blocks.10.attn1.to_q', 'base_model.model.transformer_blocks.18.attn1.to_v', 'base_model.model.transformer_blocks.19.attn1.to_q', 'base_model.model.transformer_blocks.18.attn1.to_q', 'base_model.model.transformer_blocks.25.attn1.to_q', 'base_model.model.transformer_blocks.19.attn1.to_v', 'base_model.model.transformer_blocks.6.attn1.to_v', 'base_model.model.transformer_blocks.22.attn1.to_out.0', 'base_model.model.transformer_blocks.0.attn1.to_k', 'base_model.model.transformer_blocks.14.attn1.to_k', 'base_model.model.transformer_blocks.29.attn1.to_q', 'base_model.model.transformer_blocks.27.attn1.to_k', 'base_model.model.transformer_blocks.11.attn1.to_q', 'base_model.model.transformer_blocks.2.attn1.to_out.0', 'base_model.model.transformer_blocks.10.attn1.to_out.0', 'base_model.model.transformer_blocks.26.attn1.to_out.0', 'base_model.model.transformer_blocks.5.attn1.to_v', 'base_model.model.transformer_blocks.9.attn1.to_q', 'base_model.model.transformer_blocks.6.attn1.to_q', 'base_model.model.transformer_blocks.26.attn1.to_v', 'base_model.model.transformer_blocks.15.attn1.to_out.0', 'base_model.model.transformer_blocks.25.attn1.to_v', 'base_model.model.transformer_blocks.24.attn1.to_v', 'base_model.model.transformer_blocks.9.attn1.to_k', 'base_model.model.transformer_blocks.23.attn1.to_k', 'base_model.model.transformer_blocks.9.attn1.to_out.0', 'base_model.model.transformer_blocks.3.attn1.to_q', 'base_model.model.transformer_blocks.21.attn1.to_v', 'base_model.model.transformer_blocks.2.attn1.to_k', 'base_model.model.transformer_blocks.12.attn1.to_out.0', 'base_model.model.transformer_blocks.4.attn1.to_v', 'base_model.model.transformer_blocks.28.attn1.to_v', 'base_model.model.transformer_blocks.27.attn1.to_q', 'base_model.model.transformer_blocks.29.attn1.to_k', 'base_model.model.transformer_blocks.13.attn1.to_v', 'base_model.model.transformer_blocks.27.attn1.to_out.0', 'base_model.model.transformer_blocks.12.attn1.to_v', 'base_model.model.transformer_blocks.21.attn1.to_q', 'base_model.model.transformer_blocks.29.attn1.to_out.0', 'base_model.model.transformer_blocks.13.attn1.to_q', 'base_model.model.transformer_blocks.22.attn1.to_q', 'base_model.model.transformer_blocks.0.attn1.to_q', 'base_model.model.transformer_blocks.8.attn1.to_v', 'base_model.model.transformer_blocks.11.attn1.to_k', 'base_model.model.transformer_blocks.26.attn1.to_k', 'base_model.model.transformer_blocks.28.attn1.to_q', 'base_model.model.transformer_blocks.22.attn1.to_k', 'base_model.model.transformer_blocks.11.attn1.to_v', 'base_model.model.transformer_blocks.14.attn1.to_v', 'base_model.model.transformer_blocks.16.attn1.to_k', 'base_model.model.transformer_blocks.24.attn1.to_k', 'base_model.model.transformer_blocks.28.attn1.to_k', 'base_model.model.transformer_blocks.10.attn1.to_k', 'base_model.model.transformer_blocks.8.attn1.to_k', 'base_model.model.transformer_blocks.15.attn1.to_q', 'base_model.model.transformer_blocks.16.attn1.to_out.0', 'base_model.model.transformer_blocks.2.attn1.to_q', 'base_model.model.transformer_blocks.5.attn1.to_q', 'base_model.model.transformer_blocks.19.attn1.to_out.0', 'base_model.model.transformer_blocks.27.attn1.to_v', 'base_model.model.transformer_blocks.7.attn1.to_k', 'base_model.model.transformer_blocks.7.attn1.to_out.0', 'base_model.model.transformer_blocks.2.attn1.to_v', 'base_model.model.transformer_blocks.6.attn1.to_k', 'base_model.model.transformer_blocks.21.attn1.to_k', 'base_model.model.transformer_blocks.15.attn1.to_k', 'base_model.model.transformer_blocks.13.attn1.to_k', 'base_model.model.transformer_blocks.18.attn1.to_k', 'base_model.model.transformer_blocks.21.attn1.to_out.0', 'base_model.model.transformer_blocks.23.attn1.to_q', 'base_model.model.transformer_blocks.23.attn1.to_v', 'base_model.model.transformer_blocks.20.attn1.to_v', 'base_model.model.transformer_blocks.4.attn1.to_q', 'base_model.model.transformer_blocks.3.attn1.to_k', 'base_model.model.transformer_blocks.20.attn1.to_q', 'base_model.model.transformer_blocks.17.attn1.to_q', 'base_model.model.transformer_blocks.25.attn1.to_out.0', 'base_model.model.transformer_blocks.23.attn1.to_out.0', 'base_model.model.transformer_blocks.17.attn1.to_v', 'base_model.model.transformer_blocks.1.attn1.to_v', 'base_model.model.transformer_blocks.20.attn1.to_k', 'base_model.model.transformer_blocks.14.attn1.to_q', 'base_model.model.transformer_blocks.1.attn1.to_out.0', 'base_model.model.transformer_blocks.20.attn1.to_out.0', 'base_model.model.transformer_blocks.7.attn1.to_v', 'base_model.model.transformer_blocks.5.attn1.to_k', 'base_model.model.transformer_blocks.16.attn1.to_q', 'base_model.model.transformer_blocks.1.attn1.to_q', 'base_model.model.transformer_blocks.4.attn1.to_k', 'base_model.model.transformer_blocks.8.attn1.to_q', 'base_model.model.transformer_blocks.29.attn1.to_v', 'base_model.model.transformer_blocks.1.attn1.to_k', 'base_model.model.transformer_blocks.3.attn1.to_out.0', 'base_model.model.transformer_blocks.7.attn1.to_q', 'base_model.model.transformer_blocks.12.attn1.to_k', 'base_model.model.transformer_blocks.10.attn1.to_v', 'base_model.model.transformer_blocks.3.attn1.to_v', 'base_model.model.transformer_blocks.24.attn1.to_out.0', 'base_model.model.transformer_blocks.17.attn1.to_k', 'base_model.model.transformer_blocks.8.attn1.to_out.0', 'base_model.model.transformer_blocks.16.attn1.to_v', 'base_model.model.transformer_blocks.13.attn1.to_out.0', 'base_model.model.transformer_blocks.5.attn1.to_out.0', 'base_model.model.transformer_blocks.6.attn1.to_out.0', 'base_model.model.transformer_blocks.14.attn1.to_out.0', 'base_model.model.transformer_blocks.11.attn1.to_out.0', 'base_model.model.transformer_blocks.24.attn1.to_q', 'base_model.model.transformer_blocks.28.attn1.to_out.0', 'base_model.model.transformer_blocks.18.attn1.to_out.0', 'base_model.model.transformer_blocks.12.attn1.to_q', 'base_model.model.transformer_blocks.15.attn1.to_v', 'base_model.model.transformer_blocks.19.attn1.to_k', 'base_model.model.transformer_blocks.25.attn1.to_k', 'base_model.model.transformer_blocks.0.attn1.to_out.0', 'base_model.model.transformer_blocks.4.attn1.to_out.0', 'base_model.model.transformer_blocks.17.attn1.to_out.0'} not found in the base model. Please check the target modules and try again.``
I think the lora weights is not been merge to model. But
cli_demo.py
have the code to load loraif lora_path: pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1") pipe.fuse_lora(lora_scale=1 / lora_rank)
And in train code I use the` transformer_lora_config = LoraConfig( r=args.rank, lora_alpha=args.rank, init_lora_weights=True, target_modules=["to_k", "to_q", "to_v", "to_out.0"], )
Add adapter and make sure the trainable params are in float32.
The command is
python inference/cli_demo.py --prompt "a man is walking on the street." --lora_path my-path/checkpoint-10000
Please tell me how to solve it. Thanks!