huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.88k stars 5.34k forks source link

SDXL LoRA training error with train_text_encoder #4340

Closed jorgemcgomes closed 1 year ago

jorgemcgomes commented 1 year ago

Describe the bug

wrt train_dreambooth_lora_sdxl.py

SDXL unet is conditioned on the following from the text_encoders:

When creating the LoRA adapter using provided training script, both text encoders are fully adapted: https://github.com/huggingface/diffusers/blob/306a7bd0475c4af03024057277b6454855e9ea1b/examples/dreambooth/train_dreambooth_lora_sdxl.py#L776-L781

So, when training the text encoders, there is no gradient for the last layer of encoder ONE, since the hidden states only require up to the penultimate layer, and the pooled hidden states are not used for encoder one.

This results in a pytorch error:

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. Since `find_unused_parameters=True` is enabled, this likely  means that not all `forward` outputs participate in computing loss. You can fix this by making sure all `forward` function outputs participate in calculating loss. 
If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
Parameters which did not receive grad for rank 0: text_model.encoder.layers.11.self_attn.out_proj.lora_linear_layer.up.weight, text_model.encoder.layers.11.self_attn.out_proj.lora_linear_layer.down.weight, text_model.encoder.layers.11.self_attn.q_proj.lora_linear_layer.up.weight, text_model.encoder.layers.11.self_attn.q_proj.lora_linear_layer.down.weight, text_model.encoder.layers.11.self_attn.v_proj.lora_linear_layer.up.weight, text_model.encoder.layers.11.self_attn.v_proj.lora_linear_layer.down.weight, text_model.encoder.layers.11.self_attn.k_proj.lora_linear_layer.up.weight, text_model.encoder.layers.11.self_attn.k_proj.lora_linear_layer.down.weight
Parameter indices which did not receive grad for rank 0: 88 89 90 91 92 93 94 95

Reproduction

I can't provide code easily, but I hope the explanation and references are self-sufficient.

Logs

No response

System Info

Who can help?

@sayakpaul

sayakpaul commented 1 year ago

Hi @jorgemcgomes, with the current training script, we haven't sadly observed this behavior. So, we need some means of being able to reproduce the problem.

jorgemcgomes commented 1 year ago

Hi @sayakpaul . Ok, let me try to be more helpful.

I replicated the code provided in the README in my setup (torch 2.0.1, RTX A6000, CUDA 11.7, see above). And got the same error.

The thing is that the error does make logical sense. It is impossible to have gradients for the last layer of encoder one. And that last layer is clearly included in the parameters passed to the optimizer. I think the question is: how did this ever work in the first place?

from huggingface_hub import snapshot_download

local_dir = "./dog"
snapshot_download(
    "diffusers/dog-example",
    local_dir=local_dir, repo_type="dataset",
    ignore_patterns=".gitattributes",
)

!accelerate launch diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py \
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0"  \
  --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
  --instance_data_dir="dog" \
  --output_dir="lora-trained-xl-test" \
  --mixed_precision="fp16" \
  --instance_prompt="a photo of sks dog" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --learning_rate=1e-4 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=500 \
  --validation_prompt="A photo of sks dog in a bucket" \
  --validation_epochs=25 \
  --seed="0" \
  --train_text_encoder
