Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.49k stars 1.24k forks source link

Three-dimensional local attention #947

Open JohannesGaessler opened 4 months ago

JohannesGaessler commented 4 months ago

As of right now FlashAttention only supports one-dimensional local attention. I intend to implement up to three-dimensional local attention where the effective attention mask would be a rectangular cuboid. The interface I'm imagining would be something like cudaMemcpy2D where you would specify six values for the cuboid shape and two values for the y/z stride in the flattened array (instead of the two values for window size). The intended use case is image/video models.

Looking at the code it seems that one-dimensional local attention is implemented relatively simply by just adjusting the values for n_block_min and n_block_max. I'm thinking the easiest way to implement three-dimensional local attention would be add two more loops over the y and z dimensions of the cuboid. To avoid a performance regression I would use a template parameter that lets the compiler know when the loops are only over a single element.

Feedback and support by developers more experienced with this repository would be very much appreciated, especially when it comes to testing the implementation.

matthijsvk commented 4 months ago

there's a project doing 'neighborhood attention', which sounds like what you're looking for: https://github.com/SHI-Labs/NATTEN