google-deepmind / alphafold

Open source code for AlphaFold.
Apache License 2.0
12.35k stars 2.21k forks source link

Flash attention #931

Open oliverdutton opened 5 months ago

oliverdutton commented 5 months ago

Flash attention implemented to reduce runtime and memory usage using Pallas. Added on opt-in basis in the global config.

For a 759 residue protein and model_5 this drops peak memory consumption to 5 GB without minibatching and reduces runtime 2.3x on an A100 (15.2 $\rightarrow$ 6.5 seconds [with minibatching of 256 for non-flash attention to avoid OOM])

Here's a colab link showing runtime improvement and no significant change in prediction output by visual inspection

When combined with https://github.com/google-deepmind/alphafold/pull/930 (bfloat16 support for monomer models) peak memory drops to only 2.7 GB and runtime to 5.6 seconds (2.7x speedup relative to non-flash, float32)

Notes:

Key variations from a reference flash attention kernel are:

There's guards against kernel being called for short sequence lengths less than block sizes specified in q and k which exits to reference kernel.

I haven't done correctness checks with multimer models, I would do if there was a positive response to this pull request. I'm not certain on the numerical stability of the implementation yet with bfloat16

(I can switch out the exp and log for exp2 and log2 for a small reduction in runtime, this leads to slightly different predictions but with testing I believe would show equivalent error in structure prediction)

sokrypton commented 5 months ago

Hi @oliverdutton ! Really cool contribution. Mind we try add it to colabfold? We already have fused attention and bfloat16 integrated into monomer model. Will be interesting to try flash attention as well.

oliverdutton commented 5 months ago

@sokrypton Of course, I've made a pull request in ColabDesign with it (https://github.com/sokrypton/ColabDesign/pull/173)

oliverdutton commented 5 months ago

Pre https://github.com/google-deepmind/alphafold/pull/931/commits/d4516d83aaf65aee2e2c90ca85b86acacd464c0f I find transient NaN behaviour on shapes which don't evenly divide block size (so OOB loading).

gist to reproduce problem:

import jax
from jax import jit, numpy as jnp
from alphafold.model import model

key = jax.random.PRNGKey(42)
nrepeats = 100
for nres in range(128,256):
    print(nres)
    for i in range(nrepeats):
        q, k, v = jax.random.uniform(key, (3, 1024, nres, 8, 32))
        f = jax.jit(model.modules.Attention.flash_kernel, static_argnames=(
            'return_residual', 'block_q', 'block_k', 'num_warps', 'num_stages', 'grid', 'interpret', 'debug')
        )
        assert jnp.isfinite(f(q,k,v)).all(), f"Failed with {nres} on run {i}"

Post https://github.com/google-deepmind/alphafold/pull/931/commits/d4516d83aaf65aee2e2c90ca85b86acacd464c0f transient NaN behaviour error disappears. So I hope this will now always be NaN free.

xlminfei commented 5 months ago

Thank you very much, this improvement is very useful. I am using RTX3090 to predict a 3645aa heterotetramer. With this improvement, the prediction time of a single model has decreased from 59,000 seconds to 43,000 seconds (also out of GPU memory limit).