aiqm / torchani

Accurate Neural Network Potential on PyTorch
https://aiqm.github.io/torchani/
MIT License
459 stars 126 forks source link

Loading the TorchAni model in c++ gives runtime error #583

Closed ndonyapour closed 2 years ago

ndonyapour commented 3 years ago

Hello, I'm trying to load the TorchAni model (`compiled_model.pt') in the c++ using libtorch however, it gives the following error

successfully loaded the model
libc++abi.dylib: terminating with uncaught exception of type std::runtime_error: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/torchani/aev.py", line 53, in forward
      _5, cell0 = False, unchecked_cast(Tensor, cell)
    if _5:
      aev0 = __torch__.torchani.aev.compute_aev(species, coordinates, self.triu_index, (self).constants(), (4, 16, 64, 32, 320), None, )
             ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
      aev = aev0
    else:
  File "code/__torch__/torchani/aev.py", line 119, in compute_aev
  radial_aev = torch.new_zeros(radial_terms_, [_20, radial_sublength], dtype=None, layout=None, device=None, pin_memory=None)
  index12 = torch.add(torch.mul(atom_index12, num_species), torch.flip(species12, [0]), alpha=1)
  _21 = torch.index_add_(radial_aev, 0, torch.select(index12, 0, 0), radial_terms_)
        ~~~~~~~~~~~~~~~~ <--- HERE
  _22 = torch.index_add_(radial_aev, 0, torch.select(index12, 0, 1), radial_terms_)
  _23 = [num_molecules, num_atoms, radial_length]

Traceback of TorchScript, original code (most recent call last):
  File "/Users/nazanin/programs/torchani/torchani/aev.py", line 302, in compute_aev
    radial_aev = radial_terms_.new_zeros((num_molecules * num_atoms * num_species, radial_sublength))
    index12 = atom_index12 * num_species + species12.flip(0)
    radial_aev.index_add_(0, index12[0], radial_terms_)
    ~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    radial_aev.index_add_(0, index12[1], radial_terms_)
    radial_aev = radial_aev.reshape(num_molecules, num_atoms, radial_length)
RuntimeError: index_add_(): Expected dtype int32/int64 for index

Here is my code

int main(int argc, const char* argv[]) {

  // Get the path of model
  if (argc != 2) {
    std::cerr << "usage: test_model <path-to-exported-script-module>\n";
    return -1;
  }

  torch::jit::script::Module model;
  torch::Device device(torch::kCPU);
  try {
    // Deserialize the ScriptModule from a file using torch::jit::load().
    model = torch::jit::load(argv[1]);
    std::cout << "successfully loaded the model\n";
  }

  catch (const c10::Error& e) {
    std::cerr << "error loading the model\n";
    return -1;
  }

   // Set the model properties
   model.to(device);
   model.to(torch::kDouble);

   // Define the input variables
   std::vector<torch::jit::IValue> inputs;

   std::tuple<torch::Tensor, torch::Tensor> in={species, positions};
   inputs.push_back(in);
   //inputs.push_back(species);

   // Run the model
   torch::Tensor output = model.forward(inputs).toTensor();

 }
zasdfgbnm commented 3 years ago

Could this be because your pytorch and libtorch are not the same version?

ndonyapour commented 3 years ago

I'm using the same version (1.8). I tried to debug the code, and the data type of the triu_index matrix in the line above the error line becomes torch.Double during the runtime. Changing the data type of index is fixing the issue, but I thought this might not be the best solution. angular_aev.index_add_(0, index.to(dtype=torch.int64), angular_terms_)

zasdfgbnm commented 3 years ago

Looks like you hit a PyTorch bug. Could you open a bug report at https://github.com/pytorch/pytorch

ndonyapour commented 3 years ago

I tired to reproduce the error using a simple model that uses the index_add_ method. However, there is no error when I'm loading and running this model in c++. Here is my code

from torch import Tensor
from typing import Tuple

class my_model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input_: Tuple[Tensor, Tensor]):
        species, coordinates = input_

        tensor = torch.tensor([[1, 2, 3], [1, 2, 3],[1, 2, 3]],
                              dtype=torch.double)
        index = torch.tensor([0, 3, 2], dtype=torch.long)
        coordinates.index_add_(1, index, tensor)
        return coordinates

def save_model():

    save_path = './test_model.pt'
    model = my_model()
    script_module = torch.jit.script(model)
    script_module.save(save_path)

def test_saved_model():

    save_path = './test_model.pt'

    model = torch.jit.load(save_path)
    coordinates = torch.rand((1, 8, 3),
                             dtype=torch.double)
    species = torch.ones((1, coordinates.shape[0]),
                         dtype=torch.int64)
    print(model((species, coordinates)))

if __name__ == '__main__':
    save_model()
    test_saved_model()
IgnacioJPickering commented 3 years ago

Hello @ndonyapour libtorch has a bug where integer tensors are cast to torch::kDouble tensors if you call model.double() in a C++ context. I have reported this bug, however unfortunately it didn't get enough attention at the moment, maybe I failed to make myself clear at the time. The issue can be worked around manually by adding a function in your model that recasts the tensors into integer tensors. I don't really have time right now to provide a full working solution but you can look at BuiltinModel's _recast_long_tensors method for an idea on how to solve this.

I should probably raise the bug to pytorch again one of these days, unfortunately there is not much we can do on our side to fix this