lucidrains / egnn-pytorch

Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch
MIT License
412 stars 68 forks source link

EGNN explodes for larger problems #7

Closed denjots closed 3 years ago

denjots commented 3 years ago

I'm going to look into this further... closing.

denjots commented 3 years ago

This needs further research I think. Closing for now.

denjots commented 3 years ago

I've worked quite a bit with this neural network layer now, and I think it's probably the least stable thing I've ever tried to train! It's almost impossible to keep it from exploding, unless you train it with a tiny learning rate. I suspect the algorithm itself is broken, or maybe not yet properly described in the paper. I've tried the latest fix, which helps a bit, but on training 4 layers with random sized graphs of length < 100 nodes, I managed to get past 58 epochs of training (loss going down) before it started exploding to NaN. Looking at the algorithm, I just think adding an unbounded weighted sum of relative coordinates to the existing coordinates is never going to be stable for anything other than a toy problem. The weights needed for 100 nodes are never going to comparable in magnitude to the weights needed for 10 nodes, say. Maybe the original authors meant to normalize the weights in some way, or perhaps norm the relative coordinate vectors - or perhaps they just never tried it on larger messier problems.

lucidrains commented 3 years ago

@denjots thanks for letting me know! Have you tried any of those solutions you listed? Norming the relative vectors for weighting seems like a good idea

lucidrains commented 3 years ago

@denjots there's also a lot of other things to try, if you are willing. It just seems too promising a technique to let go entirely

lucidrains commented 3 years ago

@denjots do you want to try setting norm_rel_coors = True for your larger experiments? normalizing the coor_weights may be a good ideas as well

denjots commented 3 years ago

I have to admit I have played with those things and others e.g. using tanh to squash the coordinate weights. It's a balancing act between basically getting it to be stable but unwilling to learn anything and unstable. Neither situation is good, obviously. Adding some normalization works for graphs up to 100, but larger graphs produce NaNs again. Applying more normalization stabilizes it again, but then it learns very very slowly (if it is indeed learning anything). It just feels like a neural net model that really wants to be unstable! Also, I'm a little worried that some tweaks will cause equivariance failures, but I think normalization at least should not do so.

I'm currently trying to rescale the weights by various functions of the graph size - maybe that will work?

lucidrains commented 3 years ago

@denjots are you using EGGN or EGGN_sparse. perhaps only restricting the operation to local neighborhoods on the biomolecule would work better?

lucidrains commented 3 years ago

@denjots yea, I mean, instability isn't really something foreign to us deep learning folks. it's just a problem to be solved

lucidrains commented 3 years ago

@denjots yea, i think the three things to try are (1) restricting updates only from local atoms (2) norm relative vectors (3) norm the coordinate weights coming out of the coordinate MLP. none of what i listed will break equivariance

denjots commented 3 years ago

I haven't tried the sparse version - too many dependencies, some that don't install out of the box on my system. I did try some early experiments with a distance dependent mask in the basic EGGN - maybe I will revisit them. It's a shame, because EGGNs looked so simple, but I guess things can be too simple sometimes...

lucidrains commented 3 years ago

@denjots i wouldn't let up yet, there's another similar paper from recent ICLR that corroborates the results https://github.com/lucidrains/geometric-vector-perceptron I think it is worth trying to localize, or at least clamp the relative vector norms.

the alternative is to use SE3-Transformers, which, although it works, is an order of magnitude slower

lucidrains commented 3 years ago

@denjots https://github.com/lucidrains/En-transformer I have another version of EGNN that you could try, that is free from the PyG dependencies. you can restrict to local neighbors with num_nearest_neighbors keyword

denjots commented 3 years ago

I have tried SE-3 transformers - both Fabian's and yours - and they are both stable. Obviously the speed and simplicity of EGGN was very attractive, but you can't have everything. Thanks for the pointer to the EGNN-transformer - I had noticed it, but was keen to get the basic EGNN to work first - I will queue the transformer for testing and see how it goes, however.

lucidrains commented 3 years ago

@denjots thank you for relaying your results! this matters a lot to me, as if this approach can be made to work, it may make Alphafold2 replication a lot simpler

lucidrains commented 3 years ago

@denjots I'll ready a version of EGNN without PyG for you to try by later this evening, also restricting to nearest neighbors

lucidrains commented 3 years ago

@denjots I've added the feature in 0.0.20, so you can use it like so https://github.com/lucidrains/egnn-pytorch/blob/0.0.20/tests/test_equivariance.py#L22-L37

denjots commented 3 years ago

Thanks - will have a look tomorrow.

denjots commented 3 years ago

I think I'm happy enough with where it's at now. I went back to my own masking in the end - mainly becaise the idea of a fixed number of neighbors wasn't very realistic - so I just mask the coordinate weights with a specified distance threshold now. This, combined with the new initialization, stabilizes training up to at least 8 layers now, which is good enough for me. It's not learning much of value so far, but that's likely more a science problem I need to look at than a software problem. Thanks a lot for your help!

lucidrains commented 3 years ago

Glad to hear you got it working!

gaceladri commented 3 years ago

Maybe helpful? https://arxiv.org/abs/2009.03294

You are not doing any norm in the eGNN?