Genentech / voxmol

Apache License 2.0
12 stars 0 forks source link

How to visualize results and display result such as molecular stability in QM9 dataset #1

Open Littletoyone opened 3 months ago

Littletoyone commented 3 months ago

Thank you for providing the code so that I could see how simple and effective it could be written. No doubt it is a very good code, but I still have some questions. First of all, for example, about how to visualize the .xyz file, I did a reproduction operation and found that the Xyz file is full of coordinates. Secondly, about the metrics such as molecular stability, atomic stability, etc., I already make the xyz file change to molecule class by using the function "data_atom_xyz" Offered by Midi,but when I using the function "sampling_metrics" which Midi used to caculation metrics.I can not use it stand-alone...

chefbaker4 commented 2 months ago

Here is how I managed to calculate validity, uniqueness, and novelty

from rdkit import Chem
import torch

# following functions are adapdted from the file rdkit_functions.py 
# found at this link - https://github.com/ehoogeboom/e3_diffusion_for_molecules
# and this paper - EDM: E(3) Equivariant Diffusion Model for Molecule Generation in 3D.

def mol2smiles(mol):
    try:
        Chem.SanitizeMol(mol)
    except ValueError:
        return None
    return Chem.MolToSmiles(mol, isomericSmiles=False)

def compute_validity(generated):
    ''' generated: list of rdkit molecules '''
    valid = []

    for mol in generated:
        if mol is not None:
            smiles = mol2smiles(mol)
            if smiles is not None:
                mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True)
                largest_mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms())
                smiles = mol2smiles(largest_mol)
                valid.append(smiles)

    return valid, len(valid) / len(generated)

def compute_uniqueness(valid):
    """ valid: list of SMILES strings."""
    return list(set(valid)), len(set(valid)) / len(valid)

def compute_novelty(unique):
    """ unique: list of SMILES strings."""

    num_novel = 0
    novel = []
    # access saved qm9 data to extract  smiles  from training data
    qm9_data = torch.load('/voxmol/voxmol/dataset/data/qm9/train_data.pth')
    qm9_train_smiles = []

    for mol in qm9_data:
        smile = mol['smiles']
        if smile:
            qm9_train_smiles.append(smile)

    for smiles in unique:
        if smiles not in qm9_train_smiles:
            novel.append(smiles)
            num_novel += 1

    return novel, num_novel / len(unique)

def main():
    suppl = Chem.SDMolSupplier('/voxmol/voxmol/exps/exp_qm9_sig0.9_lr2e-05/xyzs/_s500_ms1000/molecules_obabel.sdf')

    print(f'{len(suppl)} generated molecules\n')

    valid_list, valid_pct = compute_validity(suppl)
    print(f'{len(valid_list)} valid molecules')
    print(f'Validity: {(100*valid_pct):.2f}%\n')

    unique_list, unique_pct = compute_uniqueness(valid_list)
    print(f'{len(unique_list)} unique molecules')
    print(f'Uniqueness: {(100*unique_pct):.2f}%\n')

    novel_list, novel_pct = compute_novelty(unique_list)
    print(f'{len(novel_list)} novel molecules')
    print(f'Novelty: {(100*novel_pct):.2f}%\n')

main()