showlab / Tune-A-Video

[ICCV 2023] Tune-A-Video: One-Shot Tuning of Image Diffusion Models for Text-to-Video Generation
https://tuneavideo.github.io
Apache License 2.0
4.15k stars 376 forks source link

Attention weight in CrossAttnDownBlock3D is not trained? #18

Closed zhangfuyang closed 1 year ago

zhangfuyang commented 1 year ago

Hi, Congrats on this awesome work!

One question. For this line, it seems the gradient is not going through.

I'm not familiar with torch.utils.checkpoint.checkpoint. Is this something to re-calculate the gradient during the backward pass while using no_grad() during the forward time?

However, even so, I didn't observe the weight change after one iteration of training. Specifically, unet.down_blocks[0].attentions[0].transformer_blocks[0].attn1.to_q.weight is not changed.

Is this correct? or I missed something?

lealaxy commented 1 year ago

same question. I add this code below accelerator.backward(loss)

                accelerator.backward(loss)
                for name, paras in unet.named_parameters():
                    if paras.requires_grad and paras.grad is None:
                        logger.error(name)

And this is the output of the program

module.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight
module.down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q.weight
module.down_blocks.0.attentions.0.transformer_blocks.0.attn_temp.to_q.weight
module.down_blocks.0.attentions.0.transformer_blocks.0.attn_temp.to_k.weight
module.down_blocks.0.attentions.0.transformer_blocks.0.attn_temp.to_v.weight
module.down_blocks.0.attentions.0.transformer_blocks.0.attn_temp.to_out.0.weight
module.down_blocks.0.attentions.0.transformer_blocks.0.attn_temp.to_out.0.bias
module.down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_q.weight
module.down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_q.weight
module.down_blocks.0.attentions.1.transformer_blocks.0.attn_temp.to_q.weight
module.down_blocks.0.attentions.1.transformer_blocks.0.attn_temp.to_k.weight
module.down_blocks.0.attentions.1.transformer_blocks.0.attn_temp.to_v.weight
module.down_blocks.0.attentions.1.transformer_blocks.0.attn_temp.to_out.0.weight
module.down_blocks.0.attentions.1.transformer_blocks.0.attn_temp.to_out.0.bias
module.down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight
module.down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight
module.down_blocks.1.attentions.0.transformer_blocks.0.attn_temp.to_q.weight
module.down_blocks.1.attentions.0.transformer_blocks.0.attn_temp.to_k.weight
module.down_blocks.1.attentions.0.transformer_blocks.0.attn_temp.to_v.weight
module.down_blocks.1.attentions.0.transformer_blocks.0.attn_temp.to_out.0.weight
module.down_blocks.1.attentions.0.transformer_blocks.0.attn_temp.to_out.0.bias
module.down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight
module.down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight
module.down_blocks.1.attentions.1.transformer_blocks.0.attn_temp.to_q.weight
module.down_blocks.1.attentions.1.transformer_blocks.0.attn_temp.to_k.weight
module.down_blocks.1.attentions.1.transformer_blocks.0.attn_temp.to_v.weight
module.down_blocks.1.attentions.1.transformer_blocks.0.attn_temp.to_out.0.weight
module.down_blocks.1.attentions.1.transformer_blocks.0.attn_temp.to_out.0.bias
module.down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight
module.down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight
module.down_blocks.2.attentions.0.transformer_blocks.0.attn_temp.to_q.weight
module.down_blocks.2.attentions.0.transformer_blocks.0.attn_temp.to_k.weight
module.down_blocks.2.attentions.0.transformer_blocks.0.attn_temp.to_v.weight
module.down_blocks.2.attentions.0.transformer_blocks.0.attn_temp.to_out.0.weight
module.down_blocks.2.attentions.0.transformer_blocks.0.attn_temp.to_out.0.bias
module.down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight
module.down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight
module.down_blocks.2.attentions.1.transformer_blocks.0.attn_temp.to_q.weight
module.down_blocks.2.attentions.1.transformer_blocks.0.attn_temp.to_k.weight
module.down_blocks.2.attentions.1.transformer_blocks.0.attn_temp.to_v.weight
module.down_blocks.2.attentions.1.transformer_blocks.0.attn_temp.to_out.0.weight
module.down_blocks.2.attentions.1.transformer_blocks.0.attn_temp.to_out.0.bias

It seems the gradient is not going through.