thu-nics / DiTFastAttn

MIT License
104 stars 8 forks source link

Thoughts on full3d optimization #5

Closed xesdiny closed 3 months ago

xesdiny commented 3 months ago

Thank you for doing such a relevant job!

Your identify three types of redundancy: (1) Redundancy in the spatial dimension. (2) Similarity between the neighboring steps in attention outputs. (3) Similarity between the conditional and unconditional inference in attention outputs. If the sequence construction on mmDiT Arch is the self-attn of text + T x H x W, it seems that the premises of method 1&2 are not very tenable (winodw attn + Sharing across Timesteps), and method 3 seems to be unable to construct CFG residual. cached between condition &unconditon.

hahnyuan commented 3 months ago

Thank you for your thoughtful feedback. Regarding the application of our methods to mmDiT architecture:

  1. Window attention: You're correct that this may not be as effective in mmDiT due to its different attention structure. We think that preserving all the attention for text tokens is likely more appropriate in this case.

  2. Attention outputs sharing across timesteps: Maybe some certain of sharing is also OK. We haven't conducted experiments on how it might impact this aspect. This is certainly an interesting direction for future exploration.

  3. Sharing Residual: We believe these could still be effective in mmDiT.

  4. We didn't utilize CFG residuals. In our experiments with DiT, we found that direct replacement yielded good results while avoiding increased memory consumption and slight decreases in inference speed. That said, we haven't tested this on other networks, so exploring CFG residuals in different architectures, including mmDiT, could be a valuable avenue for further research.

We appreciate your insights into how our methods might apply to or be adapted for different model architectures. We plan to do these explorations in the next version.