huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.66k stars 5.3k forks source link

RuntimeError: Error(s) in loading state_dict for CLIPTextModelWithProjection #8955

Open dann2333 opened 3 months ago

dann2333 commented 3 months ago

Describe the bug

Hi, I'm trying to fine-tuning stabilityai/stable-diffusion-3-medium-diffusers and using the official diffuser scripts. The train process was normal expect the loss cannot reduce. I hope to add a validation prompt to see if every thing works ok, so I used Ctrl-C stopped the training process and then added the --validation_prompt and --validation_epochs params. However, when I tried to re-start the train I only found the error below. I tried to use other checkpoints and delete that two params, however no one works.

Reproduction

Here are the checkpoints link: https://drive.google.com/drive/folders/16RbJa_W4H7aQiGf7QhTXVEJV53LPuS8n?usp=sharing , https://drive.google.com/drive/folders/1zT3LmB7SNtavHP3tgbodTgE13VT0cYvb?usp=sharing The train command is: accelerate launch train_dreambooth_lora_sd3.py --pretrained_model_name_or_path="stabilityai/stable-diffusion-3-medium-diffusers" --output_dir=sd3-lora --instance_data_dir="sample-imgs" --instance_prompt="xxx" --resolution=900 --train_batch_size=1 --train_text_encoder --gradient_accumulation_steps=16 --optimizer="adamw" --learning_rate=1e-6 --text_encoder_lr=1e-6 --lr_scheduler="cosine" --lr_warmup_steps=500 --max_train_steps=4000 --rank=32 --seed="42" --gradient_checkpointing --resume_from_checkpoint latest --center_crop --report_to="wandb" --checkpointing_steps 20 --checkpoints_total_limit 3 --validation_prompt="xxx" --validation_epochs=1

Logs

