showlab / Tune-A-Video

[ICCV 2023] Tune-A-Video: One-Shot Tuning of Image Diffusion Models for Text-to-Video Generation
https://tuneavideo.github.io
Apache License 2.0
4.15k stars 377 forks source link

unet.down_blocks does not seem to be updating #69

Open czk32611 opened 1 year ago

czk32611 commented 1 year ago

I ran the training code with two GPUs and got error message Parameters which did not receive grad for rank 0 down_blocks.2.attentions.1.transformer_blocks.0.attn_temp.to_out.0.weight, down_blocks.2.attentions.1.transformer_blocks.0.attn_temp.to_v.weight, down_blocks.2.attentions.1.transformer_blocks.0.attn_temp.to_k.weight, down_blocks.2.attentions.1.transformer_blocks.0.attn_temp.to_q.weight, xxx.

I double checked and found that down_blocks.0.attentions.0.transformer_blocks.0.attn_temp.to_v.weight is always zero.

This issue may be resulted from torch.utils.checkpoint.checkpoint in https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/models/unet_blocks.py#L300.

Reference: https://github.com/huggingface/transformers/issues/21381 "gradient checkpointing disables requires_grad when freezing part of models (fix with use_reentrant=False)"

Can someone confirm if this issue exists and provide a brief update?

zhangjiewu commented 1 year ago

can you check if this issue still occurs when training on one gpu?

czk32611 commented 1 year ago

can you check if this issue still occurs when training on one gpu?

It also occurs when training on one gpu with no warning or error. The trainable modules in down_blocks stills have no grad.

Quick checking: You can find unet.down_blocks[2].attentions[1].transformer_blocks[0].attn_temp.to_out[0].weight are always zeros and unet.down_blocks[2].attentions[1].transformer_blocks[0].attn_temp.to_out[0].weight.grad is alwasy None

DuanXiaoyue-LittleMoon commented 1 year ago

I'm facing the same issue: the network seems not being updated when trainning. The gradients of the trainable modules are always zero. Can anyone resolve the problem?

guomc9 commented 1 week ago

It is because you enable 'torch.utils.checkpoint' to save GPU memory. If you want to update learnable modules in checkpoint function, you must ensure your input tensor.requires_grad = True since Tune-A-Video uses the default parameter 'use_reentrant=True' in torch.utils.checkpoint function. You know why mid layers and up layers can be updated ? Since mid layers do not use torch.utils.checkpoint function, so trainable parameters make hidden_states tensor requires_grad=True, then up layers can be updated through use torch.utils.checkpoint function.

The simplest way to solve this is to set 'gradient_checkpointing: False', if your GPU has a sufficient memory🙃. Good luck! screenshot