torchmd / torchmd-net

Training neural network potentials
MIT License
299 stars 71 forks source link

[Feature Request] Vector output from TensorNet #297

Open shenoynikhil opened 4 months ago

shenoynikhil commented 4 months ago
Feature Request

Currently the vector output is set to None. This could be easily calculated based on skew symmetric matrix based on the paper. Some papers suggest using force prediction using equivariant heads, so this would be required for that if TensorNet is used as a representation module.

We can add the following function in order to do it,

def skew_tensor_to_vector(tensor):
    """Extracts a vector from a skew-symmetric tensor.
    Based on Equation (3) in the paper. 
    Transforms tensor (num_atoms, hidden_channels, 3, 3) to (num_atoms, 3, hidden_channels)
    """
    return torch.stack((tensor[:, :, 1, 2], tensor[:, :, 2, 0], tensor[:, :, 0, 1]), dim=-1).transpose(1, 2)

And in the line https://github.com/torchmd/torchmd-net/blob/fdd4dac8852f1c42906e7c7a5f4ffa70319a41b2/torchmdnet/models/tensornet.py#L271

do,

v = skew_tensor_to_vector(A) # (num_atoms, 3, hidden_channels)
return x, v, z, pos, batch

It might be useful if someone wants to use an EquivariantScalar or EquivariantVectorOutput modules on top of TensorNet.

I tested for equivariance with the EquivariantVectorOutput Module and it does work out.

Screenshot 2024-02-29 at 4 13 22 PM

Would this be right? If yes, I can also add a test.

guillemsimeon commented 4 months ago

Hi! Thanks a lot for your interest, I can see that you have gone in depth through the TensorNet paper!

Regarding the new feature, you absolutely got it right. In fact we think it is an enhancement, and it will facilitate the use of vector features from TensorNet to other users that are not that familiar with the formalism. Therefore, we propose you to open a PR. Some things:

1) I am currently not sure if vector features need to be [...,3,hidden_channels] or the other way around at that point in the TorchMD_Net full model (output of representation model), I imagine you checked it, and that both EquivariantVectorOutput and EquivariantScalar expect vectors to have that shape, otherwise it should not work.

2) In any case, I would remove the transpose from the skewtensor_to_vector function, and I would apply it after getting v. That is:

   v = skewtensor_to_vector(A)
   v = v.transpose(-1,-2)

I think it is more consistent with the way we perform vector_to_skewtensor [..., whatever, 3] to [..., whatever, 3, 3].

3) Take this into account: https://github.com/torchmd/torchmd-net/blob/8ca7f607a0356282e0453000a81adcc7703e7989/torchmdnet/models/model.py#L121 TensorNet is set to is_equivariant = False to avoid that prefix and use the Scalar output, not the EquivariantScalar one. This means that if you want to build a full model with create_model, you have to specify as output model the full name 'EquivariantScalar' (I also want to mention at this point that I did not do any tests with EquivariantScalar in terms of performance).

4) I am not sure if I understood completely your equivariance test. EquivariantVectorOutput returns just a vector per atom. Do you rename these vectors as 'forces', meaning that compute_forces = True is direct prediction of forces (without autograd)?

Thanks again, feel free to open the PR.

Guillem

shenoynikhil commented 4 months ago

I am not sure if I understood completely your equivariance test. EquivariantVectorOutput returns just a vector per atom. Do you rename these vectors as 'forces', meaning that compute_forces = True is direct prediction of forces (without autograd)?

Sorry for not clearing this. I raised this issue because some papers choose to train with predicting equivariant vector as forces (instead of autograd of energy wrt to positions). This seems to be computationally faster, so if you're pretraining you can do this and then during fine-tuning use autograd based loss (reference: section 4 of https://arxiv.org/pdf/2310.16802.pdf). My equivariance test was something like,

rot = # rotation matrix
energy, forces = net(atomic_numbers, positions, batch)
energy_rot, forces_rot = net(atomic_numbers, positions @ rot, batch)
assert torch.allclose(forces @ rot, forces) # since equivariant
assert torch.allclose(energy, energy_rot) # since invariant

Let me start a PR.