huggingface / diffusers

๐Ÿค— Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
26.05k stars 5.36k forks source link

class_labels should be provided when num_class_embeds > 0 #3401

Closed stpg06 closed 1 year ago

stpg06 commented 1 year ago

Describe the bug

Running into this issue and not entirely sure what i am doing wrong, do I need a metadata file or something?

Reproduction

!python3 train_dreambooth_lora.py \ --pretrained_model_name_or_path="DeepFloyd/IF-II-L-v1.0" \ --output_dir="/content/drive/MyDrive/stable_diffusion_weights/zwx23" \ --revision="main" \ --instance_data_dir="/content/i/" \ --class_data_dir="/content/c/" \ --instance_prompt="photo of a zwx man" \ --class_prompt="photo of a man" \ --validation_prompt="painting of a zwx man in a tuxedo" \ --max_train_steps=5000 \ --seed=1378 \ --checkpoints_total_limit=2 \ --num_class_images=20 \ --resolution=512 \ --train_batch_size=2 \ --pre_compute_text_embeddings \ --sample_batch_size=4 \ --mixed_precision="fp16" \ --num_validation_images=4 \ --validation_epochs=50 \ --gradient_accumulation_steps=1 \ --learning_rate=1e-4 \ --checkpointing_steps=2500 \ --lr_scheduler="constant" \ --lr_warmup_steps=0

Logs

2023-05-10 23:25:54.719178: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py:258: FutureWarning: `logging_dir` is deprecated and will be removed in version 0.18.0 of ๐Ÿค— Accelerate. Use `project_dir` instead.
  warnings.warn(
05/10/2023 23:25:57 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: fp16

You are using a model of type t5 to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
Downloading shards: 100% 2/2 [00:00<00:00, 9998.34it/s]
Loading checkpoint shards: 100% 2/2 [00:26<00:00, 13.05s/it]
{'mid_block_only_cross_attention', 'addition_embed_type_num_heads', 'class_embeddings_concat'} was not found in config. Values will be initialized to default values.
05/10/2023 23:27:11 - INFO - __main__ - ***** Running training *****
05/10/2023 23:27:11 - INFO - __main__ -   Num examples = 4
05/10/2023 23:27:11 - INFO - __main__ -   Num batches each epoch = 2
05/10/2023 23:27:11 - INFO - __main__ -   Num Epochs = 2500
05/10/2023 23:27:11 - INFO - __main__ -   Instantaneous batch size per device = 2
05/10/2023 23:27:11 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 2
05/10/2023 23:27:11 - INFO - __main__ -   Gradient Accumulation steps = 1
05/10/2023 23:27:11 - INFO - __main__ -   Total optimization steps = 5000
Steps:   0% 0/5000 [00:00<?, ?it/s]โ•ญโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Traceback (most recent call last) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ
โ”‚ /content/train_dreambooth_lora.py:1295 in <module>                           โ”‚
โ”‚                                                                              โ”‚
โ”‚   1292                                                                       โ”‚
โ”‚   1293 if __name__ == "__main__":                                            โ”‚
โ”‚   1294 โ”‚   args = parse_args()                                               โ”‚
โ”‚ โฑ 1295 โ”‚   main(args)                                                        โ”‚
โ”‚   1296                                                                       โ”‚
โ”‚                                                                              โ”‚
โ”‚ /content/train_dreambooth_lora.py:1084 in main                               โ”‚
โ”‚                                                                              โ”‚
โ”‚   1081 โ”‚   โ”‚   โ”‚   โ”‚   โ”‚   )                                                 โ”‚
โ”‚   1082 โ”‚   โ”‚   โ”‚   โ”‚                                                         โ”‚
โ”‚   1083 โ”‚   โ”‚   โ”‚   โ”‚   # Predict the noise residual                          โ”‚
โ”‚ โฑ 1084 โ”‚   โ”‚   โ”‚   โ”‚   model_pred = unet(noisy_model_input, timesteps, encod โ”‚
โ”‚   1085 โ”‚   โ”‚   โ”‚   โ”‚                                                         โ”‚
โ”‚   1086 โ”‚   โ”‚   โ”‚   โ”‚   # if model predicts variance, throw away the predicti โ”‚
โ”‚   1087 โ”‚   โ”‚   โ”‚   โ”‚   # simplified training objective. This means that all  โ”‚
โ”‚                                                                              โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in   โ”‚
โ”‚ _call_impl                                                                   โ”‚
โ”‚                                                                              โ”‚
โ”‚   1498 โ”‚   โ”‚   if not (self._backward_hooks or self._backward_pre_hooks or s โ”‚
โ”‚   1499 โ”‚   โ”‚   โ”‚   โ”‚   or _global_backward_pre_hooks or _global_backward_hoo โ”‚
โ”‚   1500 โ”‚   โ”‚   โ”‚   โ”‚   or _global_forward_hooks or _global_forward_pre_hooks โ”‚
โ”‚ โฑ 1501 โ”‚   โ”‚   โ”‚   return forward_call(*args, **kwargs)                      โ”‚
โ”‚   1502 โ”‚   โ”‚   # Do not call functions when jit is used                      โ”‚
โ”‚   1503 โ”‚   โ”‚   full_backward_hooks, non_full_backward_hooks = [], []         โ”‚
โ”‚   1504 โ”‚   โ”‚   backward_pre_hooks = []                                       โ”‚
โ”‚                                                                              โ”‚
โ”‚ /usr/local/lib/python3.10/dist-packages/diffusers/models/unet_2d_condition.p โ”‚
โ”‚ y:691 in forward                                                             โ”‚
โ”‚                                                                              โ”‚
โ”‚   688 โ”‚   โ”‚                                                                  โ”‚
โ”‚   689 โ”‚   โ”‚   if self.class_embedding is not None:                           โ”‚
โ”‚   690 โ”‚   โ”‚   โ”‚   if class_labels is None:                                   โ”‚
โ”‚ โฑ 691 โ”‚   โ”‚   โ”‚   โ”‚   raise ValueError("class_labels should be provided when โ”‚
โ”‚   692 โ”‚   โ”‚   โ”‚                                                              โ”‚
โ”‚   693 โ”‚   โ”‚   โ”‚   if self.config.class_embed_type == "timestep":             โ”‚
โ”‚   694 โ”‚   โ”‚   โ”‚   โ”‚   class_labels = self.time_proj(class_labels)            โ”‚
โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ
ValueError: class_labels should be provided when num_class_embeds > 0
Steps:   0% 0/5000 [00:00<?, ?it/s]

System Info

Google colab

patrickvonplaten commented 1 year ago

cc @williamberman for IF Dreambooth

williamberman commented 1 year ago

hey @stpg06 this script currently doesn't support training the stage II model, please only train the stage I model for now. I'll be putting up some changes to support the stage II model later today :)

github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

shnamin commented 1 year ago

Hi, wondering what the state of this is. I'm encountering the same error. What is the workaround for now?

patrickvonplaten commented 1 year ago

Gentle ping @williamberman

patrickvonplaten commented 1 year ago

Gentle ping again @williamberman

williamberman commented 1 year ago

For IF stage II, you must pass --class_labels_conditioning timestep to the training script. I double checked the readme for training and confirmed it's documented there and in all examples for training stage II. I'm going to close the issue but if there's an alternative location it's not documented please reopen the issue and let me know where or open a separate issue and tag me so I can add it to those docs :) thanks!