google-deepmind / alphafold

Open source code for AlphaFold 2.
Apache License 2.0
12.86k stars 2.28k forks source link

Port to Metal and macOS #604

Open philipturner opened 2 years ago

philipturner commented 2 years ago

I've been taking a shot at porting this from NVIDIA/CUDA to M1/Metal, so that I can run it on my personal MacBook Pro (1 TB disk). I anticipated being able to port the entire framework in a single day, but faced significant challenges regarding disk space. It would have been much less challenging if the total database collection were even just halved in size.

I wish to engineer proteins that never existed before in biology, and may have no evolutionary history. AlphaFold would let me rapidly search the solution space of amino acid sequences, until encountering proteins that suit my needs. I fear that AlphaFold, heavily tuned for MSAs, would perform poorly on proteins foreign to biological evolution. Nevertheless, this would be a great tool in my toolbox.

I'm thinking of either (1) purchasing an external HDD or (2) seeing whether you reduce the collective database size from 600 (actually 700) GB to something like 200 GB. When/if I finish the port, would you consider merging these contributions into the main branch? I don't think that's very likely, but "never say never".

Fork and documentation of my porting efforts: https://github.com/philipturner/alphafold-metal

tjonesster commented 1 year ago

The big issue is really just jaxlib. If there were an xla backend for metal then I think this would be a pretty trival problem. I think that there is a jaxlib built for mac arm but I don't think that it currently has a metal backend.

philipturner commented 1 year ago

I did just get a 4 TB HDD, so I can try again. It just depends on whether I want to explore protein engineering. Porting stuff to Metal is my strongsuit.

ChrisLou-bioinfo commented 11 months ago

Waiting for your day3

philipturner commented 11 months ago

Apple already ported AlphaFold to Metal, with the Metal backend for JAX. It was likely inspired by their internal Ajax project experimenting with large language models.

AlphaFold isn't relevant to my work now. It wasn't actually relevant to protein engineering before, because you can design proteins that don't require an AI to characterize. That was Eric Drexler's point decades ago that he still believes now. Design monomers to be predictable.

tomgoddard commented 11 months ago

There is a jax-metal backend created by Apple described here

https://developer.apple.com/metal/jax/

I tried jax-metal version 0.0.4 with localcolabfold (https://github.com/YoshitakaMo/localcolabfold) on a MacBook Pro M1 Max (32 GPU cores, model MacBook18,2) and macOS 14.1.2. An Alphafold prediction of a 40 amino acid protein (PDB 8ff2 chain A) took 15 minutes and produced completely incorrect structures with all residues on top of each other. There were no error messages and AlphaFold reported it was using the GPU and Activity Monitor indicated 100% GPU utilization. I ran the same prediction without jax-metal installed and jax used the CPU, completed in 11 minutes and made physically reasonable predictions. By comparison this prediction on Linux with Nvidia 3090 takes several seconds, about 100 times faster.

So it appears that Apple's early jax-metal 0.0.4 has problems that cause incorrect calculations and runs slower than jax on the CPU, so is not ready for use with Alphafold.

The simple numpy test of jax-metal given on Apple's jax-metal web page issues a warning that "JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!"

(jax-metal) $ python -c 'import jax; print(jax.numpy.arange(10))'
2023-12-06 23:47:53.336057: W pjrt_plugin/src/mps_client.cc:534] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Max

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB

[0 1 2 3 4 5 6 7 8 9]
(jax-metal) $ 
sokrypton commented 11 months ago

I recall the Jax metal version being very buggy. For example jnp.eye would return an all zeros matrix instead of identity matrix.

aldospanjaard commented 11 months ago

I tried jax-metal version 0.0.4 with localcolabfold (https://github.com/YoshitakaMo/localcolabfold) on a MacBook Pro M1 Max (32 GPU cores, model MacBook18,2) and macOS 14.1.2. An Alphafold prediction of a 40 amino acid protein (PDB 8ff2 chain A) took 15 minutes and produced completely incorrect structures with all residues on top of each other. There were no error messages and AlphaFold reported it was using the GPU and Activity Monitor indicated 100% GPU utilization. I ran the same prediction without jax-metal installed and jax used the CPU, completed in 11 minutes and made physically reasonable predictions. By comparison this prediction on Linux with Nvidia 3090 takes several seconds, about 100 times faster.

How did you get localcolabfold to run on apple metal? It normally requires CUDA to run on the gpu right? I'm very interested to get either alphafold or localcolabfold running on my M2 macbook pro max. If you could share the steps you took to get this to work, I'd be very grateful. I already installed metal versions of jax, tensorflow and PyTorch.

tomgoddard commented 11 months ago

As noted in my previous comment and Sergey's previous comment, jax-metal 0.0.4 is not working correctly on Apple ARM GPUs. If you are interested in trying it anyways, you simply install localcolabfold according to the instructions on that project's github site under "For Mac with Apple Silicon (M1 chip)"

https://github.com/YoshitakaMo/localcolabfold/blob/main/README.md#for-mac-with-apple-silicon-m1-chip

Then install jax-metal 0.0.4 as instructed on Apple's web page

https://developer.apple.com/metal/jax/

Note that jax-metal 0.0.4 requires jax version 0.4.11 but the localcolabfold installer installs jax 0.3.25. So you will need to then update jax. And that also requires updating dm-haiku which is an ML library on top of jax. If I recall I did that with

pip3 install jax==0.4.11 dm-haiku=0.10.0

Then a test run of AlphaFold on a single sequence worked without errors, but gave completely wrong results with all residues on top of each other, and took a long time, as I noted before more than 100 times slower than an Nvidia 3090 / Linux system. Running on the Mac using only the CPU hence not using jax-metal worked fine, also extremely slow (although faster than with jax-metal). To run with the jax CPU backend (not using jax-metal) you can set an environment variable in the shell before you run colabfold_batch:

export CUDA_VISIBLE_DEVICES=-1
philipturner commented 11 months ago

I'm going to turn off notifications for this issue, @mention me if my expertise is needed here.

tomgoddard commented 10 months ago

Apple released new version 0.0.5 of jax-metal on December 21, 2023. I tried an alphafold run with localcolabfold on PDB 8ff2_A (40 amino acids) as in my comment above from December 15, 2023 and got the same results. The run completed in 16 minutes without errors and said it was using jax-metal and warned all JAX functionality is not supported

2024-01-21 15:29:38,774 Platform 'METAL' is experimental and not all JAX functionality may be correctly supported! 2024-01-21 15:29:50,797 Running on GPU

and the result had all residues on top of each other (image attached).

8ff2_a_jax_metal 0 0 5