Closed RaulPPelaez closed 1 year ago
Is the problem happening only when trying to train, or also when doing inference on both energies and forces?
I updated the TorchScript test to try and do a double backwards, it passes without issues.
I can do inference and backpropagation of an scripted TorchMD_Net
module, but if I give LNNP
an scripted module, or try to script LLNP
itself (either calling jit.script on it or using the Lighning to_torchscript()
member) I get that error.
Given that TorchScript is not really intended to be used in training within TorchMD-Net as of now this PR in its current state leaves tensornet at the same level of compatibility as ET (I cannot use script during training with any model), that is the tests pass. One can script TorchMD-Net set up with tensornet, store it, load it and do inference and backprop on it.
Provided this I think we should merge this now. Please review!
I am ok with leaving it just compatible at this point, but accelerating training might be useful even if it is a 10%. Anyway, I remember that at some point we could train? That's when we saw the limited speedup of 10-20%, right?
I was not able to trace the model. I am leaving that aside for now... Just for the record, the error I get in test_train when trying to trace is this:
torchmd-net/torchmdnet/module.py:128: UserWarning: Using a target size (torch.Size([8, 1])) that is different to the input size (torch.Size([50, 1])).
This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
loss_y = loss_fn(y, batch.y)
Batch size is 8 in this test, I believe the issue is that the batch size is different for testing, val, and training. The scripted/traced model is cool with training, but when the shape of the input changes for testing it gets confused.
I remember we were able to train with a scripted module @guillemsimeon, but that was without using LNNP as far as I can remember.
Interestingly this test passes without any issue:
@mark.parametrize("model_name", models.__all__)
def test_torchscript_dynamic_shapes(model_name):
z, pos, batch = create_example_batch()
model = torch.jit.script(
create_model(load_example_args(model_name, remove_prior=True, derivative=True))
)
#Repeat the input to make it dynamic
for rep in range(0, 10):
zi = z.repeat_interleave(rep+1, dim=0)
posi = pos.repeat_interleave(rep+1, dim=0)
batchi = torch.cat([batch + i for i in range(rep+1)])
y, neg_dy = model(zi, posi, batch=batchi)
grad_outputs = [torch.ones_like(neg_dy)]
ddy = torch.autograd.grad(
[neg_dy],
[posi],
grad_outputs=grad_outputs,
)[0]
I found out why this test passes and training does not work. The tests are cpu ony, trying the above test with gpu-stored tensors yields the error I posted initially. Its only with tensornet, after a couple of differently sized inputs to the model it gives that error.
Does removing the enumerates fix the error?
Does removing the enumerates fix the error?
No, you cannot do that in TorchScript:
E RuntimeError:
E Expected integer literal for index. ModuleList/Sequential indexing is only supported with integer literals. Enumeration is supported, e.g. 'for index, v in enumerate(self): ...':
E File "torchmd-net/torchmdnet/models/tensornet.py", line 201
E # Interaction layers
E for i in range(self.num_layers):
E X = self.layers[i](X, edge_index, edge_weight, edge_attr)
E ~~~~~~~~~~~~~~ <--- HERE
E I, A, S = decompose_tensor(X)
E x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1)
../../mambaforge/envs/test2/lib/python3.10/site-packages/torch/jit/_recursive.py:397: RuntimeError
I would like to merge this now, since my upcoming PR with TensorNet optimization depends on this. @raimis @guillemsimeon , could you please review again?
I understand that it has been made compatible, but we cannot use this compatibility in training in the end? Or the situation changed?
I still cannot train with a TorchScript model. But not only TensorNet, no model at all. I believe it is Torch lightning related... I spoke to @PhilippThoelke and apparently he also had issues with this.
well, we can deal with this later or even never at all, if other optimizations really make a larger difference. should I proceed to close the PR?
I can torch.jit.script TensorNet (the
test_forward_torchscript
passes), but when I try to torch.jit.script TorchMD_Net and do a training I get an error:I dont really understand why this error arises here and not in the test. I also do not understand why this error only happens when using jit.script.