kaist-amsg / LocalRetro

Retrosynthesis prediction for organic molecules with LocalRetro
81 stars 24 forks source link

Update Run_preprocessing.py #5

Closed xiaohongniua closed 2 years ago

xiaohongniua commented 2 years ago

Might be a bug? See the comments in the code.

shuan4638 commented 2 years ago

Hi Hong,

The order of B is actually made same with the order of the adjacency matrix of dgl graph, not directly from the dgl bond order (see the function pair_atom_feats at https://github.com/kaist-amsg/LocalRetro/blob/main/scripts/model_utils.py). The following code shows the bond order identity.

import torch
import dgl

from rdkit import Chem
from dgllife.utils import smiles_to_bigraph

def get_edit_site_retro(smiles): # the function in Run_preprocessing.py
    mol = Chem.MolFromSmiles(smiles)
    A = [a for a in range(mol.GetNumAtoms())]
    B = []
    for atom in mol.GetAtoms():
        others = []
        bonds = atom.GetBonds()
        for bond in bonds:
            atoms = [bond.GetBeginAtom().GetIdx(), bond.GetEndAtom().GetIdx()]
            other = [a for a in atoms if a != atom.GetIdx()][0]
            others.append(other)
        b = [(atom.GetIdx(), other) for other in sorted(others)]
        B += b
    return A, B

smiles = 'O=C1NC(=O)CCC1'
_, B = get_edit_site_retro(smiles)
g = smiles_to_bigraph(smiles, canonical_atom_order=False)
Adj = torch.transpose(g.adjacency_matrix().coalesce().indices(), 0, 1)

for bond, edge in zip(B, Adj):
    print (bond, edge)

gives

(0, 1) tensor([0, 1])
(1, 0) tensor([1, 0])
(1, 2) tensor([1, 2])
(1, 7) tensor([1, 7])
(2, 1) tensor([2, 1])
(2, 3) tensor([2, 3])
(3, 2) tensor([3, 2])
(3, 4) tensor([3, 4])
(3, 5) tensor([3, 5])
(4, 3) tensor([4, 3])
(5, 3) tensor([5, 3])
(5, 6) tensor([5, 6])
(6, 5) tensor([6, 5])
(6, 7) tensor([6, 7])
(7, 1) tensor([7, 1])
(7, 6) tensor([7, 6])
xiaohongniua commented 2 years ago

That is it!!!So you changed the order of bond edit logits in the "pair_atom_feats" dynamiclly!!!!Thanks a lot. It was kind of confusing to me.

shuan4638 commented 2 years ago

Sorry for the confusion, I believe you are the only person who gets confused! I actually wrote this when I was unfamiliar with dgl function, and I relaized I can did the same thing in a simplier way:

def pair_atom_feats(g, node_feats):
    sg = g.remove_self_loop() # in case g includes self-loop
    atom_idx1, atom_idx2 = g.edges()
    atom_pair_feats = torch.cat((node_feats[atom_idx1.long()], node_feats[atom_idx1.long()]), dim = 1)
    return atom_pair_feats

This way the bond is made by the order of dgl edges, and I will need to change the get_edit_site function function in Run_preprocessing.py. Thanks for pointing this out! Will fix this later.

xiaohongniua commented 2 years ago

Thanks very much for your time and patience. And thanks for your work. I have learned a lot. Either way is ok. if you are going to change pair_atom_feats, get_edit_site might be :

def get_edit_site_retro(smiles):
    mol = Chem.MolFromSmiles(smiles)
    A = [a for a in range(mol.GetNumAtoms())]
    B = []
    src_list = []
    dst_list = []
    num_bonds = mol.GetNumBonds()
    for i in range(num_bonds):
        bond = mol.GetBondWithIdx(i)
        u = bond.GetBeginAtomIdx()
        v = bond.GetEndAtomIdx()
        src_list.extend([u, v])
        dst_list.extend([v, u])
    for s, d in zip(src_list, dst_list):
        B.append((s, d))
    return A, B

Thanks again. And I will close this issue.

shuan4638 commented 2 years ago

No problem! I wouldn't have noticed this if you didn't tell me :) I guess the code you wrote here is from the source code of dgl-lifesci (https://github.com/awslabs/dgl-lifesci/blob/master/python/dgllife/utils/mol_to_graph.py) I would write

def get_edit_site_retro(smiles): # the function in Run_preprocessing.py
    mol = Chem.MolFromSmiles(smiles)
    A = [a for a in range(mol.GetNumAtoms())]
    B = []
    for bond in mol.GetBonds():
        u, v = bond.GetBeginAtom().GetIdx(), bond.GetEndAtom().GetIdx()
        B += [(u, v), (v, u)]
    return A, B

for faster preprocessing speed.