huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.27k stars 5.23k forks source link

Tensor size mismatch for unet2dconditional #7611

Closed metatl closed 5 months ago

metatl commented 5 months ago

Describe the bug

So I just started with the diffusers. I was following the tutorial: https://huggingface.co/docs/diffusers/using-diffusers/write_own_pipeline An error occurred when I followed the code to this line: with torch.no_grad(): noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample while denoising the noise tensor. Here's the full run time error: 0%| | 0/25 [00:00<?, ?it/s]

RuntimeError Traceback (most recent call last) Cell In[61], line 13 11 # predict the noise residual 12 with torch.no_grad(): ---> 13 noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample 15 # perform guidance 16 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, kwargs) 1509 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(args, kwargs)

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, *kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(args, **kwargs) 1522 try: 1523 result = None

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py:1216, in UNet2DConditionModel.forward(self, sample, timestep, encoder_hidden_states, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, down_intrablock_additional_residuals, encoder_attention_mask, return_dict) 1213 if is_adapter and len(down_intrablock_additional_residuals) > 0: 1214 additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) -> 1216 sample, res_samples = downsample_block( 1217 hidden_states=sample, 1218 temb=emb, 1219 encoder_hidden_states=encoder_hidden_states, 1220 attention_mask=attention_mask, 1221 cross_attention_kwargs=cross_attention_kwargs, 1222 encoder_attention_mask=encoder_attention_mask, 1223 **additional_residuals, 1224 ) 1225 else: 1226 sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, kwargs) 1509 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(args, kwargs)

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, *kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(args, **kwargs) 1522 try: 1523 result = None

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_blocks.py:1279, in CrossAttnDownBlock2D.forward(self, hidden_states, temb, encoder_hidden_states, attention_mask, cross_attention_kwargs, encoder_attention_mask, additional_residuals) 1277 else: 1278 hidden_states = resnet(hidden_states, temb) -> 1279 hidden_states = attn( 1280 hidden_states, 1281 encoder_hidden_states=encoder_hidden_states, 1282 cross_attention_kwargs=cross_attention_kwargs, 1283 attention_mask=attention_mask, 1284 encoder_attention_mask=encoder_attention_mask, 1285 return_dict=False, 1286 )[0] 1288 # apply additional residuals to the output of the last pair of resnet and attention blocks 1289 if i == len(blocks) - 1 and additional_residuals is not None:

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, kwargs) 1509 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(args, kwargs)

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, *kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(args, **kwargs) 1522 try: 1523 result = None

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:397, in Transformer2DModel.forward(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, class_labels, cross_attention_kwargs, attention_mask, encoder_attention_mask, return_dict) 385 hidden_states = torch.utils.checkpoint.checkpoint( 386 create_custom_forward(block), 387 hidden_states, (...) 394 **ckpt_kwargs, 395 ) 396 else: --> 397 hidden_states = block( 398 hidden_states, 399 attention_mask=attention_mask, 400 encoder_hidden_states=encoder_hidden_states, 401 encoder_attention_mask=encoder_attention_mask, 402 timestep=timestep, 403 cross_attention_kwargs=cross_attention_kwargs, 404 class_labels=class_labels, 405 ) 407 # 3. Output 408 if self.is_input_continuous:

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, kwargs) 1509 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(args, kwargs)

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, *kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(args, **kwargs) 1522 try: 1523 result = None

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/diffusers/models/attention.py:372, in BasicTransformerBlock.forward(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, class_labels, added_cond_kwargs) 364 norm_hidden_states = self.pos_embed(norm_hidden_states) 366 attn_output = self.attn2( 367 norm_hidden_states, 368 encoder_hidden_states=encoder_hidden_states, 369 attention_mask=encoder_attention_mask, 370 **cross_attention_kwargs, 371 ) --> 372 hidden_states = attn_output + hidden_states 374 # 4. Feed-forward 375 # i2vgen doesn't have this norm 🤷‍♂️ 376 if self.norm_type == "ada_norm_continuous":

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/torch/utils/_device.py:77, in DeviceContext.__torch_function__(self, func, types, args, kwargs) 75 if func in _device_constructors() and kwargs.get('device') is None: 76 kwargs['device'] = self.device ---> 77 return func(*args, **kwargs)

RuntimeError: The size of tensor a (8192) must match the size of tensor b (4096) at non-singleton dimension 1

It seems that the codes specifically tried to take care of this by cat two noise tensors together: latent_model_input = torch.cat([latents] * 2) but when it's passed to the unet, there's an error. I don't know how to fix this.

Reproduction

from tqdm.auto import tqdm

scheduler.set_timesteps(num_inference_steps)

for t in tqdm(scheduler.timesteps):

expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.

latent_model_input = torch.cat([latents] * 2)

latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)

# predict the noise residual
with torch.no_grad():
    noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents).prev_sample

Logs