07/28/2023 14:49:33 - INFO - __main__ - ***** Running training *****
07/28/2023 14:49:33 - INFO - __main__ -   Num examples = 5
07/28/2023 14:49:33 - INFO - __main__ -   Num batches each epoch = 5
07/28/2023 14:49:33 - INFO - __main__ -   Num Epochs = 250
07/28/2023 14:49:33 - INFO - __main__ -   Instantaneous batch size per device = 1
07/28/2023 14:49:33 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 4
07/28/2023 14:49:33 - INFO - __main__ -   Gradient Accumulation steps = 4
07/28/2023 14:49:33 - INFO - __main__ -   Total optimization steps = 500
Steps:   0%|                     | 0/500 [00:02<?, ?it/s, loss=0.186, lr=0.0001]Traceback (most recent call last):
  File "/workspace/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py", line 1351, in <module>
    main(args)
  File "/workspace/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py", line 1098, in main
    prompt_embeds, pooled_prompt_embeds = encode_prompt(
  File "/workspace/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py", line 555, in encode_prompt
    prompt_embeds = text_encoder(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1139, in forward
    if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by 
making sure all `forward` function outputs participate in calculating loss. 
If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
Parameter indices which did not receive grad for rank 0: 88 89 90 91 92 93 94 95
 In addition, you can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print out information about which particular parameters did not receive gradient on this rank as part of this error
patrickvonplaten commented 1 year ago

Ah yeah that's interesting! I remember vaguely from transformers that layers that have weights that are not used for training can lead to errors in multi processing environments. We probably need to do something semi hacky here where we remove the last layer of the first encoder and add it back later

jorgemcgomes commented 1 year ago

Ah yeah that's interesting! I remember vaguely from transformers that layers that have weights that are not used for training can lead to errors in multi processing environments. We probably need to do something semi hacky here where we remove the last layer of the first encoder and add it back later

This was my solution FYI. It has the "layers.11" hardcoded which is a bit hacky, but it doesn't cause any problems with LoRA/model loading, saving, etc down the line.

    # Optimizer creation
    params_to_optimize = []
    for name, model in (
        [
            ("unet", unet),
            ("text_encoder_one", text_encoder_one),
            ("text_encoder_two", text_encoder_two),
        ]
        if args.train_text_encoder
        else [("unet", unet)]
    ):
        if name == "text_encoder_one":
            for n, p in model.named_parameters():
                if "layers.11" in n:
                    p.requires_grad_(False)
        params = [p for p in model.parameters() if p.requires_grad]
        total_params = sum(p.numel() for p in params)
        logger.info(f"Training {len(params)} params in {name}, total of {total_params} weights.")
        params_to_optimize.extend(params)
sayakpaul commented 1 year ago

I think that's still okay. Would you maybe like to open a PR for this?

sayakpaul commented 1 year ago

I am unable to reproduce this with the following on a single-GPU machine, though:

accelerate launch train_dreambooth_lora_sdxl.py \
  --pretrained_model_name_or_path="diffusers/stable-diffusion-xl-base-0.9"  \
  --instance_data_dir="dog" \
  --output_dir="dog-lora" \
  --mixed_precision="fp16" \
  --instance_prompt="a photo of sks dog" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --learning_rate=1e-4 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=500 \
  --train_text_encoder \
  --validation_prompt="A photo of sks dog in a bucket" \
  --validation_epochs=25 \
  --seed="0" \
  --push_to_hub

WandB run page: https://wandb.ai/sayakpaul/dreambooth-lora-sd-xl/runs/8hkhtzm8.

So, if I understand this correctly, it only stems from a multi-GPU environment?

bghira commented 1 year ago

yes

jorgemcgomes commented 1 year ago

So, if I understand this correctly, it only stems from a multi-GPU environment?

@sayakpaul for what it's worth, I was running an instance with a single RTX A6000.

sayakpaul commented 1 year ago

Then that's indeed strange, as I couldn't reproduce it on a single A100. You can also check out the WandB run page overview which gives you a full disclosure of the command I used to launch training and the git state I was in when launching training.

So, maybe the solution you proposed here is the best way forward :)

garychan22 commented 1 year ago

I have reproduced this error in a single A100 by setting CUDA_VISIBLE_DEVICES=0, num_process=1 (but the machine has 8 GPUs) and the error comes as Parameters which did not receive grad for rank 0: text_model.encoder.layers.11.self_attn.out_proj.lora_linear_layer.up.weight, text_model.encoder.layers.11.self_attn.out_proj.lora_linear_layer.down.weight, text_model.encoder.layers.11.self_attn.q_proj.lora_linear_layer.up.weight, text_model.encoder.layers.11.self_attn.q_proj.lora_linear_layer.down.weight, text_model.encoder.layers.11.self_attn.v_proj.lora_linear_layer.up.weight, text_model.encoder.layers.11.self_attn.v_proj.lora_linear_layer.down.weight, text_model.encoder.layers.11.self_attn.k_proj.lora_linear_layer.up.weight, text_model.encoder.layers.11.self_attn.k_proj.lora_linear_layer.down.weight Parameter indices which did not receive grad for rank 0: 88 89 90 91 92 93 94 95

diffusers 0.20.0.dev0 transformers 4.31.0 accelerate 0.20.3

sayakpaul commented 1 year ago

@jorgemcgomes feel free to open a PR whenever ready :-)

github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.