jasonkyuyim / multiflow

https://arxiv.org/abs/2402.04997
MIT License
119 stars 6 forks source link

Multiflow: protein co-design with discrete and continuous flows

Multiflow is a protein sequence and structure generative model based on our preprint: Generative Flows on Discrete State-Spaces: Enabling Multimodal Flows with Applications to Protein Co-Design.

Our codebase is developed on top of FrameFlow. The sequence generative model is adpated from Discrete Flow Models (DFM).

If you use this codebase, then please cite

@article{campbell2024generative,
  title={Generative Flows on Discrete State-Spaces: Enabling Multimodal Flows with Applications to Protein Co-Design},
  author={Campbell, Andrew and Yim, Jason and Barzilay, Regina and Rainforth, Tom and Jaakkola, Tommi},
  journal={arXiv preprint arXiv:2402.04997},
  year={2024}
}

[!NOTE]
This codebase is very fresh. We expect there to be bugs and issues with other systems and environments. Please create a github issue or pull request and we will attempt to help.

LICENSE: MIT

multiflow-landing-page

Installation

We recommend using mamba. If using mamba then use mamba in place of conda.

# Install environment with dependencies.
conda env create -f multiflow.yml

# Activate environment
conda activate multiflow

# Install local package.
# Current directory should have setup.py.
pip install -e .

Next you need to install torch-scatter manually depending on your torch version. (Unfortunately torch-scatter has some oddity that it can't be installed with the environment.) We use torch 2.0.1 and cuda 11.7 so we install the following

pip install torch-scatter -f https://data.pyg.org/whl/torch-2.0.1+cu117.html

If you use a different torch then that can be found with the following.

# Find your installed version of torch
python
>>> import torch
>>> torch.__version__
# Example: torch 2.0.1+cu117

[!WARNING]
You will likely run into the follow error from DeepSpeed

ModuleNotFoundError: No module named 'torch._six'

If so, replace from torch._six import inf with from torch import inf.

  • /path/to/envs/site-packages/deepspeed/runtime/utils.py
  • /path/to/envs/site-packages/deepspeed/runtime/zero/stage_1_and_2.py

where /path/to/envs is replaced with your path. We would appreciate a pull request to avoid this monkey patch!

Wandb

Our training relies on logging with wandb. Log in to Wandb and make an account. Authorize Wandb here.

Data

We host the datasets on Zenodo here. Download the following files,

Uncompress test data

mkdir test_set tar -xzvf test_set.tar.gz -C test_set/

The resulting directory structure should look like
```bash
<current_dir>
├── train_set
│   ├── processed_pdb
|   |   ├── <subdir>
|   |   |   └── <protein_id>.pkl
│   ├── processed_synthetic
|   |   └── <protein_id>.pkl
├── test_set
|   └── processed
|   |   ├── <subdir>
|   |   |   └── <protein_id>.pkl
...

Our experiments read the data by using relative paths. Keep the directory structure like this to avoid bugs.

Training

The command to run co-design training is the following,

python -W ignore multiflow/experiments/train_se3_flows.py -cn pdb_codesign

We use Hydra to maintain our configs. The training config is found here multiflow/configs/pdb_codesign.yaml.

Most important fields:

Inference

We provide pre-trained model weights at this Zenodo link.

Run the following to unpack the weights

tar -xzvf weights.tar.gz

The following three tasks can be performed.

# Unconditional Co-Design
python -W ignore multiflow/experiments/inference_se3_flows.py -cn inference_unconditional

# Inverse Folding
python -W ignore multiflow/experiments/inference_se3_flows.py -cn inference_inverse_folding

# Forward Folding
python -W ignore multiflow/experiments/inference_se3_flows.py -cn inference_forward_folding

Configs

Config locations:

Most important fields:

[Only for hallucination]