Closed RaulPPelaez closed 9 months ago
All tests are passing, I do not see this breaking old checkpoints. @PhilippThoelke I would like you to review this one!
Thanks Phillipp. I saw your comment about Union, but AFAICT it does not support Tuple, just Tensor, so one cannot do Union[Tensor, Tuple[Tensor,Tensor]]
Union supports any valid subtype, not only Tensor. According to the documentation
Union[T0, T1, ...] One of the subtypes T0, T1, etc.
I also did a quick test and it works fine:
from typing import Tuple, Union
import torch
from torch import Tensor, nn
class MyModel(nn.Module):
def __init__(self, option: bool):
super(MyModel, self).__init__()
self.option = option
def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]:
if self.option:
return x
else:
return x, x
if __name__ == "__main__":
model1 = torch.jit.script(MyModel(True))
print(model1(torch.ones(5)))
model2 = torch.jit.script(MyModel(False))
print(model2(torch.ones(5)))
But anyways, this is not relevant for this PR as we want to maintain backwards compatibility.
Currently this test always fails:
This was uncovered in https://github.com/openmm/openmm-torch/issues/135 .
The error is this
```shell platform linux -- Python 3.11.6, pytest-7.4.3, pluggy-1.3.0 -- /home/raul/miniforge3/envs/torchmdnet/bin/python3.11 cachedir: .pytest_cache rootdir: /home/raul/work/bcn/torchmd-net plugins: typeguard-2.13.3, anyio-3.7.1, cov-4.1.0 collected 94 items / 93 deselected / 1 selected test_model.py::test_torchscript_output_modification FAILED ================================================================================= FAILURES ================================================================================== ___________________________________________________________________ test_torchscript_output_modification ____________________________________________________________________ def test_torchscript_output_modification(): z, pos, batch = create_example_batch() model = create_model(load_example_args("tensornet", remove_prior=True, derivative=True)) class MyModel(torch.nn.Module): def __init__(self): super(MyModel, self).__init__() self.model = model def forward(self, z, pos, batch): y, neg_dy = self.model(z, pos, batch=batch) # A TorchScript bug is triggered if we modify an output of model marked as Optional[Tensor] return y, 2*neg_dy > torch.jit.script(MyModel()) test_model.py:75: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ /home/raul/miniforge3/envs/torchmdnet/lib/python3.11/site-packages/torch/jit/_script.py:1324: in script return torch.jit._recursive.create_script_module( /home/raul/miniforge3/envs/torchmdnet/lib/python3.11/site-packages/torch/jit/_recursive.py:559: in create_script_module return create_script_module_impl(nn_module, concrete_type, stubs_fn) /home/raul/miniforge3/envs/torchmdnet/lib/python3.11/site-packages/torch/jit/_recursive.py:636: in create_script_module_impl create_methods_and_properties_from_stubs( _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ concrete_type =This is due to a bug/limitation in TorchScript. One cannot script code in which an
Optional[Tensor]
is operated on. Since TorchMD_Net returnsTuple[Tensor,Optional[Tensor]]
the error pops out. In the test the error can be fixed if one convinces TorchScript that neg_dy is actually a Tensor:However, this requires an user to:
Tuple[Tensor,Optional[Tensor]]
. Which can be challenging to discover if the user is, for instance, loading some checkpoint.Otherwise the error is really unhelpful.
Another solution is to force TorchMD_Net to always return two Tensors. When
derivative
is False the second output is an empty tensor instead of None. This PR implements this solution, making the test above pass.