getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.04k stars 65 forks source link

Question: is it possible to implement flash attention with keops #286

Closed jaak-s closed 1 year ago

jaak-s commented 1 year ago

Hi, I'm new to the pykeops and was wonder if it would be possible implement flash attention, which used for removing the quadratic memory requirements on the sequence length: https://github.com/HazyResearch/flash-attention

Basic idea is that one does not need to use N^2 memory because for each row of the attention matrix can be computed independently and then multiplied to the V (so the whole NxN matrix does not need to be stored).

Thanks!

jeanfeydy commented 1 year ago

Hi @jaak-s,

Thanks for your interest in our library! This is entirely do-able, since KeOps and FlashAttention both rely on the same core numerical scheme that was initially documented by Nvidia for N-body computations in physics. (By the way, KeOps is discussed in appendix D of the FlashAttention paper.)

I actually did this back in April 2021 in the branch "attention" with a plug-in replacement for the MultiheadAttention layer. Benchmarks are available here.

Please note, however, that KeOps is not competitive to implement standard attention layers with attention heads of size > 16:

In this context, I think that KeOps may be of interest to people who want to experiment with "original" attention layers (as we sometimes do in geometric deep learning), but not really a competitive option for Natural Language Processing. I hope that this answers your question!

If you would like to ask anything else, please let me know.

Best regards, Jean

jaak-s commented 1 year ago

Thanks a lot for the detailed answer!

jeanfeydy commented 1 year ago

You're very welcome - that's an important question in today's context :-)

jaak-s commented 1 year ago

Agreed, it is a hot topic.

Even though the current implementation of flash-attention is well optimized for NLP there are applications outside NLP that need slight modifications like relative position encodings or distance based biases (ALIBI), which are not yet supported (https://github.com/HazyResearch/flash-attention/issues/17).

With a keops-based implementation these changes feel like one-liner modifications and would make any customization quite straightforward :-).

jeanfeydy commented 1 year ago

I see, thanks for the pointers :-)

We are not close enough to Transformer experts to implement competitive layers ourselves (I already have my hands full applying KeOps to anatomical data and drug consumption records!), but I'm more than happy to provide performance tips and/or include useful features to KeOps if this could help the "attention" community.

Our priorities for 2023 lay closer to transparent usage on generic hardware (100% compatible numpy interface, CPU support...) than to bleeding edge performance on Nvidia GPUs (with automated mixed precision, etc.), but these are certainly interesting research directions.

Best regards, Jean