wengong-jin / hgraph2graph

Hierarchical Generation of Molecular Graphs using Structural Motifs
MIT License
375 stars 109 forks source link

Extracting latent vector for a molecule #21

Open michal-pikusa opened 3 years ago

michal-pikusa commented 3 years ago

Hi,

I've successfully trained a generation model on my set of molecules, and I'm able to sample from it with generate.py. However, I was wondering if it's possible to easily extract the latent vectors for a set of input molecules I used as the training set?

I've previously used your JTNN model, and it was easily done with encode_latent function you had in your model class. However, it seems like HierVAE in hgnn.py does not have such a function. Could you point me to any helper functions I would need to invoke to encode a molecule after training and get its latent vector?

Thank you!

ShayneWierbowski commented 3 years ago

I'm just playing around with this repository myself and was also interested in this question (so I'm not extensively familiar with the code). It'd certainly be better to have an official response here, but from what I can tell you should just need to modify the reconstruct method inside the HierVAE model.

    def reconstruct(self, batch):
        # I think these should match the output from the preprocess.py script
        graphs, tensors, _ = batch

        # Reformat as tensors (from numpy arrays?)
        tree_tensors, graph_tensors = tensors = make_cuda(tensors)

        # Encode batch of compounds
        root_vecs, tree_vecs, _, graph_vecs = self.encoder(tree_tensors, graph_tensors)

        # Modify the root_recs embeddings? (Not actually sure what's happening between these two steps)
        # But the output root_vecs here should be the latent embeddings as far as I can tell (so if you just return
        # these instead of running the next step, this should be what you want)
        root_vecs, root_kl = self.rsample(root_vecs, self.R_mean, self.R_var, perturb=False)

        # Convert the laten embedding(s) back into SMILES string(s)
        return self.decoder.decode((root_vecs, root_vecs, root_vecs), greedy=True, max_decode_step=150)

If you aren't sure how to get everything in the right format for the input to the reconstruct method, I wrote up a test case to read in a model from the checkpoint, take a given smiles string, encode it, and then decode it.

Again, I'm not 100% confident this is accurate, my attempted encode-decode process did not reproduce the same compound, so it's possible I'm misunderstanding something here.

# Imports / Arg Parser / Functions
# Copied from generate.py preprocess.py and / or hgraph/hgnn.py

from multiprocessing import Pool
import math, random, sys
import pickle
import argparse
from functools import partial
import torch
import numpy

from hgraph import *
import rdkit

def make_cuda(tensors):
    tree_tensors, graph_tensors = tensors
    make_tensor = lambda x: x if type(x) is torch.Tensor else torch.tensor(x)
    tree_tensors = [make_tensor(x).long() for x in tree_tensors[:-1]] + [tree_tensors[-1]]
    graph_tensors = [make_tensor(x).long() for x in graph_tensors[:-1]] + [graph_tensors[-1]]
    return tree_tensors, graph_tensors

def to_numpy(tensors):
    convert = lambda x : x.numpy() if type(x) is torch.Tensor else x
    a,b,c = tensors
    b = [convert(x) for x in b[0]], [convert(x) for x in b[1]]
    return a, b, c

def tensorize(mol_batch, vocab):
    x = MolGraph.tensorize(mol_batch, vocab, common_atom_vocab)
    return to_numpy(x)

parser = argparse.ArgumentParser()
parser.add_argument('--vocab', required=True)
parser.add_argument('--atom_vocab', default=common_atom_vocab)
parser.add_argument('--model', required=True)

parser.add_argument('--seed', type=int, default=7)
parser.add_argument('--nsample', type=int, default=10000)

parser.add_argument('--rnn_type', type=str, default='LSTM')
parser.add_argument('--hidden_size', type=int, default=250)
parser.add_argument('--embed_size', type=int, default=250)
parser.add_argument('--batch_size', type=int, default=50)
parser.add_argument('--latent_size', type=int, default=32)
parser.add_argument('--depthT', type=int, default=15)
parser.add_argument('--depthG', type=int, default=15)
parser.add_argument('--diterT', type=int, default=1)
parser.add_argument('--diterG', type=int, default=3)
parser.add_argument('--dropout', type=float, default=0.0)

args = parser.parse_args()

# Parse Vocabuary File
vocab = [x.strip("\r\n ").split() for x in open(args.vocab)]
args.vocab = PairVocab(vocab)

# Test Compound to reconstruct
smiles = ['C=Cc1cccc(C(=O)N2CC(c3ccc(F)cc3)C(C)(C)C2)c1']

print("\nINPUT SMILES: {0}\n".format(" ".join(smiles)))

# Convert SMILES String into MolGraph Tree / Graph Tensors
# (See preprocess.py)
o = tensorize(smiles, args.vocab)
batches, tensors, all_orders = o

