Open oliverdutton opened 3 months ago
@sokrypton I think this is ready for merging.
It's still strictly opt-in (as Pallas with Triton is only available for Ampere architecture GPUs and up)
You could improve performance a bit more by tuning block sizes and the number of warps on an input shape dependent manner, and similarly the 'subbatch_size` global config setting could be split into a default heuristic of memory usage where it selects subbatch sizes
Implements FlashAttention similarly to https://github.com/google-deepmind/alphafold/pull/931
For a 759 residue protein and model_5 this improves runtime 2.2x on an L4 (37.3 $\rightarrow$ 16.9 seconds [with minibatching of 256 for non-flash attention to avoid OOM])
Here's a colab link showing runtime improvement and no significant change in prediction output by visual inspection. I didn't want to rerun all the input prep so I've used a colab with alphafold input preparation and done fixes for colabdesign.
Notes
Key variations from a reference flash attention kernel are:
There's guards against kernel being called for short sequence lengths less than block sizes specified in q and k which exits to reference kernel.
Comments
use_flash_attention=False
I haven't changed behaviour: here's a colab showing same 37.3s runtime from the main branch.