torchmd / torchmd-net

Training neural network potentials
MIT License
326 stars 73 forks source link

Bias terms in the Equivariant Block and Attention Mechanism Implementation #57

Closed charliexchen closed 2 years ago

charliexchen commented 2 years ago

Dear Philipp

I've been looking at the code for this implementation in order to reproduce some of the results, and I noticed that the vector linear projections used in the Gated Equivariant Blocks and the Attention blocks don't have their bias terms removed. This seems to deviate from the descriptions from your paper and the PaiNN papers.

https://github.com/torchmd/torchmd-net/blob/e23c178cf8a6eb6ca931995d61128bd394e5999d/torchmdnet/models/utils.py#L260-L261

https://github.com/torchmd/torchmd-net/blob/main/torchmdnet/models/torchmd_et.py#L224

Unfortunately these bias terms would prevent the model implementation from being rotationally invariant, since you will be adding a vector in the (1,1,1) direction to all the vector features whenever it is called.

In my own implementation, it seems like the gated equivariant blocks had a propensity to NaN during training if this bias term is removed, since it is called immediately before a L2 Norm. I'm still investigating this, however.

All the best,

Charlie

giadefa commented 2 years ago

Thanks Charlie, we are looking into it.

giadefa commented 2 years ago

Charlie, thanks for spotting the bug. The results are the same or better without it. Philipp will merge soon the new correct version.

PhilippThoelke commented 2 years ago

Thanks Charlie, good spot! I have removed the bias from the vector projection layers in the Gated Equivariant Block and from the attention mechanism (see this commit https://github.com/torchmd/torchmd-net/commit/798468dff61fd6c6f034a177c8b120800dec275b).

It is not required to remove the bias from the v_proj layer as you suggested though as this layer just operates on scalars. I have however removed it from the vec_proj layer in the attention mechanism: https://github.com/torchmd/torchmd-net/blob/69301dbaccda8e3c61718029b6dbaaa2f34ba1bf/torchmdnet/models/torchmd_et.py#L227

charliexchen commented 2 years ago

Thanks for fixing this! Yeah, it's a super easy mistake to make, but glad that the performance isn't affected by this!