# Extract pieces we need
tree_tensors, graph_tensors = make_cuda(tensors)

# Load Checkpoint model
model = HierVAE(args)

model.load_state_dict(torch.load(args.model, map_location=torch.device('cpu'))[0])
model.eval()

# Encode compound
root_vecs, tree_vecs, _, graph_vecs = model.encoder(tree_tensors, graph_tensors)

print("\nLATENT_EMBEDDING\n")
print(root_vecs)
print("\n")

# Unsure what this second step does / what the difference between
# the first and second root_vecs values are?
root_vecs, root_kl = model.rsample(root_vecs, model.R_mean, model.R_var, perturb=False)

print("\nLATENT_EMBEDDING_2\n")
print(root_vecs)
print("\n")

# Decode compound
decoded_smiles = model.decoder.decode((root_vecs, root_vecs, root_vecs), greedy=True, max_decode_step=150)

# The decoded and original smiles / compound do not match
# Not sure if this is because something is done wrong or just
# because this compound is one that couldn't be reconstructed
# accurately
print("DECODED SMILES: {0}".format("".join(decoded_smiles)))
ShayneWierbowski commented 3 years ago

It's possible the reason the original vs. decoded SMILES strings don't match up is related to underlying rdkit implementations referenced in other issues (#20).

When I expand this to test multiple compounds I start getting two errors I believe are linked to this...

Traceback (most recent call last):
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/test_smiles2tensor.py", line 70, in <module>
    o = tensorize(smiles, args.vocab)
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/test_smiles2tensor.py", line 29, in tensorize
    x = MolGraph.tensorize(mol_batch, vocab, common_atom_vocab)
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/hgraph/mol_graph.py", line 152, in tensorize
    mol_batch = [MolGraph(x) for x in mol_batch]
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/hgraph/mol_graph.py", line 152, in <listcomp>
    mol_batch = [MolGraph(x) for x in mol_batch]
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/hgraph/mol_graph.py", line 22, in __init__
    self.order = self.label_tree()
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/hgraph/mol_graph.py", line 123, in label_tree
    tree.nodes[i]['assm_cands'] = get_assm_cands(mol, hist, inter_label, pa_cls, len(inter_atoms))
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/hgraph/chemutils.py", line 120, in get_assm_cands
    mol = get_clique_mol(mol, atoms)
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/hgraph/chemutils.py", line 111, in get_clique_mol
    smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True)
rdkit.Chem.rdchem.KekulizeException: Can't kekulize mol.  Unkekulized atoms: 22

and

Traceback (most recent call last):
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/test_smiles2tensor.py", line 70, in <module>
    o = tensorize(smiles, args.vocab)
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/test_smiles2tensor.py", line 29, in tensorize
    x = MolGraph.tensorize(mol_batch, vocab, common_atom_vocab)
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/hgraph/mol_graph.py", line 153, in tensorize
    tree_tensors, tree_batchG = MolGraph.tensorize_graph([x.mol_tree for x in mol_batch], vocab)
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/hgraph/mol_graph.py", line 194, in tensorize_graph
    fnode[v] = vocab[attr]
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/hgraph/vocab.py", line 43, in __getitem__
    return self.hmap[x[0]], self.vmap[x]
KeyError: 'C1=CNN=C1'

Will try to get the package versions corrected and confirm if this is the problem.

michal-pikusa commented 3 years ago

Thank you @ShayneWierbowski! Your code actually works pretty nice, I think.

I've been able to encode a set of molecules I work with (around 4,000 of them), and reconstruct them back with your code, and for my set ~ 20% are reconstructed correctly 1:1. The rest is not 1:1, but Tanimoto similarity shows me that median similarity is ~0.8 meaning that most molecules still are very similar to the original with minor modifications. This can probably be improved with the training hyperparameters, as I see a similar thing with their previous work (JT-VAE).

As to your error, make sure you are trying to encode molecules that are part of your training set, or have all the relevant motifs from the vocabulary. You cannot encode something that is out of vocabulary, and I think that's why you are getting the KeyError. I haven't gotten a single one while encoding all molecules from my training set.

Thanks again for your help.

ShayneWierbowski commented 3 years ago

@michal-pikusa I'm glad you found this helpful!

Thanks for your suggestion about the vocabulary as well. This is definitely an important consideration. In my case I think it was linked to the RDKit version and I was able to get everything running smoothly after reinstalling a correct version.

From my expanded evaluation I came up with similar results as you (~20% perfectly reconstructed, the rest with generally high Tanimoto similarity).

From your experience retraining the model and / or tweaking hyper-parameters for a different drug / vocabulary set do you have a sense of how long the training takes? I haven't played with that yet.

michal-pikusa commented 3 years ago

@ShayneWierbowski: Training on my 4k set took ~ 30 minutes on a single GPU with 16GB RAM, so it's really fast.