openmm / NNPOps

High-performance operations for neural network potentials
Other
79 stars 17 forks source link

error when trying to jit.script getNeighbors #92

Closed sef43 closed 1 year ago

sef43 commented 1 year ago

When I torch.jit.script a module using getNeighbors it fails (pytorch=1.13.1, nnpops=0.4): Example:

import torch
from NNPOps.neighbors import getNeighborPairs

class ForceModule(torch.nn.Module):

    def forward(self, positions):

        neighbors, deltas, distances = getNeighborPairs(positions, cutoff=1.0)
        mask = torch.isnan(distances)
        distances = distances[~mask]

        return torch.sum(distances**2)

model = ForceModule()
module = torch.jit.script(model)
module.save('model.pt')

output:

Traceback (most recent call last):
  File "/home/sfarr/Documents/MLP_train/run_md_nequip/test_nn_nl.py", line 15, in <module>
    module = torch.jit.script(model)
  File "/home/sfarr/miniconda3/envs/mlp_train/lib/python3.9/site-packages/torch/jit/_script.py", line 1286, in script
    return torch.jit._recursive.create_script_module(
  File "/home/sfarr/miniconda3/envs/mlp_train/lib/python3.9/site-packages/torch/jit/_recursive.py", line 476, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/home/sfarr/miniconda3/envs/mlp_train/lib/python3.9/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
    create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
  File "/home/sfarr/miniconda3/envs/mlp_train/lib/python3.9/site-packages/torch/jit/_recursive.py", line 393, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
  File "/home/sfarr/miniconda3/envs/mlp_train/lib/python3.9/site-packages/torch/jit/_recursive.py", line 863, in try_compile_fn
    return torch.jit.script(fn, _rcb=rcb)
  File "/home/sfarr/miniconda3/envs/mlp_train/lib/python3.9/site-packages/torch/jit/_script.py", line 1343, in script
    fn = torch._C._jit_script_compile(
RuntimeError: 
Return value was annotated as having type Tuple[Tensor, Tensor] but is actually of type Tuple[Tensor, Tensor, Tensor]:
  File "/home/sfarr/miniconda3/envs/mlp_train/lib/python3.9/site-packages/NNPOps/neighbors/getNeighborPairs.py", line 126
    if box_vectors is None:
        box_vectors = empty((0, 0), device=positions.device, dtype=positions.dtype)
    return ops.neighbors.getNeighborPairs(positions, cutoff, max_num_neighbors, box_vectors)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
'getNeighborPairs' is being compiled since it was called from 'ForceModule.forward'
  File "/home/sfarr/Documents/MLP_train/run_md_nequip/test_nn_nl.py", line 8
    def forward(self, positions):

        neighbors, deltas, distances = getNeighborPairs(positions, cutoff=1.0)
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        mask = torch.isnan(distances)
        distances = distances[~mask]

Looks like it is just the type annotations.

We should add jit.scripting to the test cases

RaulPPelaez commented 1 year ago

This particular issue seems easy to fix, in fact it is probable this #80 fixes it, since I modified the decorators and I remember they were wrong. However, I do not think it is worth spending time on jit.script, since it is basically obsolete now due to torch.compile. jit.script requires things like type annotations, something that torch.compile has engineered to avoid.

raimis commented 1 year ago

jit.script is still used (including TorchMD-NET). So, it would be ideal if it works.

raimis commented 1 year ago

@sef43 I think, you need just fix this line to have Tuple[Tensor, Tensor, Tensor]

https://github.com/openmm/NNPOps/blob/c8690ac7b1538e5b9246d3d8ed02bb85c80ef25d/src/pytorch/neighbors/getNeighborPairs.py#L5

RaulPPelaez commented 1 year ago

I tried and yout fix works, @raimis, I will PR.

sef43 commented 1 year ago

Thanks!

Do we have/should we make a general pytorch 2.0 issue in this repo or openmm-torch to discuss various changes and features? the new torch.compile feature does look very useful!

RaulPPelaez commented 1 year ago

Totally, I believe that is on me. I will PR soon. It will probably mostly painless (hopefully...)

RaulPPelaez commented 1 year ago

I think you gave the perfect excuse to start a PR and discussion :P #94

peastman commented 1 year ago

If you use torch.compile, can you load the compiled module in C++? The documentation still only talks about loading TorchScript models.