ROCm / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
141 stars 46 forks source link

Clean up for Upstream #81

Closed micmelesse closed 2 months ago

micmelesse commented 2 months ago

Hi, this is a pr to add a Triton backend to Flash Attention on ROCm. We hope that this pr will be the first in a series of prs to that end. Triton has had support for ROCm for a while now and a Flash Attention Triton backend will allows us to support Flash Attention on both our MI and Navi Machines.

In this pr, we enable major parts of fwd, varlen_fwd and fwd_kvcache. However there are some features missing such as Dropout, Sliding window, Rotary Embedding and Pagged Attention. There are also a few miscellaneous bugs. These will all be addressed in subsequent prs. The next pr we plan to file will be support for bwd and varlen_vwd, if we should reprioritize, please let us know.

We have tested this pr here on an MI200 machine. When the testing the Triton Backend for ROCm, we skip testing the backward pass, configs with unsupported features and a portion of headsizes (d) randomly. The later is to keep the test times reasonable. The latest results, we have are === 64 failed, 30387 passed, 478321 skipped, 1 warning in 3110.86s (0:51:50) ===. There is clearly more work to be done but we hope that this will make a good start.

Please let us know what we can do on our end to help with this process. Finally this pr includes work from multiple people besides myself, especially thanks to @vgokhale, @scxiao and @jlgreathouse.