huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.85k stars 5.33k forks source link

LCM example training fails Unexpected key(s) in state_dict: "time_embedding.cond_proj.weight" #6790

Closed shadowofsoul closed 6 months ago

shadowofsoul commented 8 months ago

Describe the bug

Hi!

i'm follwing this guide for LCM training to produce models that can run ineference with 4 steps and LCMScheduler: https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/README.md

but the example of training model fails with:

RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for UNet2DConditionModel: Unexpected key(s) in state_dict: "time_embedding.cond_proj.weight".

not sure what i'm doing wrong, if someone has any hint, let me know!

Reproduction

Literally follow the example at https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/README.md

Logs

No response

System Info

Who can help?

@sayakpaul @patrickvonplaten

sayakpaul commented 8 months ago

What is the trainig command?

shadowofsoul commented 8 months ago

Literally, the same in the guide

export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="lcm"

accelerate launch train_lcm_distill_sd_wds.py \
    --pretrained_teacher_model=$MODEL_NAME \
    --output_dir=$OUTPUT_DIR \
    --mixed_precision=fp16 \
    --resolution=512 \
    --learning_rate=1e-6 --loss_type="huber" --ema_decay=0.95 --adam_weight_decay=0.0 \
    --max_train_steps=1000 \
    --max_train_samples=4000000 \
    --dataloader_num_workers=8 \
    --train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
    --validation_steps=200 \
    --checkpointing_steps=200 --checkpoints_total_limit=10 \
    --train_batch_size=12 \
    --gradient_checkpointing --enable_xformers_memory_efficient_attention \
    --gradient_accumulation_steps=1 \
    --use_8bit_adam \
    --resume_from_checkpoint=latest \
    --report_to=wandb \
    --seed=453645634 

of course, with a correct output path and i removed '--push_to_hub'

sayakpaul commented 8 months ago

I will let @patil-suraj comment further.

dg845 commented 8 months ago

Hi @shadowofsoul, would you be able to share the location in the script where you get the RuntimeError?

shadowofsoul commented 8 months ago

Here is the complete trace @dg845


