Implementation of Alphafold 3 in Pytorch
MIT License
Alphafold 3 - Pytorch

Implementation of Alphafold 3 in Pytorch

You can chat with other researchers about this work here

Review of the paper by Sergey

Illustrated guide by Elana P. Simon

Talk by Max Jaderberg

A fork with full Lightning + Hydra support is being maintained by Alex at this repository

A visualization of the molecules of life used in the repository can be seen and interacted with here



$ pip install alphafold3-pytorch


import torch
from alphafold3_pytorch import Alphafold3
from alphafold3_pytorch.utils.model_utils import exclusive_cumsum

alphafold3 = Alphafold3(
    dim_atom_inputs = 77,
    dim_template_feats = 108

# mock inputs

seq_len = 16

molecule_atom_indices = torch.randint(0, 2, (2, seq_len)).long()
molecule_atom_lens = torch.full((2, seq_len), 2).long()

atom_seq_len = molecule_atom_lens.sum(dim=-1).amax()
atom_offsets = exclusive_cumsum(molecule_atom_lens)

atom_inputs = torch.randn(2, atom_seq_len, 77)
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)

additional_molecule_feats = torch.randint(0, 2, (2, seq_len, 5))
additional_token_feats = torch.randn(2, seq_len, 33)
is_molecule_types = torch.randint(0, 2, (2, seq_len, 5)).bool()
is_molecule_mod = torch.randint(0, 2, (2, seq_len, 4)).bool()
molecule_ids = torch.randint(0, 32, (2, seq_len))

template_feats = torch.randn(2, 2, seq_len, seq_len, 108)
template_mask = torch.ones((2, 2)).bool()

msa = torch.randn(2, 7, seq_len, 32)
msa_mask = torch.ones((2, 7)).bool()

additional_msa_feats = torch.randn(2, 7, seq_len, 2)

# required for training, but omitted on inference

atom_pos = torch.randn(2, atom_seq_len, 3)

distogram_atom_indices = molecule_atom_lens - 1

distance_labels = torch.randint(0, 37, (2, seq_len, seq_len))
resolved_labels = torch.randint(0, 2, (2, atom_seq_len))

# offset indices correctly

distogram_atom_indices += atom_offsets
molecule_atom_indices += atom_offsets

# train

