torchmd / torchmd-net

Training neural network potentials
MIT License
328 stars 73 forks source link

Simple TorchScript test fails #219

Closed RaulPPelaez closed 8 months ago

RaulPPelaez commented 1 year ago

This test will run exactly 4 iterations and then print an error:

def test_really_simple():
    n_atoms=10
    zs = torch.tensor([1, 6, 7, 8, 9], dtype=torch.long)
    z = zs[torch.randint(0, len(zs), (n_atoms,))]
    pos = torch.randn(len(z), 3)
    batch = torch.zeros(len(z), dtype=torch.long)
    batch[len(batch) // 2 :] = 1
    args = {"model": "tensornet",
            "embedding_dimension": 128,
            "num_layers": 2,
            "num_rbf": 32,
            "rbf_type": "expnorm",
            "trainable_rbf": False,
            "activation": "silu",
            "cutoff_lower": 0.0,
            "cutoff_upper": 5.0,
            "max_z": 100,
            "max_num_neighbors": 128,
            "equivariance_invariance_group": "O(3)",
            "prior_model": None,
            "atom_filter": -1,
            "derivative": True,
            "output_model": "Scalar",
            "reduce_op": "sum",
            "precision": 32 }
    model = create_model(args).to(device="cuda")
    z = z.to("cuda")
    pos = pos.to("cuda").requires_grad_(True)
    batch = batch.to("cuda")
    model = torch.jit.script(model).to(device="cuda")
    for i in range(0, 10):
        print("Running iteration {}".format(i))
        y, neg_dy = model(z, pos, batch)

The error:

self = RecursiveScriptModule(
  original_name=TorchMD_Net
  (representation_model): RecursiveScriptModule(
    original_name=...      (1): RecursiveScriptModule(original_name=SiLU)
      (2): RecursiveScriptModule(original_name=Linear)
    )
  )
)
args = (tensor([7, 6, 7, 7, 8, 7, 8, 7, 8, 8], device='cuda:0'), tensor([[-0.6791,  0.2550, -0.8304],
        [-0.4754,  0.60...107, -1.6600,  0.7560]], device='cuda:0', requires_grad=True), tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], device='cuda:0')), kwargs = {}, forward_call = <torch.ScriptMethod object at 0x7f9d363cd4e0>

    def _call_impl(self, *args, **kwargs):
        forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
        # If we don't have any hooks, we want to skip the rest of the logic in
        # this function, and just call forward.
        if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
                or _global_backward_pre_hooks or _global_backward_hooks
                or _global_forward_hooks or _global_forward_pre_hooks):
>           return forward_call(*args, **kwargs)
E           RuntimeError: The following operation failed in the TorchScript interpreter.
E           Traceback of TorchScript (most recent call last):
E             File "/shared/raul/torchmd-net/torchmdnet/models/model.py", line 292, in forward
E                       print("Shape of pos")
E                       print(pos.shape)
E                       dy = grad(
E                            ~~~~ <--- HERE
E                           [y],
E                           [pos],
E           RuntimeError: The following operation failed in the TorchScript interpreter.
E           Traceback of TorchScript (most recent call last):
E           RuntimeError: The following operation failed in the TorchScript interpreter.
E           Traceback of TorchScript (most recent call last):
E           RuntimeError: The size of tensor a (128) must match the size of tensor b (3) at non-singleton dimension 3

../../mambaforge/envs/openmmtorch-test/lib/python3.10/site-packages/torch/nn/modules/module.py:1501: RuntimeError
RaulPPelaez commented 1 year ago

Saving and loading the model each iteration solves the issue, and it is in fact faster than not doing it:

    for i in range(0, 10):
        print("Running iteration {}".format(i))
        import io
        buffer = io.BytesIO()
        torch.jit.save(model, buffer)
        buffer.seek(0)
        model = torch.jit.load(buffer).to(device="cuda")
        y, neg_dy = model(z, pos, batch)
RaulPPelaez commented 1 year ago

This is the offending line. Making tensor_m = tensor or using 0 layers makes the problem go away.

https://github.com/torchmd/torchmd-net/blob/fd8395463370c6b0510d4fe24b25d937382247a1/torchmdnet/models/tensornet.py#L331 Replacing it by this does NOT solve the issue:

tensor_m = torch.zeros_like(tensor).scatter_add(0, (edge_index[0])[:,None, None, None].expand_as(msg), msg)

But my suspicion is scatter_add is the same thing being used by scatter anyway... EDIT: Not saving the graph when calling grad also makes the error go away:

        if self.derivative:
            grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(y)]
            dy = grad(
                [y],
                [pos],
                grad_outputs=grad_outputs,
                create_graph=False,
                retain_graph=False,
            )[0]
RaulPPelaez commented 1 year ago

5 is a magic number we have encountered before when dealing with CUDA torch models. See for instance https://github.com/openmm/openmm-torch/pull/122

I just made the connection with this issue now while reading about torch.jit fusion operation mechanism with NVFuser: https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/codegen/cuda/README.md#general-ideas-of-debug-no-fusion

What I believe is going on here is that there is some kind of collision between jit.script, the nvfuser backend and autograd.

RaulPPelaez commented 11 months ago

This issue goes away with pytorch>2.0.0, so I am going to assume it is a bug there. Will leave this open until there is a conda-forge package for pytorch 2.1.

RaulPPelaez commented 8 months ago

Closing this as there is a package for 2.1 already.