torchmd / torchmd-net

Training neural network potentials
MIT License
335 stars 75 forks source link

Make TensorNet compatible with TorchScript #186

Closed RaulPPelaez closed 1 year ago

RaulPPelaez commented 1 year ago

I can torch.jit.script TensorNet (the test_forward_torchscript passes), but when I try to torch.jit.script TorchMD_Net and do a training I get an error:

Traceback of TorchScript (most recent call last):
  File "/shared/raul/torchmd-net/torchmdnet/models/model.py", line 254, in forward
        if self.derivative:
            grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(y)]
            dy = grad(
                 ~~~~ <--- HERE
                [y],
                [pos],
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: The size of tensor a (128) must match the size of tensor b (3) at non-singleton dimension 3

I dont really understand why this error arises here and not in the test. I also do not understand why this error only happens when using jit.script.

guillemsimeon commented 1 year ago

Is the problem happening only when trying to train, or also when doing inference on both energies and forces?

RaulPPelaez commented 1 year ago

I updated the TorchScript test to try and do a double backwards, it passes without issues. I can do inference and backpropagation of an scripted TorchMD_Net module, but if I give LNNP an scripted module, or try to script LLNP itself (either calling jit.script on it or using the Lighning to_torchscript() member) I get that error.

RaulPPelaez commented 1 year ago

Given that TorchScript is not really intended to be used in training within TorchMD-Net as of now this PR in its current state leaves tensornet at the same level of compatibility as ET (I cannot use script during training with any model), that is the tests pass. One can script TorchMD-Net set up with tensornet, store it, load it and do inference and backprop on it.

Provided this I think we should merge this now. Please review!

guillemsimeon commented 1 year ago

I am ok with leaving it just compatible at this point, but accelerating training might be useful even if it is a 10%. Anyway, I remember that at some point we could train? That's when we saw the limited speedup of 10-20%, right?

RaulPPelaez commented 1 year ago

I was not able to trace the model. I am leaving that aside for now... Just for the record, the error I get in test_train when trying to trace is this:

torchmd-net/torchmdnet/module.py:128: UserWarning: Using a target size (torch.Size([8, 1])) that is different to the input size (torch.Size([50, 1])). 
This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.                                                                       
    loss_y = loss_fn(y, batch.y)   

Batch size is 8 in this test, I believe the issue is that the batch size is different for testing, val, and training. The scripted/traced model is cool with training, but when the shape of the input changes for testing it gets confused.

I remember we were able to train with a scripted module @guillemsimeon, but that was without using LNNP as far as I can remember.

RaulPPelaez commented 1 year ago

Interestingly this test passes without any issue:

@mark.parametrize("model_name", models.__all__)
def test_torchscript_dynamic_shapes(model_name):
    z, pos, batch = create_example_batch()
    model = torch.jit.script(
        create_model(load_example_args(model_name, remove_prior=True, derivative=True))
    )
    #Repeat the input to make it dynamic
    for rep in range(0, 10):
        zi = z.repeat_interleave(rep+1, dim=0)
        posi = pos.repeat_interleave(rep+1, dim=0)
        batchi = torch.cat([batch + i for i in range(rep+1)])
        y, neg_dy = model(zi, posi, batch=batchi)
        grad_outputs = [torch.ones_like(neg_dy)]
        ddy = torch.autograd.grad(
            [neg_dy],
            [posi],
            grad_outputs=grad_outputs,
        )[0]
RaulPPelaez commented 1 year ago

I found out why this test passes and training does not work. The tests are cpu ony, trying the above test with gpu-stored tensors yields the error I posted initially. Its only with tensornet, after a couple of differently sized inputs to the model it gives that error.

guillemsimeon commented 1 year ago

Does removing the enumerates fix the error?

RaulPPelaez commented 1 year ago

Does removing the enumerates fix the error?

No, you cannot do that in TorchScript:

E       RuntimeError: 
E       Expected integer literal for index. ModuleList/Sequential indexing is only supported with integer literals. Enumeration is supported, e.g. 'for index, v in enumerate(self): ...':
E         File "torchmd-net/torchmdnet/models/tensornet.py", line 201
E               # Interaction layers
E               for i in range(self.num_layers):
E                   X = self.layers[i](X, edge_index, edge_weight, edge_attr)
E                       ~~~~~~~~~~~~~~ <--- HERE
E               I, A, S = decompose_tensor(X)
E               x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1)

../../mambaforge/envs/test2/lib/python3.10/site-packages/torch/jit/_recursive.py:397: RuntimeError
RaulPPelaez commented 1 year ago

I would like to merge this now, since my upcoming PR with TensorNet optimization depends on this. @raimis @guillemsimeon , could you please review again?

guillemsimeon commented 1 year ago

I understand that it has been made compatible, but we cannot use this compatibility in training in the end? Or the situation changed?

RaulPPelaez commented 1 year ago

I still cannot train with a TorchScript model. But not only TensorNet, no model at all. I believe it is Torch lightning related... I spoke to @PhilippThoelke and apparently he also had issues with this.

guillemsimeon commented 1 year ago

well, we can deal with this later or even never at all, if other optimizations really make a larger difference. should I proceed to close the PR?