ShivamShrirao / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch
https://huggingface.co/docs/diffusers
Apache License 2.0
1.88k stars 506 forks source link

RuntimeError: shape '[0, 8, 4096, 320]' is invalid for input of size 2621440 #98

Open zakinp opened 1 year ago

zakinp commented 1 year ago

A problem occur when i use the up-to-date "train_dreambooth.py" to finetune the model.

Firstly, i convert a .ckpt model to diffusers using the script called "convert_original_stable_diffusion_to_diffusers.py" in "/diffusers/scripts/". Then i use the "train_dreambooth.py" script to train the converted model and this problem occur.

The error message is:

Traceback (most recent call last): File "train_dreambooth.py", line 805, in main(args) File "train_dreambooth.py", line 751, in main noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, kwargs) File "/root/miniconda3/lib/python3.8/site-packages/accelerate/utils/operations.py", line 507, in call return convert_to_fp32(self.model_forward(*args, *kwargs)) File "/root/miniconda3/lib/python3.8/site-packages/torch/autocast_mode.py", line 12, in decorate_autocast return func(args, kwargs) File "/root/miniconda3/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py", line 296, in forward sample, res_samples = downsample_block( File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, kwargs) File "/root/miniconda3/lib/python3.8/site-packages/diffusers/models/unet_2d_blocks.py", line 563, in forward hidden_states = attn(hidden_states, context=encoder_hidden_states) File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, *kwargs) File "/root/miniconda3/lib/python3.8/site-packages/diffusers/models/attention.py", line 169, in forward hidden_states = block(hidden_states, context=context) File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(input, kwargs) File "/root/miniconda3/lib/python3.8/site-packages/diffusers/models/attention.py", line 217, in forward hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "/root/miniconda3/lib/python3.8/site-packages/diffusers/models/attention.py", line 295, in forward hidden_states = self.reshape_batch_dim_to_heads(hidden_states) File "/root/miniconda3/lib/python3.8/site-packages/diffusers/models/attention.py", line 267, in reshape_batch_dim_to_heads tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) RuntimeError: shape '[0, 8, 4096, 320]' is invalid for input of size 2621440

damienrj commented 1 year ago

I am also getting this issue, but am letting the code download the model.

zakinp commented 1 year ago

I am also getting this issue, but am letting the code download the model.

yeah, it seems like a bug. This error also occur when i convert the waifu-diffusion 1.3 model in huggingface to train through dreambooth.

jaworek commented 1 year ago

I'm getting similar error. I updated main branch to latest commit and also applied patch from #102. In both cases I still get the same error. I'm trying to run it on MacBook with M1 Pro and 32GB of RAM.

image

Here's the command that I'm using to run it:

accelerate launch train_dreambooth.py --pretrained_model_name_or_path=$MODEL_NAME --instance_data_dir=$INSTANCE_DIR --output_dir=$OUTPUT_DIR --instance_prompt="a photo of sks dog" --resolution=512 --train_batch_size=1 --gradient_accumulation_steps=1 --learning_rate=5e-6 --lr_scheduler="constant" --lr_warmup_steps=0 --max_train_steps=400

I also set to use MPS in accelerate config. I was able to successfully run cv_example.py from accelerate repo.