As of right now FlashAttention only supports one-dimensional local attention. I intend to implement up to three-dimensional local attention where the effective attention mask would be a rectangular cuboid. The interface I'm imagining would be something like cudaMemcpy2D where you would specify six values for the cuboid shape and two values for the y/z stride in the flattened array (instead of the two values for window size). The intended use case is image/video models.
Looking at the code it seems that one-dimensional local attention is implemented relatively simply by just adjusting the values for n_block_min and n_block_max. I'm thinking the easiest way to implement three-dimensional local attention would be add two more loops over the y and z dimensions of the cuboid. To avoid a performance regression I would use a template parameter that lets the compiler know when the loops are only over a single element.
Feedback and support by developers more experienced with this repository would be very much appreciated, especially when it comes to testing the implementation.
As of right now FlashAttention only supports one-dimensional local attention. I intend to implement up to three-dimensional local attention where the effective attention mask would be a rectangular cuboid. The interface I'm imagining would be something like
cudaMemcpy2D
where you would specify six values for the cuboid shape and two values for the y/z stride in the flattened array (instead of the two values for window size). The intended use case is image/video models.Looking at the code it seems that one-dimensional local attention is implemented relatively simply by just adjusting the values for
n_block_min
andn_block_max
. I'm thinking the easiest way to implement three-dimensional local attention would be add two more loops over the y and z dimensions of the cuboid. To avoid a performance regression I would use a template parameter that lets the compiler know when the loops are only over a single element.Feedback and support by developers more experienced with this repository would be very much appreciated, especially when it comes to testing the implementation.