huggingface / diffusers

πŸ€— Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.99k stars 5.35k forks source link

train_dreambooth_lora_sdxl_advanced.py and train_dreambooth_lora_sdxl.py do not load previously saved checkpoints correctly #6366

Closed prushik closed 9 months ago

prushik commented 10 months ago

Describe the bug

both examples/dreambooth/train_dreambooth_lora_sdxl.py and examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py seem to have an issue when resuming training from a previously saved checkpoint.

Training and saving checkpoints seems to work correctly, however, when resuming from a previously saved checkpoint, the following messages are produced at script startup: Resuming from checkpoint checkpoint-10

12/27/2023 16:29:22 - INFO - accelerate.accelerator - Loading states from xqc/checkpoint-10
Loading unet.
12/27/2023 16:29:22 - INFO - peft.tuners.tuners_utils - Already found a `peft_config` attribute in the model. This will lead to having multiple adapters in the model. Make sure to know what you are doing!
Loading text_encoder.
12/27/2023 16:29:23 - INFO - peft.tuners.tuners_utils - Already found a `peft_config` attribute in the model. This will lead to having multiple adapters in the model. Make sure to know what you are doing!

Training appears to continue normally, however, all new checkpoints saved after this will be significantly larger than the previous checkpoints:

(xl) localhost /media/nvme/xl/diffusers/examples/dreambooth # du -sch xqc/*
87M     xqc/checkpoint-10
110M    xqc/checkpoint-15
110M    xqc/checkpoint-20
110M    xqc/checkpoint-25
87M     xqc/checkpoint-5
88K     xqc/logs
494M    total

Once training with a resumed checkpoint is completed, there will be a large dump of layer names with a message saying that the model contains layers that do not match. (Full error message below)

To me, this looks like the checkpoints are being loaded incorrectly and ignored, and then a new adapter is being trained from scratch, and then both versions, old and new, are saved in the final lora.

Reproduction

To reproduce this issue, follow the following steps:

  1. Run either train_dreambooth_lora_sdxl*.py script with appropriate parameters, including --checkpointing_steps (preferably set to a low number to reproduce this issue quickly).
  2. After at least one or two checkpoints have been saved, either stop the script or wait for it to complete.
  3. Rerun the same script, but also include the --resume_from_checkpoint latest or --resume_from_checkpoint checkpoint-x.
  4. Observe the effects listed above (PEFT warning message on startup, later checkpoint file sizes)
  5. After resumed training is completed, attempt to load the finished lora. (inference will be successful, but lora performance does not seem correct).
  6. Observe the error message produced.

Logs

My full command-line with all arguments looks like this:

python train_dreambooth_lora_sdxl.py --pretrained_model_name_or_path ../../../models/colossus_v5.3 --instance_data_dir /media/nvme/datasets/combined/ --output_dir xqc --resolution 1024 --instance_prompt 'a photo of hxq' --train_text_encoder --num_train_epochs 1 --train_batch_size 1 --gradient_checkpointing --checkpointing_steps 5 --gradient_accumulation_steps 1 --learning_rate 0.0001 --resume_from_checkpoint latest

Error produced during inference with the affected lora (truncated because of length):

