Closed why-in-Shanghaitech closed 2 weeks ago
It looks like you are using flex_attention in eager mode. Please use compilation mode instead. i.e., flex_attention = torch.compile(flex_attention)
It looks like you are using flex_attention in eager mode. Please use compilation mode instead. i.e.,
flex_attention = torch.compile(flex_attention)
Thank you so much! This exactly solves the problem.
🐛 Describe the bug
I try to use flex attention in huggingface transformer, only to find it very slow. Compared to the sdpa implementation, flex attention is about 4-5 times slower, but it does save the CUDA memory.
Tested on RTX3090, A6000 and A100.
Here is the example code: https://gist.github.com/why-in-Shanghaitech/8b8205f98568c6741a2e38dfcdb9d362
I have no idea what is happening. Is this normal? Can anyone reproduce this? Or this problem is related to huggingface transformers?
Versions
cc @ezyang @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng