huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.49k stars 5.28k forks source link

LORA error when running train_text_to_image_lora.py, error Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! #5897

Closed MhDang closed 9 months ago

MhDang commented 10 months ago

Describe the bug

I tried to experiment with LoRA training following examples/text_to_image/README.md#training-with-lora.

However, I got the error RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_CUDA_mm) on line 801.

The same issue did not occur when I was trying the the same example (with the implementation at that time) months ago. I noticed there were several commits after that.

I followed the README.md for installing packages and the non-LoRA training works well.

Thank you very much!

Reproduction

  1. Install packages following README.md:
    git clone https://github.com/huggingface/diffusers
    cd diffusers
    pip install .

Then cd in the folder examples/text_to_image and run

pip install -r requirements.txt
  1. in directory examples/text_to_image run the following
    export MODEL_NAME="CompVis/stable-diffusion-v1-4"
    export DATASET_NAME="lambdalabs/pokemon-blip-captions"
    accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \
    --pretrained_model_name_or_path=$MODEL_NAME \
    --dataset_name=$DATASET_NAME --caption_column="text" \
    --resolution=512 --random_flip \
    --train_batch_size=1 \
    --num_train_epochs=100 --checkpointing_steps=5000 \
    --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
    --seed=42 \
    --output_dir="sd-pokemon-model-lora" \
    --validation_prompt="cute dragon creature" --report_to="wandb"

Logs

11/22/2023 08:36:20 - INFO - __main__ - ***** Running training *****
11/22/2023 08:36:20 - INFO - __main__ -   Num examples = 833
11/22/2023 08:36:20 - INFO - __main__ -   Num Epochs = 100
11/22/2023 08:36:20 - INFO - __main__ -   Instantaneous batch size per device = 1
11/22/2023 08:36:20 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 1
11/22/2023 08:36:20 - INFO - __main__ -   Gradient Accumulation steps = 1
11/22/2023 08:36:20 - INFO - __main__ -   Total optimization steps = 83300
Steps:   0%|                                                                                                                                | 0/83300 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "./repo/diffusers/examples/text_to_image/train_text_to_image_lora.py", line 975, in <module>
    main()
  File "./repo/diffusers/examples/text_to_image/train_text_to_image_lora.py", line 801, in main
    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
  File "./miniconda3/envs/diffusers_cuda117/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "./miniconda3/envs/diffusers_cuda117/lib/python3.9/site-packages/diffusers/models/unet_2d_condition.py", line 1075, in forward
    sample, res_samples = downsample_block(
  File "./miniconda3/envs/diffusers_cuda117/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "./miniconda3/envs/diffusers_cuda117/lib/python3.9/site-packages/diffusers/models/unet_2d_blocks.py", line 1160, in forward
    hidden_states = attn(
  File "./miniconda3/envs/diffusers_cuda117/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "./miniconda3/envs/diffusers_cuda117/lib/python3.9/site-packages/diffusers/models/transformer_2d.py", line 375, in forward
    hidden_states = block(
  File "./miniconda3/envs/diffusers_cuda117/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "./miniconda3/envs/diffusers_cuda117/lib/python3.9/site-packages/diffusers/models/attention.py", line 258, in forward
    attn_output = self.attn1(
  File "./miniconda3/envs/diffusers_cuda117/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "./miniconda3/envs/diffusers_cuda117/lib/python3.9/site-packages/diffusers/models/attention_processor.py", line 522, in forward
    return self.processor(
  File "./miniconda3/envs/diffusers_cuda117/lib/python3.9/site-packages/diffusers/models/attention_processor.py", line 1211, in __call__
    query = attn.to_q(hidden_states, *args)
  File "./miniconda3/envs/diffusers_cuda117/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "./miniconda3/envs/diffusers_cuda117/lib/python3.9/site-packages/diffusers/models/lora.py", line 433, in forward
    out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
  File "./miniconda3/envs/diffusers_cuda117/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "./miniconda3/envs/diffusers_cuda117/lib/python3.9/site-packages/diffusers/models/lora.py", line 220, in forward
    down_hidden_states = self.down(hidden_states.to(dtype))
  File "./miniconda3/envs/diffusers_cuda117/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "./miniconda3/envs/diffusers_cuda117/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapp
er_CUDA_mm)

System Info

Who can help?

@sayakpaul @patrickvonplaten

MhDang commented 10 months ago

If the current version is still in development, would it also be possible to point to any previsous working version?

zanedamico commented 10 months ago

Also having the same issue, after successfully training last week

wellCh4n commented 10 months ago

I can reproduce the issue in my case too. But I've looked at the lora script commits, and there was a recent one with big changes, and I used the previous commit, which runs fine in my case.

Download this and replace ./diffusers/examples/text_to_image/train_text_to_image_lora.py This is a temporary solution, sadly I'm not familiar with this.

IceClear commented 10 months ago

I think the reason is that the Lora parameters is added to the unet after unet is send to GPU. So the LORA is actually on CPU, leading to the error. A simple way to fix this typo is to first add Lora to unet and then send them together to GPU:

image

abetatos commented 10 months ago

Same behaviour here, falling back to the version @wellCh4n provided solved the problem.

NEhlen commented 10 months ago

I think the reason is that the Lora parameters is added to the unet after unet is send to GPU. So the LORA is actually on CPU, leading to the error. A simple way to fix this typo is to first add Lora to unet and then send them together to GPU:

image

Can confirm this fixes the error for me, however at least on Colab with a T4 runtime I then get a "Expected is_sm80 || is_sm90 to be true, but got false." error message when the script tries to backpropagate the loss. Not sure if this is an issue with the new script or some compatibility issue with the CUDA drivers in the Colab though.

sayakpaul commented 10 months ago

This seems like a setup problem to me as I am unable to reproduce it, even on a Google Colab: https://github.com/huggingface/diffusers/issues/5004#issuecomment-1780909598

hanweikung commented 10 months ago

I got the same error. However, reverting to the previous version, as @wellCh4n suggested, resolved the issue.

maliozer commented 10 months ago

same issue following for fix

sayakpaul commented 10 months ago

I am gonna have to repeat myself here:

https://github.com/huggingface/diffusers/issues/5897#issuecomment-1827075282

zanedamico commented 10 months ago

@sayakpaul Is there anything we can do to help you reproduce this issue? Seems significant as multiple people with different setups have encountered the same issue. Otherwise we're forced to keep using this older version indefinently.

sayakpaul commented 10 months ago

A Colab notebook would be nice because that's the easiest to reproduce. As already indicated here, I was not able to reproduce at all: https://github.com/huggingface/diffusers/issues/5897#issuecomment-1827075282.

And I am quite sure https://github.com/huggingface/diffusers/pull/5388 will resolve these problems for good.

MohamadZeina commented 10 months ago

Hopefuly this is fixed when moving to PEFT - in the meantime if you don't want to revert to an older version, I had the same issue, and fixed it by adding 1 line:

unet.to(accelerator.device, dtype=weight_dtype)

At my line 539, immediately after the LORA weights are added, and outside the loop:

    # Accumulate the LoRA params to optimize.
    unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
    unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
    unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
    unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())

unet.to(accelerator.device, dtype=weight_dtype)

Thanks to @IceClear and others that found that some of the unet was on the wrong device.

sayakpaul commented 10 months ago

If you want to open a PR fixing it, more than happy to merge :)

MohamadZeina commented 10 months ago

@sayakpaul Thank you - I've opened #6061, let me know if it needs any modification

maliozer commented 10 months ago

@sayakpaul Thank you - I've opened #6061, let me know if it needs any modification

Is this still on the progress ?

github-actions[bot] commented 9 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.