Official implementation of RetroBridge, a Markov bridge model for retrosynthesis planning by Ilia Igashov, Arne Schneuing, Marwin Segler, Michael Bronstein and Bruno Correia.
We model single-step retrosynthesis planning as a distribution learning problem in a discrete state space. First, we introduce the Markov Bridge Model, a generative framework aimed to approximate the dependency between two intractable discrete distributions accessible via a finite sample of coupled data points. Our framework is based on the concept of a Markov bridge, a Markov process pinned at its endpoints. Unlike diffusion-based methods, our Markov Bridge Model does not need a tractable noise distribution as a sampling proxy and directly operates on the input product molecules as samples from the intractable prior distribution. We then address the retrosynthesis planning problem with our novel framework and introduce RetroBridge, a template-free retrosynthesis modeling approach that achieves state-of-the-art results on standard evaluation benchmarks.
Software | Version |
---|---|
Python | 3.9 |
CUDA | 11.6 |
conda create --name retrobridge python=3.9 rdkit=2023.09.5 -c conda-forge -y
conda activate retrobridge
pip install -r requirements.txt
To sample reactants for a given product molecule:
mkdir -p models
wget https://zenodo.org/record/10688201/files/retrobridge.ckpt?download=1 -O models/retrobridge.ckpt
python predict.py --smiles "CN1C=NC2=C1C(=O)N(C(=O)N2C)C" --checkpoint models/retrobridge.ckpt
python train.py --config configs/retrobridge.yaml --model RetroBridge
python train.py --config configs/digress.yaml --model DiGress
python mit/train.py --config configs/forwardbridge.yaml
Trained models can be downloaded from Zenodo:
mkdir -p models
wget https://zenodo.org/record/10688201/files/retrobridge.ckpt?download=1 -O models/retrobridge.ckpt
wget https://zenodo.org/record/10688201/files/digress.ckpt?download=1 -O models/digress.ckpt
wget https://zenodo.org/record/10688201/files/forwardbridge.ckpt?download=1 -O models/forwardbridge.ckpt
Sampling with RetroBridge model:
python sample.py \
--config configs/retrobridge.yaml \
--checkpoint models/retrobridge.ckpt \
--samples samples \
--model RetroBridge \
--mode test \
--n_samples 10 \
--n_steps 500 \
--sampling_seed 1
Sampling with DiGress:
python sample.py \
--config configs/digress.yaml \
--checkpoint models/digress.ckpt \
--samples samples \
--model DiGress \
--mode test \
--n_samples 10 \
--n_steps 500 \
--sampling_seed 1
Sampling with ForwardBridge:
python sample_MIT.py \
--config configs/forwardbridge.yaml \
--checkpoint models/forwardbridge.ckpt \
--samples samples \
--model RetroBridge \
--mode test \
--n_samples 10 \
--n_steps 500 \
--sampling_seed 1
Download Molecular Transformer and follow the instructions on their GitHub page
To make forward predictions for all generated reactants, run:
python /src/metrics/round_trip.py --csv_file <path/to/retrobridge_csv> --csv_out <path/to/output_csv> --mol_trans_dir <path/to/MolecularTransformer_dir>
To compute the metrics reported in the paper, run the following commands in python:
import numpy as np
import pandas as pd
from pathlib import Path
from src.metrics.eval_csv_helpers import canonicalize, compute_confidence, assign_groups, compute_accuracy
csv_file = Path('<path/to/output_csv>')
df = pd.read_csv(csv_file)
df = assign_groups(df, samples_per_product_per_file=10)
df.loc[(df['product'] == 'C') & (df['true'] == 'C'), 'true'] = 'Placeholder'
df_processed = compute_confidence(df)
for key in ['product', 'pred_product']:
df_processed[key] = df_processed[key].apply(canonicalize)
compute_accuracy(df_processed, top=[1, 3, 5, 10], scoring=lambda df: np.log(df['confidence']))
RetroBridge is released under CC BY-NC 4.0 license.
If you have any questions, please contact ilia.igashov@epfl.ch or arne.schneuing@epfl.ch.