modelscope / DiffSynth-Studio

Enjoy the magic of Diffusion models!
Apache License 2.0
6.58k stars 600 forks source link

Flash attention question #14

Open JonathanLi19 opened 9 months ago

JonathanLi19 commented 9 months ago

Hi, Great work!! I have one question: in the paper, you said that "we adopt flash attention [6] in all attention layers, including the text encoder, UNet, VAE, ControlNet models, and motion modules". I found the xformers_forward() function in the Attention module. However, this function is never called during the whole process of "diffutoon_toon_shading.py". It is very strange, since it still can generate high resolution videos. I am very confused how does this work? Thanks!

Artiprocher commented 9 months ago

In the Attention module, we provide two different forward functions, namely xformers_forward and torch_forward. In our early work, we intend to only use xformers_forward. However, we found xformers sometimes doesn't work. For example, the dim of the tensors is not 2^n. Therefore, we use the scaled_dot_product in torch instead of xformers. scaled_dot_product is another implementation of flash attention, which is only supported in torch>=2.0. This implementation is fast and stable enough. If you are interested, you can see this document for more information.

JonathanLi19 commented 9 months ago

Thank you very much!

JonathanLi19 commented 8 months ago

Hi, sorry to bother you agian. I have two more questions.

  1. I tried to rebuild your pipeline using the open source diffusers library. I changed the Attention layer in this library using the above mentioned scaled_dot_product_attention function instead of xformers. This function is used in Unet and ControlNet. However, it will still cause CUDA OOM Error when generating high-resolution videos. I want to ask that is this caused by the fact that I don't use this function in VAE and CLIPTextEncoder(These two architectures aren't implemented by diffusers)? Or is it because of something magic in your frame? (You don't use diffusers at all, and change the whole code frame, which is really terrific.)
  2. "1 torch was not compiled with flash attention", this warning is shown during code execution, does it matter? Thank you very much!