Rose-STL-Lab / LIMO

generative model for drug discovery
59 stars 14 forks source link

Substructure-constrained logP Extremization & logP Targeting #6

Closed TryLittleHarder closed 2 years ago

TryLittleHarder commented 2 years ago

Hi, you mentioned Substructure-constrained logP Extremization & logP Targeting in your paper, but I don't seem to see either part in the code. Could you please provide more details about how to realise them? Thanks

PeterEckmann1 commented 2 years ago

@TryLittleHarder Hi, sorry for the late response on this. Are you still interested in these tasks? If so, I can offer some guidance. I'm happy to release the code for logP targeting. The substructure-constrained logP extremization will be a bit more involved, but still possible.

TryLittleHarder commented 2 years ago

@TryLittleHarder Hi, sorry for the late response on this. Are you still interested in these tasks? If so, I can offer some guidance. I'm happy to release the code for logP targeting. The substructure-constrained logP extremization will be a bit more involved, but still possible.

Thanks for your reply. I think I still need some help. I add the l2 loss to the loss in get_optimized_z,but the molecule generated didn't have the same substructure as the input molecule. The l1 & l2 loss are -2.4705 and 10860.1191, respectively. Maybe the l2 loss is too large?

PeterEckmann1 commented 2 years ago

How are you applying the l2 loss? The way it was done for the paper was at the output of the decoder, after applying a mask to the positions of the SELFIES string you don't want changed.

For the paper, we used torch.sum(((x - orig_x.clone().detach()) * mask) ** 2), with x being the output of the decoder for the current epoch and orig_x being the output of the decoder for the first epoch, i.e. the input molecule. mask is a torch.zeros_like(orig_x), and it equals 1 where you wish to keep the substructure the same and 0 otherwise. This was the code we used to generate the mask:

smile = ''  # put input smiles here
orig_z = smiles_to_z([smile], vae)
orig_x = torch.exp(vae.decode(orig_z))
substruct = Chem.MolFromSmiles('') # put smiles of the substructure you wish to keep constant here
selfies = list(sf.split_selfies(sf.encoder(smile)))
mask = torch.zeros_like(orig_x)
for i in range(len(selfies)):
    for j in range(len(dm.dataset.idx_to_symbol)):
        changed = selfies.copy()
        changed[i] = dm.dataset.idx_to_symbol[j]
        m = Chem.MolFromSmiles(sf.decoder(''.join(changed)))
        if not m.HasSubstructMatch(substruct):
            mask[0][i * len(dm.dataset.idx_to_symbol) + j] = 1

Note that this will not necessarily work for all molecules, as sometimes the VAE fails to reconstruct the input molecule you give it. After you have a mask and a starting molecule, then you can put everything in a training loop:

z = orig_z.clone().detach().requires_grad_(True)
optimizer = torch.optim.Adam([z], lr=0.1)
smiles = []
logps = []
for epoch in tqdm(range(50000)):
    optimizer.zero_grad()
    x = torch.exp(vae.decode(z))
    loss = model(x) + 1000 * torch.sum(((x - orig_x.clone().detach()) * mask) ** 2)
    loss.backward()
    optimizer.step()
    if epoch % 1000 == 0:
        x, logp = get_logp(z)
        logps.append(logp.item())
        smiles.append(one_hot_to_smiles(x))
print('starting', logps[0])
print('ending', logps[-1])

model here is a logp property predictor, but it could be any other property predictor as well. Hopefully this gives you a good idea of where to go next, but please follow up if you have any questions!

TryLittleHarder commented 2 years ago

How are you applying the l2 loss? The way it was done for the paper was at the output of the decoder, after applying a mask to the positions of the SELFIES string you don't want changed.

For the paper, we used torch.sum(((x - orig_x.clone().detach()) * mask) ** 2), with x being the output of the decoder for the current epoch and orig_x being the output of the decoder for the first epoch, i.e. the input molecule. mask is a torch.zeros_like(orig_x), and it equals 1 where you wish to keep the substructure the same and 0 otherwise. This was the code we used to generate the mask:

smile = ''  # put input smiles here
orig_z = smiles_to_z([smile], vae)
orig_x = torch.exp(vae.decode(orig_z))
substruct = Chem.MolFromSmiles('') # put smiles of the substructure you wish to keep constant here
selfies = list(sf.split_selfies(sf.encoder(smile)))
mask = torch.zeros_like(orig_x)
for i in range(len(selfies)):
    for j in range(len(dm.dataset.idx_to_symbol)):
        changed = selfies.copy()
        changed[i] = dm.dataset.idx_to_symbol[j]
        m = Chem.MolFromSmiles(sf.decoder(''.join(changed)))
        if not m.HasSubstructMatch(substruct):
            mask[0][i * len(dm.dataset.idx_to_symbol) + j] = 1

Note that this will not necessarily work for all molecules, as sometimes the VAE fails to reconstruct the input molecule you give it. After you have a mask and a starting molecule, then you can put everything in a training loop:

z = orig_z.clone().detach().requires_grad_(True)
optimizer = torch.optim.Adam([z], lr=0.1)
smiles = []
logps = []
for epoch in tqdm(range(50000)):
    optimizer.zero_grad()
    x = torch.exp(vae.decode(z))
    loss = model(x) + 1000 * torch.sum(((x - orig_x.clone().detach()) * mask) ** 2)
    loss.backward()
    optimizer.step()
    if epoch % 1000 == 0:
        x, logp = get_logp(z)
        logps.append(logp.item())
        smiles.append(one_hot_to_smiles(x))
print('starting', logps[0])
print('ending', logps[-1])

model here is a logp property predictor, but it could be any other property predictor as well. Hopefully this gives you a good idea of where to go next, but please follow up if you have any questions! Thanks a lot! It totally solved my problem. But I still want to know how to target logp to a certain range. I tried to only count loss beyond the range [-2.5, 2.0], only get a success rate of 8%.

PeterEckmann1 commented 2 years ago

Glad to hear that helped! For logp targeting, we used an MSE loss to the center of the range, so something like torch.mean((out + 2.25) ** 2) (I'm assuming you meant [-2.5, -2.0] instead of [-2.5, +2.0], but you can adapt the 2.25 in that line of code to whatever the center of your range is). So instead of counting loss only beyond the range, we actually counted loss within the range too. Also note that we used logp, not penalized_logp, for this task, in case you weren't aware, since that could change the results slightly.

Although, in general, I'm not very surprised by the 8%, because we got a very similar value (10.4%) for the paper.