keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.58k stars 19.42k forks source link

Add Flash Attention #19418

Open innat opened 5 months ago

innat commented 5 months ago

Describe

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.

Paper https://arxiv.org/abs/2205.14135 Cited by: 671

Implementation

Huggingface https://huggingface.co/docs/text-generation-inference/en/conceptual/flash_attention

Others

Has version 2 of it.

https://arxiv.org/abs/2307.08691

FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

fchollet commented 5 months ago

For JAX, we may want to rely on Pallas. For TF, since we can't rely on custom ops, we may have to skip support.

Presumably we should add it in the form of a new backend op, ops.nn.flash_attention.