Closed jenkspt closed 1 year ago
@jenkspt hey Penn! thanks for the benchmarks
i think Flash Attention takes the online softmax a step further. it is more an online softmax weighted sum. it makes the most sense in the context of CUDA, where you can control HBM access
are you working on speeding up attention for work or is this a side project?
I was just messing around with it as a side project -- and thought I'd share when I saw this repo. Great work btw!
Thought you might find this interesting. I did some benchmarking of the online softmax algorithm used in the flash attention paper. https://github.com/jenkspt/online-softmax-jax. TLDR; not reliably faster than the naive softmax.