After pytorch 2.2.0, flash attn 2 have been added to pytorch's sdp attn implementation. But it also drop the support of flash attn on Windows platform.
Since this project already require user to install xformers (which have flash attn 2 support on windows by default).
I recommend to use xformers instead of diffusers' default implementation (which use pytorch sdp).
Use xformers also avoid extra transpose operation for q/k/v and attn output.
After pytorch 2.2.0, flash attn 2 have been added to pytorch's sdp attn implementation. But it also drop the support of flash attn on Windows platform.
Since this project already require user to install xformers (which have flash attn 2 support on windows by default). I recommend to use xformers instead of diffusers' default implementation (which use pytorch sdp).
Use xformers also avoid extra transpose operation for q/k/v and attn output.