Shen-Lab / LDM-3DG

[ICLR 2024] "Latent 3D Graph Diffusion" by Yuning You, Ruida Zhou, Jiwoong Park, Haotian Xu, Chao Tian, Zhangyang Wang, Yang Shen
GNU General Public License v3.0
28 stars 6 forks source link

3d decoder is not deterministic, 3d decoding and encoding are not reversible #6

Closed yryMax closed 2 months ago

yryMax commented 2 months ago

Hi Yuning,

Sorry to bother you again. This is not really an “issue”, I'm just experimenting with your Geometric AE, but I'm not sure what I'm doing is correct, I'd be grateful if you could help me identify if the results I got are reasonable.

Geometric decoder is not deterministic

I tried to generate 3d molecules from fixed z_samples and smiles_samples, but each time I ran the same code snippet the outcomes look very different(see below). Is this expected? trail1

Trail 1

trail2

Trail 2

trail3

Trail 3

3d decoding and encoding are not reversible

I tried to decode z_samples to smiles_samples and conformer_samples, and my goal is to recover the (second half) of the latent space from only smiles and conformer samples. However, if I first use the geometric decoder and then use the geometric encoder. The reconstructed latent space differs a lot from the z_samples. This is unexpected to me because I tried to reconstruct the first half of latent space from the topology AE earlier, and it fits very well(See below). Is this normal?

图片说明1

2d latent space (first half of z)

图片说明2

3d latent space (second half of z)

yryMax commented 2 months ago

My decoding pipeline

(code borrowed from AE_Geometry_and_Unconditional_Latent_Diffusion/sample3_latent_ddpm_qm9_3d.py)

def smile_to_conformer(smiles, z, decoder_3d_instance):
    mol3d_list = []
    for idx, smi in enumerate(tqdm(smiles)):
        mol = Chem.MolFromSmiles(smi)
        mol = Chem.AddHs(mol)

        try:
            AllChem.EmbedMolecule(mol, maxAttempts=5000)
            positions = mol.GetConformers()[0].GetPositions()
        except:
            AllChem.Compute2DCoords(mol)
            positions = mol.GetConformers()[0].GetPositions()

        data = mol_to_graph_data_obj_simple_2D(mol)[0]
        data.x = data.x.float()[:, :118]
        data.edge_attr = data.edge_attr.float()
        data.n_nodes = data.x.shape[0]
        data.n_edges = data.edge_index.shape[1]
        data.pos = torch.tensor(positions).float()

        data.edge_index = tgeom.nn.radius_graph(data.pos, r=5, loop=False)
        edge_attr = torch.exp(-torch.norm(data.pos[data.edge_index[0]] - data.pos[data.edge_index[1]], dim=1))
        data.edge_attr = torch.einsum('i,j->ij', edge_attr, torch.linspace(1, 5, 16).to(edge_attr.device))
        data.n_edges = torch.tensor(data.edge_index.shape[1]).long()

        latent = z[idx].unsqueeze(dim=0).to('cuda')
        data = data.to('cuda')

        batch = Batch.from_data_list([data]).to('cuda')

        with torch.no_grad():
            pos = decoder_3d_instance(batch, latent[:, 250:])[0][-1]
        pos = pos.cpu()
        conf = mol.GetConformer()
        for jdx in range(mol.GetNumAtoms()):
            conf.SetAtomPosition(jdx, Point3D(pos[jdx, 0].item(), pos[jdx, 1].item(), pos[jdx, 2].item()))

        try:
            AllChem.MMFFOptimizeMolecule(mol)
            mol3d_list.append(mol)
        except:
            continue
    return mol3d_list

My encoding pipeline

mol_one = mol3d_list[0]
data = mol_to_graph_data_obj_simple_3D(mol_one)[0]
data.x = data.x.float()[:, :118]
data.edge_index = radius_graph(data.positions, r=5.0, loop=False)
bt = Batch.from_data_list([data]).to('cuda')
with torch.no_grad():
        emb_3d = encoder_3d_instance(bt.x, bt.positions, bt.edge_index, bt.batch, True)[1].to('cpu')

I exported the full snippet for reference: visualization.ipynb.pdf

In any case thanks again for your great work and the rapid response to all the issues that have been raised!!!!!!!!

yyou1996 commented 2 months ago

Thanks for your comments. We do need to say that the 3D AE is far from ideal compared to 2D AE, as shown in Table 1 RMSE is 0.69. So the current performance bottleneck is more on the 3D part, as you correctly understand.

The reversibility you mention is very interesting. It basically saying that Enc(Dec(Enc(G))) and Enc(G) have non-trivial mismatch error for the 3D features. I haven't checked from this perspective but it does not surprise me that much, considering I indeed thought the 3D AE's quality can be further improved.

yryMax commented 2 months ago

Thank you so much for your response! Sorry to bother you again. I am currently trying to recover the z_sample from generated smiles and conformers and this is really important to me.

Do you think it's feasible to write an inverse method manually of the 2d/3d decoder currently used? (For example we can have DDIM and DDIM-inverse, For ddpm sampler it's not feasible because there is randomization when forwarding)

I am not a researcher in this field and lack domain knowledge, so this may be a stupid question. But if it's feasible I am gonna look into the details in your paper and implementation and try to implement the inversion, otherwise I won't waste my time in it and I will try to find other methods.

Again I appreciate your great work and your patience

yyou1996 commented 1 month ago

In my view it is feasible. 2D thing is already there, and 3D might need some efforts to make it works.