MinkaiXu / GeoLDM

Geometric Latent Diffusion Models for 3D Molecule Generation
MIT License
195 stars 37 forks source link

Autoencoder is identity function on atom coordinates? Equivalence to EDM #6

Open guanjq opened 1 year ago

guanjq commented 1 year ago

Hi Minkai,

Thank you for sharing this work! When I analyze the sampling results of GeoLDM, I found the latent variable z_x is almost equal to the decoded atom positions. Below are molecules I reconstructed with decoded atom pos and atom type (left) and z_x and decoded atom type (right) respectively. They are almost same.

z_x + recon atom type recon atom pos + recon atom type

A further analysis on the reconstruction results of the auto encoder in GeoLDM indicates that both encoder and decoder are almost identity functions on atom coordinates. If so, can I consider GeoLDM is actually equivalent to 3D space diffusion (i.e. EDM) since #latent variables is equal to #atoms and both encoder and decoder are identity functions on atom coordinates, except that there is an auto-encoder part on atom types?

If this is correct, I’m also wondering how did you train the autoencoder in your published version. I can understand the training will lead to identity functions with the reconstruction loss only, but you mentioned in the repo that the encoder is remained untrained. If so, why is the encoder not a random mapping but a identity function instead?

Thanks!

MinkaiXu commented 1 year ago

Hi Jiaqi,

Thanks a lot for your interest and your close observation! The published version exactly follows my description (https://github.com/MinkaiXu/GeoLDM#train-the-geoldm) that the encoder remained "untrained".

I just followed your interesting observation take a look at the whole model. I guess, but not 100% sure, this phenomenon comes from the initialization of the EGNN layers:

  1. x in EGNN encoders is updated with weighted relative direction from neighbors. I found that it looks like the MLPs for computing the weights will be initialed with very small values. https://github.com/MinkaiXu/GeoLDM/blob/main/egnn/egnn_new.py#L76
  2. Then since x is updated with residual connections, with the small aggregations, the updates become almost identical transformations. https://github.com/MinkaiXu/GeoLDM/blob/main/egnn/egnn_new.py#L98

I think your observation also helps us to understand the model's behavior that I could make it work with any KL regularization over z_x (as in the paper appendix ablation study). Seems like the latent code need to keep most of the structural information and leave this part of modeling complexity to the latent diffusion. Considering this perspective, I think your understanding of equivalence to EDM is also true. But I think very likely my current implementation of GeoLDM is still not perfect, and actually some reasonably distorted z_x can lead to better results :)

guanjq commented 1 year ago

Hi Minkai,

Thank you for your quick response! That makes sense. I also believe there should be some better way to implement GeoLDM. Looking forward to future updates and welcome further discussion!

Best, Jiaqi