torchmd / torchmd-net

Training neural network potentials
MIT License
335 stars 75 forks source link

Something about Eq. 9 #176

Closed hhhhzzzzz closed 9 months ago

hhhhzzzzz commented 1 year ago

I understand the motivation for incorporating distance information in feature vectors. But I can't understand the top line of Eq. 9. How it works?

Thanks!

PhilippThoelke commented 1 year ago

We perform a graph convolution with the value matrix and distance embeddings. This is a standard approach from graph convolutional networks (e.g. SchNet). Then we simply split the resulting feature vectors into three equally sized chunks for further processing in different branches of the architecture. The value and distance embedding dimension is 3*embedding_dim so the splitting operation results in three feature vectors, each with dimension embedding_dim.

See this part in the code for further details https://github.com/torchmd/torchmd-net/blob/d3e611e2d01890720621180691c87044884ab4aa/torchmdnet/models/torchmd_et.py#L317-L320

hhhhzzzzz commented 1 year ago

v_j (1*d) is the feature vector of node j. dv is the distance matrix (n*n*d). So how to apply graph convolution on v_j?

Thanks!

PhilippThoelke commented 1 year ago

The shape of dv is not (n, n, d): https://github.com/torchmd/torchmd-net/blob/d3e611e2d01890720621180691c87044884ab4aa/torchmdnet/models/torchmd_et.py#L281

Instead, it is a flattened version of the matrix that matches the expansion of values v to v_j. We are able to convolve every v_j with the corresponding dv entry for nodes i and j. Feel free to explore the shapes e.g. by adding prints if you want to get a complete understanding of what is happening in the message function.