tspeterkim / flash-attention-minimal

Flash Attention in ~100 lines of CUDA (forward pass only)
Apache License 2.0
548 stars 48 forks source link

flash-attention-minimal

A minimal re-implementation of Flash Attention with CUDA and PyTorch. The official implementation can be quite daunting for a CUDA beginner (like myself), so this repo tries to be small and educational.

Usage

Prerequisite

Benchmark

Compare the wall-clock time between manual attention and minimal flash attention:

python bench.py

Sample output on a T4:

=== profiling manual attention ===
...
Self CPU time total: 52.389ms
Self CUDA time total: 52.545ms

=== profiling minimal flash attention === 
...  
Self CPU time total: 11.452ms
Self CUDA time total: 3.908ms

Speed-up achieved!

I don't have a GPU

Try out this online colab demo.

Caveats

Todos