/home/test-models/tinysd/diffusers/src/diffusers/utils/outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  torch.utils._pytree._register_pytree_node(
/home/test-models/tinysd/lib/python3.8/site-packages/accelerate/accelerator.py:393: UserWarning: `log_with=wandb` was passed but no supported trackers are currently installed.
  warnings.warn(f"`log_with={log_with}` was passed but no supported trackers are currently installed.")
02/04/2024 01:03:51 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cpu

Mixed precision type: fp16

{'prediction_type', 'timestep_spacing', 'thresholding', 'sample_max_value', 'rescale_betas_zero_snr', 'variance_type', 'clip_sample_range', 'dynamic_thresholding_ratio'} was not found in config. Values will be initialized to default values.
{'scaling_factor', 'force_upcast'} was not found in config. Values will be initialized to default values.
{'only_cross_attention', 'class_embed_type', 'dropout', 'resnet_time_scale_shift', 'resnet_skip_time_act', 'time_cond_proj_dim', 'addition_embed_type_num_heads', 'resnet_out_scale_factor', 'transformer_layers_per_block', 'time_embedding_act_fn', 'upcast_attention', 'timestep_post_act', 'time_embedding_dim', 'num_attention_heads', 'conv_out_kernel', 'conv_in_kernel', 'reverse_transformer_layers_per_block', 'encoder_hid_dim_type', 'time_embedding_type', 'use_linear_projection', 'addition_time_embed_dim', 'projection_class_embeddings_input_dim', 'encoder_hid_dim', 'class_embeddings_concat', 'mid_block_only_cross_attention', 'addition_embed_type', 'cross_attention_norm', 'dual_cross_attention', 'attention_type', 'num_class_embeds', 'mid_block_type'} was not found in config. Values will be initialized to default values.
{'only_cross_attention', 'class_embed_type', 'dropout', 'resnet_time_scale_shift', 'resnet_skip_time_act', 'addition_embed_type_num_heads', 'resnet_out_scale_factor', 'transformer_layers_per_block', 'time_embedding_act_fn', 'upcast_attention', 'timestep_post_act', 'time_embedding_dim', 'num_attention_heads', 'conv_out_kernel', 'conv_in_kernel', 'reverse_transformer_layers_per_block', 'encoder_hid_dim_type', 'time_embedding_type', 'use_linear_projection', 'addition_time_embed_dim', 'projection_class_embeddings_input_dim', 'encoder_hid_dim', 'class_embeddings_concat', 'mid_block_only_cross_attention', 'addition_embed_type', 'cross_attention_norm', 'dual_cross_attention', 'attention_type', 'num_class_embeds', 'mid_block_type'} was not found in config. Values will be initialized to default values.
Traceback (most recent call last):
  File "train_lcm_distill_sd_wds.py", line 1411, in <module>
    main(args)
  File "train_lcm_distill_sd_wds.py", line 937, in main
    target_unet.load_state_dict(unet.state_dict())
  File "/home/test-models/tinysd/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2153, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for UNet2DConditionModel:
        Unexpected key(s) in state_dict: "time_embedding.cond_proj.weight".
Traceback (most recent call last):
  File "/home/test-models/tinysd/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/home/test-models/tinysd/lib/python3.8/site-packages/accelerate/commands/accelerate_cli.py", line 47, in main
    args.func(args)
  File "/home/test-models/tinysd/lib/python3.8/site-packages/accelerate/commands/launch.py", line 1023, in launch_command
    simple_launcher(args)
  File "/home/test-models/tinysd/lib/python3.8/site-packages/accelerate/commands/launch.py", line 643, in simple_launcher
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['/home/test-models/tinysd/bin/python3', 'train_lcm_distill_sd_wds.py', '--pretrained_teacher_model=runwayml/stable-diffusion-v1-5', '--output_dir=lcm', '--mixed_precision=fp16', '--resolution=512', '--learning_rate=1e-6', '--loss_type=huber', '--ema_decay=0.95', '--adam_weight_decay=0.0', '--max_train_steps=1000', '--max_train_samples=4000000', '--dataloader_num_workers=8', '--train_shards_path_or_url=pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true', '--validation_steps=200', '--checkpointing_steps=200', '--checkpoints_total_limit=10', '--train_batch_size=12', '--gradient_checkpointing', '--enable_xformers_memory_efficient_attention', '--gradient_accumulation_steps=1', '--use_8bit_adam', '--resume_from_checkpoint=latest', '--report_to=wandb', '--seed=453645634']' returned non-zero exit status 1.
shadowofsoul commented 8 months ago

I want to report that the training still fails even with the patch at #6848 (running this on a google colab now, as xformers requires GPU) @dg845


{'mid_block_only_cross_attention', 'class_embeddings_concat', 'reverse_transformer_layers_per_block', 'addition_time_embed_dim', 'projection_class_embeddings_input_dim', 'use_linear_projection', 'timestep_post_act', 'encoder_hid_dim_type', 'upcast_attention', 'attention_type', 'cross_attention_norm', 'resnet_out_scale_factor', 'time_cond_proj_dim', 'encoder_hid_dim', 'time_embedding_dim', 'dropout', 'addition_embed_type', 'time_embedding_type', 'num_attention_heads', 'dual_cross_attention', 'conv_out_kernel', 'addition_embed_type_num_heads', 'resnet_time_scale_shift', 'resnet_skip_time_act', 'time_embedding_act_fn', 'num_class_embeds', 'conv_in_kernel', 'class_embed_type', 'only_cross_attention', 'transformer_layers_per_block', 'mid_block_type'} was not found in config. Values will be initialized to default values.
{'mid_block_only_cross_attention', 'class_embeddings_concat', 'reverse_transformer_layers_per_block', 'addition_time_embed_dim', 'projection_class_embeddings_input_dim', 'use_linear_projection', 'timestep_post_act', 'encoder_hid_dim_type', 'upcast_attention', 'attention_type', 'cross_attention_norm', 'resnet_out_scale_factor', 'encoder_hid_dim', 'time_embedding_dim', 'dropout', 'addition_embed_type', 'num_attention_heads', 'time_embedding_type', 'dual_cross_attention', 'conv_out_kernel', 'addition_embed_type_num_heads', 'resnet_time_scale_shift', 'resnet_skip_time_act', 'time_embedding_act_fn', 'num_class_embeds', 'conv_in_kernel', 'class_embed_type', 'only_cross_attention', 'transformer_layers_per_block', 'mid_block_type'} was not found in config. Values will be initialized to default values.
Traceback (most recent call last):
  File "/content/diffusers/examples/consistency_distillation/train_lcm_distill_sd_wds.py", line 1411, in <module>
    main(args)
  File "/content/diffusers/examples/consistency_distillation/train_lcm_distill_sd_wds.py", line 937, in main
    target_unet.load_state_dict(unet.state_dict())
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for UNet2DConditionModel:
    Unexpected key(s) in state_dict: "time_embedding.cond_proj.weight". 
    size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([320, 768]) from checkpoint, the shape in current model is torch.Size([320, 1280]).
    size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([320, 768]) from checkpoint, the shape in current model is torch.Size([320, 1280]).
    size mismatch for down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([320, 768]) from checkpoint, the shape in current model is torch.Size([320, 1280]).
    size mismatch for down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([320, 768]) from checkpoint, the shape in current model is torch.Size([320, 1280]).
    size mismatch for down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([640, 768]) from checkpoint, the shape in current model is torch.Size([640, 1280]).
    size mismatch for down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([640, 768]) from checkpoint, the shape in current model is torch.Size([640, 1280]).
    size mismatch for down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([640, 768]) from checkpoint, the shape in current model is torch.Size([640, 1280]).
    size mismatch for down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([640, 768]) from checkpoint, the shape in current model is torch.Size([640, 1280]).
    size mismatch for down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([1280, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
    size mismatch for down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([1280, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
    size mismatch for down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([1280, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
    size mismatch for down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([1280, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
    size mismatch for up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([1280, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
    size mismatch for up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([1280, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
    size mismatch for up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([1280, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
    size mismatch for up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([1280, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
    size mismatch for up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([1280, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
    size mismatch for up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([1280, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
    size mismatch for up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([640, 768]) from checkpoint, the shape in current model is torch.Size([640, 1280]).
    size mismatch for up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([640, 768]) from checkpoint, the shape in current model is torch.Size([640, 1280]).
    size mismatch for up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([640, 768]) from checkpoint, the shape in current model is torch.Size([640, 1280]).
    size mismatch for up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([640, 768]) from checkpoint, the shape in current model is torch.Size([640, 1280]).
    size mismatch for up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([640, 768]) from checkpoint, the shape in current model is torch.Size([640, 1280]).
    size mismatch for up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([640, 768]) from checkpoint, the shape in current model is torch.Size([640, 1280]).
    size mismatch for up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([320, 768]) from checkpoint, the shape in current model is torch.Size([320, 1280]).
    size mismatch for up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([320, 768]) from checkpoint, the shape in current model is torch.Size([320, 1280]).
    size mismatch for up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([320, 768]) from checkpoint, the shape in current model is torch.Size([320, 1280]).
    size mismatch for up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([320, 768]) from checkpoint, the shape in current model is torch.Size([320, 1280]).
    size mismatch for up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([320, 768]) from checkpoint, the shape in current model is torch.Size([320, 1280]).
    size mismatch for up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([320, 768]) from checkpoint, the shape in current model is torch.Size([320, 1280]).
    size mismatch for mid_block.attentions.0.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([1280, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
    size mismatch for mid_block.attentions.0.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([1280, 768]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
Traceback (most recent call last):
  File "/usr/local/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/accelerate_cli.py", line 47, in main
    args.func(args)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/launch.py", line 1023, in launch_command
    simple_launcher(args)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/launch.py", line 643, in simple_launcher
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
dg845 commented 8 months ago

Does something like

    target_unet = UNet2DConditionModel.from_config(unet.config)

instead of

    target_unet = UNet2DConditionModel(**unet.config)

work?

shadowofsoul commented 8 months ago

@dg845 Yes, i can confirm that the training started. I will let it run and comment after with the results. Thanks!

rohit901 commented 8 months ago

@dg845 thanks for the patch. I hope it gets merged soon since even I faced the issue reported by @shadowofsoul

github-actions[bot] commented 7 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

sayakpaul commented 7 months ago

@dg845 did you open a PR for it?

dg845 commented 7 months ago

@sayakpaul, yes, at #6848.

github-actions[bot] commented 6 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.