sokrypton / ColabDesign

Making Protein Design accessible to all via Google Colab!
529 stars 118 forks source link

Flash attention #173

Open oliverdutton opened 3 months ago

oliverdutton commented 3 months ago

Implements FlashAttention similarly to https://github.com/google-deepmind/alphafold/pull/931

For a 759 residue protein and model_5 this improves runtime 2.2x on an L4 (37.3 $\rightarrow$ 16.9 seconds [with minibatching of 256 for non-flash attention to avoid OOM])

Here's a colab link showing runtime improvement and no significant change in prediction output by visual inspection. I didn't want to rerun all the input prep so I've used a colab with alphafold input preparation and done fixes for colabdesign.

Notes

Key variations from a reference flash attention kernel are:

There's guards against kernel being called for short sequence lengths less than block sizes specified in q and k which exits to reference kernel.

Comments

oliverdutton commented 2 months ago

@sokrypton I think this is ready for merging.

It's still strictly opt-in (as Pallas with Triton is only available for Ampere architecture GPUs and up)

You could improve performance a bit more by tuning block sizes and the number of warps on an input shape dependent manner, and similarly the 'subbatch_size` global config setting could be split into a default heuristic of memory usage where it selects subbatch sizes