openmm / NNPOps

High-performance operations for neural network potentials
Other
79 stars 17 forks source link

Support for pytorch 2.0 #94

Closed RaulPPelaez closed 1 year ago

RaulPPelaez commented 1 year ago

This PR is to start working on making NNPOps compatible with pytorch 2.0 and torch.compile

RaulPPelaez commented 1 year ago

Torchani cannot be installed with pytorch2, which forces to skip some tests. EDIT: Torchani made a new torch2 compatible release

RaulPPelaez commented 1 year ago

All tests pass and the ci works for an installation with pytorch2. I believe this should be merged now and work on compile() compatibility be done in another PR. A new release could be done now so that users can install NNPOps along pytorch2.

RaulPPelaez commented 1 year ago

In case you have some experience with torch2 compile: This test miserably fails in CUDA mode:


@pytest.mark.parametrize('device', ['cpu', 'cuda'])
@pytest.mark.parametrize('dtype', [pt.float32, pt.float64])
def test_torch_compile_compatible(device, dtype):

    class ForceModule(pt.nn.Module):

        def forward(self, positions):

            neighbors, deltas, distances = getNeighborPairs(positions, cutoff=1.0)
            mask = pt.isnan(distances)
            distances = distances[~mask]
            return pt.sum(distances**2)

    original_model = ForceModule()
    num_atoms=10
    positions = (20 * pt.randn((num_atoms, 3), device=device, dtype=dtype)) - 10
    original_model(positions)
    model = pt.compile(original_model)
    model(positions)

It yields a really verbose error about something called FakeTensor that makes the most obscure gcc recursive template error look clear and informative:

```python TestNeighbors.py::test_torch_compile_compatible[dtype1-cuda] FAILED [600/1860] ================================================================================= FAILURES ================================================================================== ________________________________________________________________ test_torch_compile_compatible[dtype0-cuda] _________________________________________________________________ output_graph = , node = get_neighbor_pairs args = (FakeTensor(FakeTensor(..., device='meta', size=(10, 3)), cuda:0), 1.0, -1, FakeTensor(FakeTensor(..., device='meta', size=(0, 0)), cuda:0)), kwargs = {} nnmodule = None def run_node(output_graph, node, args, kwargs, nnmodule): """ Runs a given node, with the given args and kwargs. Behavior is dicatated by a node's op. run_node is useful for extracting real values out of nodes. See get_real_value for more info on common usage. Note: The output_graph arg is only used for 'get_attr' ops Note: The nnmodule arg is only used for 'call_module' ops Nodes that are not call_function, call_method, call_module, or get_attr will raise an AssertionError. """ op = node.op try: if op == "call_function": > return node.target(*args, **kwargs) ../../../mambaforge/envs/nnpops-torch2-nvidia/lib/python3.10/site-packages/torch/_dynamo/utils.py:1194: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ self = args = (FakeTensor(FakeTensor(..., device='meta', size=(10, 3)), cuda:0), 1.0, -1, FakeTensor(FakeTensor(..., device='meta', size=(0, 0)), cuda:0)), kwargs = {} def __call__(self, *args, **kwargs): # overloading __call__ to ensure torch.ops.foo.bar() # is still callable from JIT # We save the function ptr as the `op` attribute on # OpOverloadPacket to access it here. > return self._op(*args, **kwargs or {}) E RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory. ../../../mambaforge/envs/nnpops-torch2-nvidia/lib/python3.10/site-packages/torch/_ops.py:502: RuntimeError ``` and following for a gazillion lines.

I have not been able to solve this, from what I have gathered this should not happen and it is a bug in torch (there are a lot of issues describing stuff like this: https://github.com/pytorch/pytorch/issues/96742 https://github.com/pytorch/pytorch/issues/95791

raimis commented 1 year ago

Yes, we can skip the compile feature for now.

RaulPPelaez commented 1 year ago

Ok I think this is done now.

raimis commented 1 year ago

@RaulPPelaez can I merge?

RaulPPelaez commented 1 year ago

Yes, thanks. @raimis