Loading adapter weights from state_dict led to unexpected keys not found in the model:  ['down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.lora_A_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.lora_B_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.lora_A_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.lora_B_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.lora_A_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.lora_B_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.lora_A_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.lora_B_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.lora_A_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.lora_B_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.lora_A_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.lora_B_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.lora_A_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.lora_B_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.lora_A_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.lora_B_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q.lora_A_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q.lora_B_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k.lora_A_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k.lora_B_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v.lora_A_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v.lora_B_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.lora_A_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.lora_B_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q.lora_A_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q.lora_B_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k.lora_A_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k.lora_B_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v.lora_A_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v.lora_B_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.lora_A_1.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.lora_B_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.lora_A_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.lora_B_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.lora_A_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.lora_B_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.lora_A_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.lora_B_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.lora_A_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.lora_B_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.lora_A_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.lora_B_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.lora_A_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.lora_B_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.lora_A_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.lora_B_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.lora_A_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.lora_B_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q.lora_A_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q.lora_B_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k.lora_A_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k.lora_B_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v.lora_A_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v.lora_B_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.lora_A_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.lora_B_1.default_0.weight', 'down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q.lora_A_1.default_0.weight', 

*** TRUNCATED HERE ***

'mid_block.attentions.0.transformer_blocks.8.attn1.to_k.lora_A_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.8.attn1.to_k.lora_B_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.8.attn1.to_v.lora_A_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.8.attn1.to_v.lora_B_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.8.attn1.to_out.0.lora_A_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.8.attn1.to_out.0.lora_B_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.8.attn2.to_q.lora_A_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.8.attn2.to_q.lora_B_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.8.attn2.to_k.lora_A_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.8.attn2.to_k.lora_B_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.8.attn2.to_v.lora_A_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.8.attn2.to_v.lora_B_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.8.attn2.to_out.0.lora_A_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.8.attn2.to_out.0.lora_B_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.9.attn1.to_q.lora_A_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.9.attn1.to_q.lora_B_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.9.attn1.to_k.lora_A_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.9.attn1.to_k.lora_B_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.9.attn1.to_v.lora_A_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.9.attn1.to_v.lora_B_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.9.attn1.to_out.0.lora_A_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.9.attn1.to_out.0.lora_B_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.9.attn2.to_q.lora_A_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.9.attn2.to_q.lora_B_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.9.attn2.to_k.lora_A_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.9.attn2.to_k.lora_B_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.9.attn2.to_v.lora_A_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.9.attn2.to_v.lora_B_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.9.attn2.to_out.0.lora_A_1.default_0.weight', 'mid_block.attentions.0.transformer_blocks.9.attn2.to_out.0.lora_B_1.default_0.weight'].
Loading adapter weights from None led to unexpected keys not found in the model:  ['text_model.encoder.layers.0.self_attn.k_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.0.self_attn.k_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.0.self_attn.v_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.0.self_attn.v_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.0.self_attn.q_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.0.self_attn.q_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.0.self_attn.out_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.0.self_attn.out_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.1.self_attn.k_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.1.self_attn.k_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.1.self_attn.v_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.1.self_attn.v_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.1.self_attn.q_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.1.self_attn.q_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.1.self_attn.out_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.1.self_attn.out_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.2.self_attn.k_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.2.self_attn.k_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.2.self_attn.v_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.2.self_attn.v_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.2.self_attn.q_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.2.self_attn.q_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.2.self_attn.out_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.2.self_attn.out_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.3.self_attn.k_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.3.self_attn.k_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.3.self_attn.v_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.3.self_attn.v_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.3.self_attn.q_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.3.self_attn.q_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.3.self_attn.out_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.3.self_attn.out_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.4.self_attn.k_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.4.self_attn.k_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.4.self_attn.v_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.4.self_attn.v_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.4.self_attn.q_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.4.self_attn.q_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.4.self_attn.out_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.4.self_attn.out_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.5.self_attn.k_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.5.self_attn.k_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.5.self_attn.v_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.5.self_attn.v_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.5.self_attn.q_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.5.self_attn.q_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.5.self_attn.out_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.5.self_attn.out_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.6.self_attn.k_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.6.self_attn.k_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.6.self_attn.v_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.6.self_attn.v_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.6.self_attn.q_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.6.self_attn.q_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.6.self_attn.out_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.6.self_attn.out_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.7.self_attn.k_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.7.self_attn.k_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.7.self_attn.v_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.7.self_attn.v_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.7.self_attn.q_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.7.self_attn.q_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.7.self_attn.out_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.7.self_attn.out_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.8.self_attn.k_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.8.self_attn.k_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.8.self_attn.v_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.8.self_attn.v_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.8.self_attn.q_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.8.self_attn.q_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.8.self_attn.out_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.8.self_attn.out_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.9.self_attn.k_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.9.self_attn.k_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.9.self_attn.v_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.9.self_attn.v_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.9.self_attn.q_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.9.self_attn.q_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.9.self_attn.out_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.9.self_attn.out_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.10.self_attn.k_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.10.self_attn.k_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.10.self_attn.v_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.10.self_attn.v_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.10.self_attn.q_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.10.self_attn.q_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.10.self_attn.out_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.10.self_attn.out_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.11.self_attn.k_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.11.self_attn.k_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.11.self_attn.v_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.11.self_attn.v_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.11.self_attn.q_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.11.self_attn.q_proj.lora_B_1.default_0.weight', 'text_model.encoder.layers.11.self_attn.out_proj.lora_A_1.default_0.weight', 'text_model.encoder.layers.11.self_attn.out_proj.lora_B_1.default_0.weight'].

### System Info

Latest diffusers - master branch pulled on 2023/12/27.

OS - Linux 6.1.9

