aiqm / torchani

Accurate Neural Network Potential on PyTorch
https://aiqm.github.io/torchani/
MIT License
459 stars 126 forks source link

Error when converting model to TorchScript #586

Closed peastman closed 2 years ago

peastman commented 3 years ago

TorchANI models do not work correctly when they are converted to TorchScript. I'm pretty sure this is a PyTorch bug, but you should be aware of it, and perhaps you can find a workaround. The following script demonstrates the problem using PyTorch 1.7.1.

import torch
import torchani

device = torch.device('cuda')
species = torch.tensor([[2, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0]]).to(device)
positions = torch.tensor([[
    [19.8102, 18.8292, 37.9237],
    [18.5426, 19.4376, 38.3537],
    [20.8406, 19.9389, 38.3717],
    [20.0545, 17.8716, 37.4475],
    [20.6924, 15.5844, 39.6175],
    [19.8118, 20.1167, 36.7419],
    [19.2987, 20.0335, 36.7371],
    [18.6676, 23.4357, 36.8353],
    [20.7095, 18.4540, 36.3641],
    [19.3523, 18.2274, 36.3663],
    [20.8299, 15.0725, 34.6570],
    [20.2751, 20.4366, 35.4269],
    [18.6964, 20.1698, 35.4341],
    [19.4113, 20.8338, 35.7119],
    [20.6836, 18.2203, 33.5049],
    [19.8220, 18.9653, 32.5260],
    [19.3635, 17.9477, 33.4874]]], dtype=torch.float32).to(device)
model = torchani.models.ANI1ccx().to(device)
module = torch.jit.script(model)
print(module((species, positions)))
print(module((species, positions)))

It fails with this exception.

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/home/peastman/miniconda3/envs/openmm/lib/python3.8/site-packages/torchani/aev.py", line 440, in forward

        if cell is None and pbc is None:
            aev = compute_aev(species, coordinates, self.triu_index, self.constants(), self.sizes, None)
                  ~~~~~~~~~~~ <--- HERE
        else:
            assert (cell is not None and pbc is not None)
  File "/home/peastman/miniconda3/envs/openmm/lib/python3.8/site-packages/torchani/aev.py", line 305, in compute_aev

    # compute angular aev
    central_atom_index, pair_index12, sign12 = triple_by_molecule(atom_index12)
                                               ~~~~~~~~~~~~~~~~~~ <--- HERE
    species12_small = species12[:, pair_index12]
    vec12 = vec.index_select(0, pair_index12.view(-1)).view(2, -1, 3) * sign12.unsqueeze(-1)
  File "/home/peastman/miniconda3/envs/openmm/lib/python3.8/site-packages/torchani/aev.py", line 250, in triple_by_molecule
    print(pair_sizes)
    print(sorted_local_index12.shape, pair_indices.shape, intra_pair_indices.shape, atom_index12.shape, mask.shape)
    sorted_local_index12 += cumsum_from_zero(counts).index_select(0, pair_indices)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

    # unsort result from last part
RuntimeError: The size of tensor a (679) must match the size of tensor b (699) at non-singleton dimension 1

Important points:

I tried to debug it a bit further, and I believe the error is happening in this line:

https://github.com/aiqm/torchani/blob/2007d181c1d650e28ee4e2c6a5f8e74fae54d6bf/torchani/aev.py#L249

I print out both counts and pair_sizes, and I can see that the second time it is invoked, pair_sizes is computed incorrectly:

 10
  9
  9
 11
  1
 10
 11
  2
 12
 12
  2
 12
 12
 11
  7
  4
  7
[ CUDALongType{17} ]
 49
 40
 40
 60
  0
 49
 60
  1
 71
 71
  1
 71
 71
 60
 24
  7
 24

For example, the first element should be 10*9/2 = 45, not 49.

yueyericardo commented 3 years ago

Hi, Thanks for letting us know. As you guys mentioned in another thread, this should be a bug for pytorch version < 1.8.

IgnacioJPickering commented 3 years ago

@yueyericardo @peastman yes, thank you, we are aware of the issues with JIT, this is actually still an issue in > 1.8, specifically it is also an issue in nightly right now if you load the model from C++, I have a workaround for it that I'm using in a private repo but I may merge it here. basically the issue is related to TorchScripts parsing of

pair_sizes = counts * (counts - 1) // 2

for some reason TorchScript creates buggy code that sometimes executes (counts ** 2) // 2 by making 1 into a tensor the issue can be bypassed with no performance penalty

one = torch.tensor(1, dtype=torch.long, device=counts.device)
pair_sizes = counts * (counts - one) // 2

it is ugly but it works