lucidrains / flash-attention-jax

Implementation of Flash Attention in Jax
MIT License
194 stars 23 forks source link

Online Softmax from FlashAttention #2

Closed jenkspt closed 1 year ago

jenkspt commented 2 years ago

Thought you might find this interesting. I did some benchmarking of the online softmax algorithm used in the flash attention paper. https://github.com/jenkspt/online-softmax-jax. TLDR; not reliably faster than the naive softmax.

lucidrains commented 2 years ago

@jenkspt hey Penn! thanks for the benchmarks

i think Flash Attention takes the online softmax a step further. it is more an online softmax weighted sum. it makes the most sense in the context of CUDA, where you can control HBM access

are you working on speeding up attention for work or is this a side project?

jenkspt commented 2 years ago

I was just messing around with it as a side project -- and thought I'd share when I saw this repo. Great work btw!