pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.25k stars 3.64k forks source link

Marginalized Graph Autoencoder #2152

Open ferdinand-popp opened 3 years ago

ferdinand-popp commented 3 years ago

Hello,

thank you for providing this amazing tools! I was thinking about implementing a Marginalized graph autoencoder (MGAE: https://dl.acm.org/doi/10.1145/3132847.3132967) where the feature vectors are corrupted with noise. I want to compare it to the other already implemented GAEs. Currently I am clustering the latent representations to find new subgroups of the nodes. What would be the best way to set this up?

Thank you very much your help!

rusty1s commented 3 years ago

I cannot access the paper due to closed access, so I can only make an educated guess. It looks like you want to add noise to the input feature vectors and both reconstruct node features as well as graph structure, it that correct? In that case, modifications are necessary to also reconstruct input node features with an additional MAE/MSE loss on the input node features. In that case, your decoder needs to consists of two parts: the graph structure decoder and the feature decoder.

ferdinand-popp commented 3 years ago

Hello rusty1s, thank you for your answer! Yes, you are correct ! The loss function is different and the decoder as well. The matlab implementation is in this repo: https://github.com/FakeTibbers/MGAE but implementing it into the pytorch geometric application is quite challenging! The idea was to create a stacked single layer setup to get a latent represenation from both the structural information and the node features. Or is there another implementation that is denoising and integrates both structure and features? Thank you very much for your help!

rusty1s commented 3 years ago

Note that the GNN encoder already captures both structural information and local feature information, so I do not think there are any changes necessary on this part. I feel like the decoder part and the loss formulation can be implemented with just pure PyTorch functionality. Let me know what you got already and I can have a quick look :)