pengxingang / Pocket2Mol

Pocket2Mol: Efficient Molecular Sampling Based on 3D Protein Pockets
MIT License
242 stars 65 forks source link

Evaluation with different training poses? #18

Open ljl123ljl opened 1 year ago

ljl123ljl commented 1 year ago

Wonderful work!

Found that it was trained using 100 k poses, is there significant gain with more poses included in training set?

I am trying to scale the data and see what happens? Is the code with filtering logic from crossdocked available.

pengxingang commented 1 year ago

Thanks for your interest in our work. We hadn't tried training the model with more data but we speculated that more data are useful because the current training set is still limited. As for the dataset filtering and processing, we directly used the data sets of the previous work, which filtered out data points with binding pose RMSD greater than 1Å, use mmseqs2 to cluster data at 30\% sequence identity and then randomly drew 100,000 pairs for training. You can refer to that work for more details about data filtering.

ljl123ljl commented 1 year ago

Thank u for detailed reply! I am interested in relax the data scope of RMSD < 2A, and found that it2_tt_0_lowrmsd_mols_train/test0.types has 486k/238k respectively, maybe that is a good startpoint.

But crossdocked datasets provide recptor/gninatypes, while crossdocked 10 provided here with pocket and docked sdf.

I wonder if the data preprocessing can be code available here :), kind of new learner in drug design!

pengxingang commented 1 year ago

You can refer to the CrossDocked dataset for the details of the dataset. We found a file clean_crossdocked.py which might be helpful:

import os
import shutil
import pickle
import argparse
from tqdm.auto import tqdm

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--source', type=str, default='./data/crossdocked')
    parser.add_argument('--dest', type=str, required=True)
    parser.add_argument('--rmsd_thr', type=float, default=1.0)
    args = parser.parse_args()

    os.makedirs(args.dest, exist_ok=False)
    types_path = os.path.join(args.source, 'types/it2_tt_completeset_train0.types')

    index = []
    with open(types_path, 'r') as f:
        for ln in tqdm(f):
            _, _, rmsd, protein_fn, ligand_fn, _ = ln.split()
            rmsd = float(rmsd)
            if rmsd > args.rmsd_thr: 
                continue

            ligand_id = int(ligand_fn[ligand_fn.rfind('_')+1:ligand_fn.rfind('.')])

            protein_fn = protein_fn[:protein_fn.rfind('_')] + '.pdb'
            ligand_raw_fn = ligand_fn[:ligand_fn.rfind('_')] + '.sdf'
            protein_path = os.path.join(args.source, protein_fn)
            ligand_raw_path = os.path.join(args.source, ligand_raw_fn)
            if not (os.path.exists(protein_path) and os.path.exists(ligand_raw_path)):
                continue

            with open(ligand_raw_path, 'r') as f:
                ligand_sdf = f.read().split('$$$$\n')[ligand_id]
            ligand_save_fn = ligand_fn[:ligand_fn.rfind('.')] + '.sdf'  # include ligand id

            protein_dest = os.path.join(args.dest, protein_fn)
            ligand_dest = os.path.join(args.dest, ligand_save_fn)
            os.makedirs(os.path.dirname(protein_dest), exist_ok=True)
            os.makedirs(os.path.dirname(ligand_dest), exist_ok=True)
            shutil.copyfile(protein_path, protein_dest)
            with open(ligand_dest, 'w') as f:
                f.write(ligand_sdf)

            index.append((protein_fn, ligand_save_fn, rmsd))

    index_path = os.path.join(args.dest, 'index.pkl')
    with open(index_path, 'wb') as f:
        pickle.dump(index, f)

    print('Done. %d protein-ligand pairs in total.' % len(index))

Besides, you can also refer to the data processing part of a previous method LiGAN.