showlab / Tune-A-Video

[ICCV 2023] Tune-A-Video: One-Shot Tuning of Image Diffusion Models for Text-to-Video Generation
https://tuneavideo.github.io
Apache License 2.0
4.15k stars 377 forks source link

attention_mask is None during training #53

Closed liqq010 closed 1 year ago

liqq010 commented 1 year ago

In the paper it mentioned that the ST-Attn layer only consider the previous frame and the first frame. But in the code it seems that during training, the attention_mask for the SC-Attn layer is None. Could you explain why this setting is not same as in the paper?

zhangjiewu commented 1 year ago

it's the same as what we describe in the paper. here, we do not use attention mask, which still requires calculating full attention in spatiotemporal domain. instead, we stack the key and value of the first and former frames, and do the SC-Attn as follows:

$$Q=W^Q z_{vi}, K=W^K \left[ z{v1}, z{v{i-1}} \right], V=W^V \left[ z{v1}, z{v_{i-1}} \right],$$

$$\mathrm{Attention}(Q,K,V)=\mathrm{Softmax}(\frac{Q K^T}{\sqrt{d}}) \cdot V.$$

this is much more efficient than using attention mask. the corresponding code snippet in our code: https://github.com/showlab/Tune-A-Video/blob/f6a94f0c40d6e6b60b2c0e2a858ba21de504f1e6/tuneavideo/models/attention.py#L292-L301