(xl) localhost /media/nvme/xl # uname -a Linux localhost 6.1.9-noinitramfs #4 SMP PREEMPT_DYNAMIC Fri Feb 10 03:01:14 -00 2023 x86_64 Intel(R) Core(TM) i5-9500T CPU @ 2.20GHz GenuineIntel GNU/Linux


python - Python 3.10.9

(xl) localhost /media/nvme/xl # diffusers-cli env Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

Who can help?

No response

hi-sushanta commented 10 months ago

please can you provide your code?

asomoza commented 10 months ago

I support this issue, I was wondering why anyone else wasn't having this problem so I was trying to test in a collab since I use a custom code, now I was finally able to test it and it happens there too, you don't need any code, just do the training as stated in the documentation and then resume:

!accelerate launch train_dreambooth_lora_sdxl.py --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" --instance_data_dir="/content/dataset" --instance_prompt="anime screencap of sks girl" --validation_prompt="anime screencap of sks girl" --train_batch_size=1 --gradient_checkpointing --gradient_accumulation_steps=1 --learning_rate=1e-4 --lr_scheduler="constant" --max_train_steps=800 --validation_epochs=25 --output_dir="/content/output_lora" --train_text_encoder --checkpointing_steps=25 --optimizer="AdamW" --use_8bit_adam --resume_from_checkpoint="checkpoint-400" --mixed_precision="fp16"

I got the same errors than @prushik and also I can see it in the images, training with just one image:

test

First run first validation: imageData First run at 400 steps: imageData Resume from 400 steps in a second run: imageData

its just obvious that it started as a clean training even though it loads the checkpoint and the state without errors.

prushik commented 10 months ago

My code is almost entirely unchanged from the example script, the only change I have made is to the DreamBoothDataset class:

localhost /media/nvme/xl/diffusers/examples/advanced_diffusion_training # diff train_dreambooth_lora_sdxl_advanced.py train_custom.py
876,877c876,884
<             instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
<             self.custom_instance_prompts = None
---
>             # Load .png images
>             instance_images = [Image.open(path) for path in self.instance_data_root.iterdir() if path.suffix == '.png']
>             #self.instance_images = instance_images
>
>             # Load .txt files and store their contents in a list
>             self.custom_instance_prompts = [path.read_text() for path in self.instance_data_root.iterdir() if path.suffix == '.txt']
>
>             #instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
>             #self.custom_instance_prompts = None

This is just to ensure that only .pngs are loaded as images and prompts are loaded from corresponding .txt files. (Yes, I should have just used jsonl files, but my data was already formatted this way and changing the code seemed easier than figuring out how the jsonl should be formatted) I very much doubt this is related to the issue.

prushik commented 10 months ago

I have some more information about the issue.

I took a look at the differences between the saved checkpoint-x/pytorch_lora_weights.safetensors file and the file final trained pytorch_lora_weights.safetensors file, and found some discrepancies between the layer names.

Since there are a lot of layers, I just chose a small subset that should be comparable. There is nothing special about the layers I chose to look at, they were chosen at random to be representative of the problem with every layer in the generated file. I looked at the layers beginning with "unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1."

a lora produced in a single training run from examples/dreambooth/train_dreambooth_lora_sdxl.py contains the following:

(xl) localhost /media/nvme/xl # python print_layers.py diffusers/examples/dreambooth/xq3/pytorch_lora_weights.safetensors | grep unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_k.lora_A.weight
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_k.lora_B.weight
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_out.0.lora_A.weight
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_out.0.lora_B.weight
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_q.lora_A.weight
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_q.lora_B.weight
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_v.lora_A.weight
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_v.lora_B.weight

However, a lora that has been produced after being resumed from a saved checkpoint has the following layers:

(xl) localhost /media/nvme/xl # python print_layers.py diffusers/examples/dreambooth/xq4/pytorch_lora_weights.safetensors | grep unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_k.lora_A.weight
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_k.lora_A_1.weight
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_k.lora_B.weight
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_k.lora_B_1.weight
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_out.0.lora_A.weight
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_out.0.lora_A_1.weight
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_out.0.lora_B.weight
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_out.0.lora_B_1.weight
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_q.lora_A.weight
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_q.lora_A_1.weight
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_q.lora_B.weight
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_q.lora_B_1.weight
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_v.lora_A.weight
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_v.lora_A_1.weight
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_v.lora_B.weight
unet.up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_v.lora_B_1.weight

Note that each layer now has a lora_A or lora_B and a lora_A_1 or lora_B_1 version. It looks like the _1 version of each layer is the one that causes the giant warning message during inference.