07/24/2024 08:16:14 - INFO - __main__ - ***** Running training *****
07/24/2024 08:16:14 - INFO - __main__ -   Num examples = 93
07/24/2024 08:16:14 - INFO - __main__ -   Num batches each epoch = 93
07/24/2024 08:16:14 - INFO - __main__ -   Num Epochs = 667
07/24/2024 08:16:14 - INFO - __main__ -   Instantaneous batch size per device = 1
07/24/2024 08:16:14 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 16
07/24/2024 08:16:14 - INFO - __main__ -   Gradient Accumulation steps = 16
07/24/2024 08:16:14 - INFO - __main__ -   Total optimization steps = 4000
Resuming from checkpoint checkpoint-1100
07/24/2024 08:16:14 - INFO - accelerate.accelerator - Loading states from sd3-lora/checkpoint-1100
Traceback (most recent call last):
  File "/workspace/sd3-fine-tune/train_dreambooth_lora_sd3.py", line 1876, in <module>
    main(args)
  File "/workspace/sd3-fine-tune/train_dreambooth_lora_sd3.py", line 1576, in main
    accelerator.load_state(os.path.join(args.output_dir, path))
  File "/opt/conda/lib/python3.10/site-packages/accelerate/accelerator.py", line 3131, in load_state
    hook(models, input_dir)
  File "/workspace/sd3-fine-tune/train_dreambooth_lora_sd3.py", line 1291, in load_model_hook
    _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
  File "/root/diffusers/src/diffusers/training_utils.py", line 221, in _set_state_dict_into_text_encoder
    set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
  File "/opt/conda/lib/python3.10/site-packages/peft/utils/save_and_load.py", line 353, in set_peft_model_state_dict
    load_result = model.load_state_dict(peft_model_state_dict, strict=False)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2153, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for CLIPTextModelWithProjection:
        size mismatch for text_model.encoder.layers.0.self_attn.k_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.0.self_attn.k_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.0.self_attn.v_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.0.self_attn.v_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.0.self_attn.q_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.0.self_attn.q_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.0.self_attn.out_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.0.self_attn.out_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.1.self_attn.k_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.1.self_attn.k_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.1.self_attn.v_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.1.self_attn.v_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.1.self_attn.q_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.1.self_attn.q_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.1.self_attn.out_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.1.self_attn.out_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.2.self_attn.k_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.2.self_attn.k_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.2.self_attn.v_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.2.self_attn.v_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.2.self_attn.q_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.2.self_attn.q_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.2.self_attn.out_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.2.self_attn.out_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.3.self_attn.k_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.3.self_attn.k_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.3.self_attn.v_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.3.self_attn.v_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.3.self_attn.q_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.3.self_attn.q_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.3.self_attn.out_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.3.self_attn.out_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.4.self_attn.k_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.4.self_attn.k_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.4.self_attn.v_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.4.self_attn.v_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.4.self_attn.q_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.4.self_attn.q_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.4.self_attn.out_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.4.self_attn.out_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.5.self_attn.k_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.5.self_attn.k_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.5.self_attn.v_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.5.self_attn.v_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.5.self_attn.q_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.5.self_attn.q_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.5.self_attn.out_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.5.self_attn.out_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.6.self_attn.k_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.6.self_attn.k_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.6.self_attn.v_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.6.self_attn.v_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.6.self_attn.q_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.6.self_attn.q_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.6.self_attn.out_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.6.self_attn.out_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.7.self_attn.k_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.7.self_attn.k_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.7.self_attn.v_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.7.self_attn.v_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.7.self_attn.q_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.7.self_attn.q_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.7.self_attn.out_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.7.self_attn.out_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.8.self_attn.k_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.8.self_attn.k_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.8.self_attn.v_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.8.self_attn.v_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.8.self_attn.q_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.8.self_attn.q_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.8.self_attn.out_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.8.self_attn.out_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.9.self_attn.k_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.9.self_attn.k_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.9.self_attn.v_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.9.self_attn.v_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.9.self_attn.q_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.9.self_attn.q_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.9.self_attn.out_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.9.self_attn.out_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.10.self_attn.k_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.10.self_attn.k_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.10.self_attn.v_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.10.self_attn.v_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.10.self_attn.q_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.10.self_attn.q_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.10.self_attn.out_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.10.self_attn.out_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.11.self_attn.k_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.11.self_attn.k_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.11.self_attn.v_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.11.self_attn.v_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.11.self_attn.q_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.11.self_attn.q_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
        size mismatch for text_model.encoder.layers.11.self_attn.out_proj.lora_A.default.weight: copying a param with shape torch.Size([32, 1280]) from checkpoint, the shape in current model is torch.Size([32, 768]).
        size mismatch for text_model.encoder.layers.11.self_attn.out_proj.lora_B.default.weight: copying a param with shape torch.Size([1280, 32]) from checkpoint, the shape in current model is torch.Size([768, 32]).
Traceback (most recent call last):
  File "/opt/conda/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/opt/conda/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 48, in main
    args.func(args)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1106, in launch_command
    simple_launcher(args)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/commands/launch.py", line 704, in simple_launcher
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['/opt/conda/bin/python', 'train_dreambooth_lora_sd3.py', '--pretrained_model_name_or_path=stabilityai/stable-diffusion-3-medium-diffusers', '--output_dir=coconut-lora', '--instance_data_dir=sample-imgs', '--instance_prompt=xxx', '--resolution=900', '--train_batch_size=1', '--train_text_encoder', '--gradient_accumulation_steps=16', '--optimizer=adamw', '--learning_rate=1e-6', '--text_encoder_lr=1e-6', '--lr_scheduler=cosine', '--lr_warmup_steps=500', '--max_train_steps=4000', '--rank=32', '--seed=42', '--gradient_checkpointing', '--resume_from_checkpoint', 'latest', '--center_crop', '--report_to=wandb', '--checkpointing_steps', '20', '--checkpoints_total_limit', '3', '--validation_prompt=xxx', '--validation_epochs=1']' returned non-zero exit status 1.

System Info