0%|          | 0/25 [00:00<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[61], line 13
     11 # predict the noise residual
     12 with torch.no_grad():
---> 13     noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
     15 # perform guidance
     16 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py:1216, in UNet2DConditionModel.forward(self, sample, timestep, encoder_hidden_states, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, down_intrablock_additional_residuals, encoder_attention_mask, return_dict)
   1213     if is_adapter and len(down_intrablock_additional_residuals) > 0:
   1214         additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
-> 1216     sample, res_samples = downsample_block(
   1217         hidden_states=sample,
   1218         temb=emb,
   1219         encoder_hidden_states=encoder_hidden_states,
   1220         attention_mask=attention_mask,
   1221         cross_attention_kwargs=cross_attention_kwargs,
   1222         encoder_attention_mask=encoder_attention_mask,
   1223         **additional_residuals,
   1224     )
   1225 else:
   1226     sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_blocks.py:1279, in CrossAttnDownBlock2D.forward(self, hidden_states, temb, encoder_hidden_states, attention_mask, cross_attention_kwargs, encoder_attention_mask, additional_residuals)
   1277 else:
   1278     hidden_states = resnet(hidden_states, temb)
-> 1279     hidden_states = attn(
   1280         hidden_states,
   1281         encoder_hidden_states=encoder_hidden_states,
   1282         cross_attention_kwargs=cross_attention_kwargs,
   1283         attention_mask=attention_mask,
   1284         encoder_attention_mask=encoder_attention_mask,
   1285         return_dict=False,
   1286     )[0]
   1288 # apply additional residuals to the output of the last pair of resnet and attention blocks
   1289 if i == len(blocks) - 1 and additional_residuals is not None:

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:397, in Transformer2DModel.forward(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, class_labels, cross_attention_kwargs, attention_mask, encoder_attention_mask, return_dict)
    385         hidden_states = torch.utils.checkpoint.checkpoint(
    386             create_custom_forward(block),
    387             hidden_states,
   (...)
    394             **ckpt_kwargs,
    395         )
    396     else:
--> 397         hidden_states = block(
    398             hidden_states,
    399             attention_mask=attention_mask,
    400             encoder_hidden_states=encoder_hidden_states,
    401             encoder_attention_mask=encoder_attention_mask,
    402             timestep=timestep,
    403             cross_attention_kwargs=cross_attention_kwargs,
    404             class_labels=class_labels,
    405         )
    407 # 3. Output
    408 if self.is_input_continuous:

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/diffusers/models/attention.py:372, in BasicTransformerBlock.forward(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, class_labels, added_cond_kwargs)
    364         norm_hidden_states = self.pos_embed(norm_hidden_states)
    366     attn_output = self.attn2(
    367         norm_hidden_states,
    368         encoder_hidden_states=encoder_hidden_states,
    369         attention_mask=encoder_attention_mask,
    370         **cross_attention_kwargs,
    371     )
--> 372     hidden_states = attn_output + hidden_states
    374 # 4. Feed-forward
    375 # i2vgen doesn't have this norm 🤷‍♂️
    376 if self.norm_type == "ada_norm_continuous":

File /panfs/ccds02/nobackup/people/tyuan/diffusion_models/.env/lib/python3.10/site-packages/torch/utils/_device.py:77, in DeviceContext.__torch_function__(self, func, types, args, kwargs)
     75 if func in _device_constructors() and kwargs.get('device') is None:
     76     kwargs['device'] = self.device
---> 77 return func(*args, **kwargs)

RuntimeError: The size of tensor a (8192) must match the size of tensor b (4096) at non-singleton dimension 1

System Info

Latest diffusers version

Who can help?

@DN6 @yiyixuxu @sayakpaul

tolgacangoz commented 5 months ago

Probably, you ran something unexpected/twice etc. Could you run the whole page again from scratch?

metatl commented 5 months ago

Probably, you ran something unexpected/twice etc. Could you run the whole page again from scratch?

I rerun it and checked: print (latents.size()) print(latent_model_input.size()) torch.Size([1, 4, 64, 64]) torch.Size([2, 4, 64, 64])

Seems not to be the issue?

tolgacangoz commented 5 months ago

There is nothing wrong. As you said as well, it is due to this:

    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
    latent_model_input = torch.cat([latents] * 2)

Could you please run all the code and show the error in a Colab notebook?

metatl commented 5 months ago

There is nothing wrong. It is due to this as you said as well:

    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
    latent_model_input = torch.cat([latents] * 2)

Could you please run all the code and show the error in a Colab notebook?

I reran the code from the beginning with the line commented out. The error is the same: RuntimeError: The size of tensor a (8192) must match the size of tensor b (4096) at non-singleton dimension 1

I used the same exact code as in the tutorial. It seems that the classifier-free guidance model expects 8192 but only getting 4096, even if a concatenated tensor is supplied...

tolgacangoz commented 5 months ago

Could you examine this Colab notebook?

metatl commented 5 months ago

Could you examine this Colab notebook?

This absolutely works in colab. But the same code gives me error on my local cluster! What could be the cause?

metatl commented 5 months ago

Could you examine this Colab notebook?

Thank you for the reminder. I found an error in my copy and paste.Deleting the post.