loss = alphafold3(
    num_recycling_steps = 2,
    atom_inputs = atom_inputs,
    atompair_inputs = atompair_inputs,
    molecule_ids = molecule_ids,
    molecule_atom_lens = molecule_atom_lens,
    additional_molecule_feats = additional_molecule_feats,
    additional_msa_feats = additional_msa_feats,
    additional_token_feats = additional_token_feats,
    is_molecule_types = is_molecule_types,
    is_molecule_mod = is_molecule_mod,
    msa = msa,
    msa_mask = msa_mask,
    templates = template_feats,
    template_mask = template_mask,
    atom_pos = atom_pos,
    distogram_atom_indices = distogram_atom_indices,
    molecule_atom_indices = molecule_atom_indices,
    distance_labels = distance_labels,
    resolved_labels = resolved_labels


# after much training ...

sampled_atom_pos = alphafold3(
    num_recycling_steps = 4,
    num_sample_steps = 16,
    atom_inputs = atom_inputs,
    atompair_inputs = atompair_inputs,
    molecule_ids = molecule_ids,
    molecule_atom_lens = molecule_atom_lens,
    additional_molecule_feats = additional_molecule_feats,
    additional_msa_feats = additional_msa_feats,
    additional_token_feats = additional_token_feats,
    is_molecule_types = is_molecule_types,
    is_molecule_mod = is_molecule_mod,
    msa = msa,
    msa_mask = msa_mask,
    templates = template_feats,
    template_mask = template_mask

sampled_atom_pos.shape # (2, <atom_seqlen>, 3)

An example with molecule level input handling

import torch
from alphafold3_pytorch import Alphafold3, Alphafold3Input

contrived_protein = 'AG'

mock_atompos = [
    torch.randn(5, 3),   # alanine has 5 non-hydrogen atoms
    torch.randn(4, 3)    # glycine has 4 non-hydrogen atoms

train_alphafold3_input = Alphafold3Input(
    proteins = [contrived_protein],
    atom_pos = mock_atompos

eval_alphafold3_input = Alphafold3Input(
    proteins = [contrived_protein]

# training

alphafold3 = Alphafold3(
    dim_atom_inputs = 3,
    dim_atompair_inputs = 5,
    atoms_per_window = 27,
    dim_template_feats = 108,
    num_molecule_mods = 0,
    confidence_head_kwargs = dict(
        pairformer_depth = 1
    template_embedder_kwargs = dict(
        pairformer_stack_depth = 1
    msa_module_kwargs = dict(
        depth = 1
    pairformer_stack = dict(
        depth = 2
    diffusion_module_kwargs = dict(
        atom_encoder_depth = 1,
        token_transformer_depth = 1,
        atom_decoder_depth = 1,

loss = alphafold3.forward_with_alphafold3_inputs([train_alphafold3_input])

# sampling

sampled_atom_pos = alphafold3.forward_with_alphafold3_inputs(eval_alphafold3_input)

assert sampled_atom_pos.shape == (1, (5 + 4), 3)

Data preparation

PDB dataset curation

To acquire the AlphaFold 3 PDB dataset, first download all first-assembly (and asymmetric unit) complexes in the Protein Data Bank (PDB), and then preprocess them with the script referenced below. The PDB can be downloaded from the RCSB: The two Python scripts below (i.e., filter_pdb_{train,val,test} and cluster_pdb_{train,val,test} assume you have downloaded the PDB in the mmCIF file format, placing its first-assembly and asymmetric unit mmCIF files at data/pdb_data/unfiltered_assembly_mmcifs/ and data/pdb_data/unfiltered_asym_mmcifs/, respectively.

For reproducibility, we recommend downloading the PDB using AWS snapshots (e.g., 20240101). To do so, refer to AWS's documentation to set up the AWS CLI locally. Alternatively, on the RCSB website, navigate down to "Download Protocols", and follow the download instructions depending on your location.

For example, one can use the following commands to download the PDB as two collections of mmCIF files:

# For `assembly1` complexes, use the PDB's `20240101` AWS snapshot:
aws s3 sync s3://pdbsnapshots/20240101/pub/pdb/data/assemblies/mmCIF/divided/ ./data/pdb_data/unfiltered_assembly_mmcifs
# Or as a fallback, use rsync:
rsync -rlpt -v -z --delete --port=33444 \ ./data/pdb_data/unfiltered_assembly_mmcifs/

# For asymmetric unit complexes, also use the PDB's `20240101` AWS snapshot:
aws s3 sync s3://pdbsnapshots/20240101/pub/pdb/data/structures/divided/mmCIF/ ./data/pdb_data/unfiltered_asym_mmcifs
# Or as a fallback, use rsync:
rsync -rlpt -v -z --delete --port=33444 \ ./data/pdb_data/unfiltered_asym_mmcifs/

WARNING: Downloading the PDB can take up to 700GB of space.

NOTE: The PDB hosts all available AWS snapshots here:

After downloading, you should have two directories formatted like this: &


For these directories, unzip all the files:

find ./data/pdb_data/unfiltered_assembly_mmcifs/ -type f -name "*.gz" -exec gzip -d {} \;
find ./data/pdb_data/unfiltered_asym_mmcifs/ -type f -name "*.gz" -exec gzip -d {} \;

Next run the commands

wget -P ./data/ccd_data/
wget -P ./data/ccd_data/

from the project's root directory to download the latest version of the PDB's Chemical Component Dictionary (CCD) and its structural models. Extract each of these files using the following command:

find data/ccd_data/ -type f -name "*.gz" -exec gzip -d {} \;

PDB dataset filtering

Then run the following with pdb_assembly_dir, pdb_asym_dir, ccd_dir, and mmcif_output_dir replaced with the locations of your local copies of the first-assembly PDB, asymmetric unit PDB, CCD, and your desired dataset output directory (i.e., ./data/pdb_data/unfiltered_assembly_mmcifs/, ./data/pdb_data/unfiltered_asym_mmcifs/, ./data/ccd_data/, and ./data/pdb_data/{train,val,test}_mmcifs/).

python scripts/ --mmcif_assembly_dir <pdb_assembly_dir> --mmcif_asym_dir <pdb_asym_dir> --ccd_dir <ccd_dir> --output_dir <mmcif_output_dir>
python scripts/ --mmcif_assembly_dir <pdb_assembly_dir> --mmcif_asym_dir <pdb_asym_dir> --output_dir <mmcif_output_dir>
python scripts/ --mmcif_assembly_dir <pdb_assembly_dir> --mmcif_asym_dir <pdb_asym_dir> --output_dir <mmcif_output_dir>

See the scripts for more options. Each first-assembly mmCIF that successfully passes all processing steps will be written to mmcif_output_dir within a subdirectory named according to the mmCIF's second and third PDB ID characters (e.g. 5c).

PDB dataset clustering

Next, run the following with mmcif_dir and {train,val,test}_clustering_output_dir replaced, respectively, with your local output directory created using the dataset filtering script above and with your desired clustering output directories (i.e., ./data/pdb_data/{train,val,test}_mmcifs/ and ./data/pdb_data/data_caches/{train,val,test}_clusterings/):

python scripts/ --mmcif_dir <mmcif_dir> --output_dir <train_clustering_output_dir> --clustering_filtered_pdb_dataset
python scripts/ --mmcif_dir <mmcif_dir> --reference_clustering_dir <train_clustering_output_dir> --output_dir <val_clustering_output_dir> --clustering_filtered_pdb_dataset
python scripts/ --mmcif_dir <mmcif_dir> --reference_1_clustering_dir <train_clustering_output_dir> --reference_2_clustering_dir <val_clustering_output_dir> --output_dir <test_clustering_output_dir> --clustering_filtered_pdb_dataset

Note: The --clustering_filtered_pdb_dataset flag is recommended when clustering the filtered PDB dataset as curated using the scripts above, as this flag will enable faster runtimes in this context (since filtering leaves each chain's residue IDs 1-based). However, this flag must not be provided when clustering other (i.e., non-PDB) datasets of mmCIF files. Otherwise, interface clustering may be performed incorrectly, as these datasets' mmCIF files may not use strict 1-based residue indexing for each chain.

Note: One can instead download preprocessed (i.e., filtered) mmCIF (train/val/test) files (~25GB, comprising 148k complexes) and chain/interface clustering (train/val/test) files (~3GB) for the PDB's 20240101 AWS snapshot via a shared OneDrive folder. Each of these tar.gz archives should be decompressed within the data/pdb_data/ directory e.g., via tar -xzf data_caches.tar.gz -C data/pdb_data/. One can also download and prepare PDB distillation data using as a reference the script scripts/ Once downloaded, one can run scripts/ to filter this dataset to only examples associated with at least one PDB entry. Moreover, for convenience, a mapping of UniProt accession IDs to PDB IDs for training on PDB distillation data has already been downloaded and extracted as data/afdb_data/data_caches/uniprot_to_pdb_id_mapping.dat.


At the project root, run

$ sh ./

Then, add your module to alphafold3_pytorch/, add your tests to tests/, and submit a pull request. You can run the tests locally with

$ pytest tests/

Docker Image

The included Dockerfile contains the required dependencies to run the package and to train/inference using PyTorch with GPUs.

The default base image is pytorch/pytorch:2.3.0-cuda12.1-cudnn8-runtime and installs the latest version of this package from the main GitHub branch.

## Build Docker Container
docker build -t af3 .

Alternatively, use build arguments to rebuild the image with different software versions:

For example:

## Use build argument to change versions
docker build --build-arg "PYTORCH_TAG=2.2.1-cuda12.1-cudnn8-devel" --build-arg "GIT_TAG=0.1.15" -t af3 .

Then, run the container with GPUs and mount a local volume (for training) using the following command:

## Run Container
docker run -v .:/data --gpus all -it af3