🤗 Diffusers version: 0.30.0.dev0 Platform: Linux-6.5.0-25-generic-x86_64-with-glibc2.35 Running on a notebook?: No Running on Google Colab?: No Python version: 3.10.13 PyTorch version (GPU?): 2.2.1 (True) Flax version (CPU?/GPU?/TPU?): not installed (NA) Jax version: not installed JaxLib version: not installed Huggingface_hub version: 0.24.0 Transformers version: 4.42.4 Accelerate version: 0.32.1 PEFT version: 0.11.1 Bitsandbytes version: not installed Safetensors version: 0.4.3 xFormers version: not installed Accelerator: NVIDIA L40s, 49152 MiB VRAM Using GPU in script?: NVIDIA L40s, 49152 MiB VRAM Using distributed or parallel set-up in script?: No

Who can help?

@sayakpaul

sayakpaul commented 3 months ago

However, when I tried to re-start the train I only found the error below. I tried to use other checkpoints and delete that two params, however no one works.

What does this mean in code?

sayakpaul commented 3 months ago

@linoytsaban could you check if you're able to reproduce this with text encoder training?

dann2333 commented 3 months ago

However, when I tried to re-start the train I only found the error below. I tried to use other checkpoints and delete that two params, however no one works.

What does this mean in code?

accelerate launch train_dreambooth_lora_sd3.py --pretrained_model_name_or_path="stabilityai/stable-diffusion-3-medium-diffusers" --output_dir=sd3-lora --instance_data_dir="sample-imgs" --instance_prompt="xxx" --resolution=900 --train_batch_size=1 --train_text_encoder --gradient_accumulation_steps=16 --optimizer="adamw" --learning_rate=1e-6 --text_encoder_lr=1e-6 --lr_scheduler="cosine" --lr_warmup_steps=500 --max_train_steps=4000 --rank=32 --seed="42" --gradient_checkpointing --resume_from_checkpoint latest --center_crop --report_to="wandb" --checkpointing_steps 20 --checkpoints_total_limit 3 --validation_prompt="xxx" --validation_epochs=1 Because the prompt is a bit long so I cutted them. If the accurate prompt is needed, just reply. Thanks

linoytsaban commented 2 months ago

Hey @dann2333 thanks for reporting!

The issue is when resuming from checkpoint (I was able to reproduce), I'm not sure yet as to why that happens but basically when training with checkpointing such as -

--gradient_checkpointing 
--resume_from_checkpoint="latest"

specifically, this call errors

sevenclay commented 2 months ago

Hey @dann2333 thanks for reporting!

The issue is when resuming from checkpoint (I was able to reproduce), I'm not sure yet as to why that happens but basically when training with checkpointing such as -

--gradient_checkpointing 
--resume_from_checkpoint="latest"

specifically, this call errors

Since both the text_encoder_one and text_encoder_two classes are CLIPTextModelWithProjection, there is an issue when saving the models in the save_model_hook:

isinstance(model, type(unwrap_model(text_encoder_one))) 
isinstance(model, type(unwrap_model(text_encoder_two)))

They are of the same type, so text_encoder_two will overwrite text_encoder_one. A small workaround is:

elif isinstance(model, type(unwrap_model(text_encoder_one))) and model.config.hidden_size == 768:
tianbuwei commented 2 months ago

Hey @dann2333 thanks for reporting! The issue is when resuming from checkpoint (I was able to reproduce), I'm not sure yet as to why that happens but basically when training with checkpointing such as -

--gradient_checkpointing 
--resume_from_checkpoint="latest"

specifically, this call errors

Since both the text_encoder_one and text_encoder_two classes are CLIPTextModelWithProjection, there is an issue when saving the models in the save_model_hook:

isinstance(model, type(unwrap_model(text_encoder_one))) 
isinstance(model, type(unwrap_model(text_encoder_two)))

They are of the same type, so text_encoder_two will overwrite text_encoder_one. A small workaround is:

elif isinstance(model, type(unwrap_model(text_encoder_one))) and model.config.hidden_size == 768:

Really? You're so smart

github-actions[bot] commented 1 month ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.