Now the question of where this incorrect _1 version is created. If I compare a checkpoint-x/pytorch_lora_weights.safetensors that was saved during a first run of the training script (one that was not resumed), then it does NOT contain the _1 versions of each key. However, all checkpoint-x/pytorch_lora_weights.safetensors files produced after training is resumed DO contain the _1 versions.

So it looks like the issue is introduced upon loading of the checkpoint, not saving the checkpoint.

Looking through the training script, this seems to be accomplished with the line: accelerator.load_state(os.path.join(args.output_dir, path)) so perhaps this is an accelerator bug rather than a diffusers bug...

asomoza commented 10 months ago

@prushik if you want someone from diffusers to look at this problem you should tag someone from the team in the issue, otherwise they probably won't see it.

The load_state and save_state have hooks to functions in the same script, for example the save function has an error with the get_peft_model_state_dict(model) but even changing that the state won't resume so I was also guessing that maybe its a problem with accelerator or peft and not with diffusers. When I get to the training part of my code again I will look into it if no one has done it before.

edit: my diffusers repo wasn't updated, the current version has the fix and also is not an error but it enables to load peft models when peft is not installed.

asomoza commented 10 months ago

I finally could look into it, I fixed it in my code but I don't know the correct way to fix it in the official training script since this probably need a change in the LoraLoaderMixin.

The problem as you were suspecting is that the option to resume tries to load the lora model again in the unet and the text encoders but since the code doesn't use the adapter_name when you resume the training it loads the the model as a second adapter which duplicates the final size and makes the training start from scratch.

What I did was to force the load of the state_dicts on the unet and the text_encoders directly with the inject_adapter_in_model function from the peft library. With that change I could resume and load the saved checkpoint without errors.

Hope it helps.

prushik commented 10 months ago

Awesome!

The problem as you were suspecting is that the option to resume tries to load the lora model again in the unet and the text encoders but since the code doesn't use the adapter_name when you resume the training it loads the model as a second adapter which duplicates the final size and makes the training start from scratch.

I saw that the checkpoint gets loaded with the adapter_name of "default_1" and there is already an adapter called "default" loaded at that point, and it looked like "default" didn't actually have any parameters in it. But after looking further in diffusers code it seemed like the "default_1" adapter_name was intentional (diffusers/src/diffusers/utils/peft_utils.py : get_adapter_name called from diffusers/src/diffusers/loaders/lora.py : load_lora_into_unet).

What I did was to force the load of the state_dicts on the unet and the text_encoders directly with the inject_adapter_in_model function from the peft library. With that change I could resume and load the saved checkpoint without errors.

I'm trying to understand this, it would be awesome to have a workaround. Is this change just in the training script in load_model_hook? Can we just replace LoraLoaderMixin.load_lora_into_unet with peft.inject_adapter_in_model? Would you mind posting your modification so we can at least have a temporary workaround? Sorry if I'm being dumb, this is the first time looking at peft.

Thank you for all your help!

asomoza commented 10 months ago

I was wrong with the inject_adapter_in_model, that's not needed, we just need to load the state_dict in the adapter so the function to use is just set_peft_model_state_dict. I shared some code before and it was wrong, partially because of this and also because I don't use the official training script. Finally I had to use it to debug the problems so here it is, this is a working training script with resume:

https://gist.github.com/asomoza/2a7514caceffdbc28f11da5e7f74561c

Validation at step 1 imageData

Validation at step 100 (end first run) imageData

Validation at step 125 (resume from 100) imageData

patrickvonplaten commented 10 months ago

cc @sayakpaul @apolinario

sayakpaul commented 10 months ago

Cannot comment for the advanced script. Cc: @linoytsaban (as Poli is on leave).

But for the SDXL LoRA script, have you tried pulling in the latest changes? We have had issues like https://github.com/huggingface/diffusers/issues/6087 but we have fixed them. Could you please ensure you're using the latest version of the script, please?

github-actions[bot] commented 9 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

sayakpaul commented 9 months ago

Is this still an issue? Cc: @linoytsaban

sayakpaul commented 9 months ago

Is this still a problem? Cc: @linoytsaban

linoytsaban commented 9 months ago

No longer an issue I believe πŸ‘πŸ» as the changes made in https://github.com/huggingface/diffusers/pull/6225 were also put in place for the advanced scripts (both sdxl and the new sd 1.5 implementation)

sayakpaul commented 9 months ago

Going to close then. But feel free to re-open.