Open HJoonKwon opened 9 months ago
Hey @HJoonKwon! Damn, very good find, thank you! I guess this does matter in compiled forward, where we are padding inputs to static dimensions. We'd need to run the benchmarks, but maybe avoiding the call to half()
could improve throughput then.
@Phil26AT Great! Thank you again for your great work. I got inspired a lot.
On the topic of FlashAttention, you link to FlashAttention and not FlashAttention2 here Isn't the second version used? If not, why? Seems quite much faster
FlashAttention: https://arxiv.org/abs/2205.14135 FlashAttention2: https://arxiv.org/pdf/2307.08691.pdf?trk=public_post_comment-text
Thanks for your great work!
I'm just curious whether your code here is using flash or not when mask is not
None
. My guess is it's using memory efficient attention instead since PyTorch flash attention kernel does not support attention mask. In addition, if memory efficient was used,half()
would not have been needed when mask is notNone
. Thank you!++ I did some experiments. Even if sdp_flash is enabled, it is not executed when mask is not
None
. If we force PyTorch to use flash, it spits out an error like below.while memory efficient kernel does not