aiqm / torchani

Accurate Neural Network Potential on PyTorch
https://aiqm.github.io/torchani/
MIT License
446 stars 125 forks source link

Pass pbc and cell as args to allow use of forward hook with TorchScript #648

Open lohedges opened 2 weeks ago

lohedges commented 2 weeks ago

Hello there,

I am writing a module based on ANI2x that requires AEVs and have been trying to use a forward hook on the ANI2x AEVComputer to avoid duplicating the calculation. While this works perfectly fine in PyTorch, I appear to be unable to serialise my model using TorchScript since you are passing args within the hook input as kwargs. In particular, cell and pbc are always passed as cell=cell and pbc=pbc. The exception that I get is:

...
RuntimeError:

hook(__torch__.torchani.aev.AEVComputer module, ((Tensor, Tensor), Tensor?, Tensor?) input, __torch__.torchani.aev.SpeciesAEV output) -> NoneType:
Expected a value of type 'Tuple[Tuple[Tensor, Tensor], Optional[Tensor], Optional[Tensor]]' for argument 'input' but instead found type 'Tuple[Tuple[Tensor, Tensor]]'.
:
  File "/home/lester/.conda/envs/emle/lib/python3.10/site-packages/torchani/models.py", line 106
            raise ValueError(f'Unknown species found in {species_coordinates[0]}')

        species_aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc)
                       ~~~~~~~~~~~~~~~~~ <--- HERE
        species_energies = self.neural_networks(species_aevs)
        return self.energy_shifter(species_energies)

Applying the following patch to models.py gets things to work:

diff --git a/torchani/models.py b/torchani/models.py
index 117cb4a..522e6d0 100644
--- a/torchani/models.py
+++ b/torchani/models.py
@@ -103,7 +103,7 @@ class BuiltinModel(torch.nn.Module):
         if species_coordinates[0].ge(self.aev_computer.num_species).any():
             raise ValueError(f'Unknown species found in {species_coordinates[0]}')

-        species_aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc)
+        species_aevs = self.aev_computer(species_coordinates, cell, pbc)
         species_energies = self.neural_networks(species_aevs)
         return self.energy_shifter(species_energies)

@@ -135,7 +135,7 @@ class BuiltinModel(torch.nn.Module):
         """
         if self.periodic_table_index:
             species_coordinates = self.species_converter(species_coordinates)
-        species, aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc)
+        species, aevs = self.aev_computer(species_coordinates, cell, pbc)
         atomic_energies = self.neural_networks._atomic_energies((species, aevs))
         self_energies = self.energy_shifter.self_energies.clone().to(species.device)
         self_energies = self_energies[species]
@@ -236,7 +236,7 @@ class BuiltinEnsemble(BuiltinModel):
         """
         if self.periodic_table_index:
             species_coordinates = self.species_converter(species_coordinates)
-        species, aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc)
+        species, aevs = self.aev_computer(species_coordinates, cell, pbc)
         members_list = []
         for nnp in self.neural_networks:
             members_list.append(nnp._atomic_energies((species, aevs)).unsqueeze(0))
@@ -322,7 +322,7 @@ class BuiltinEnsemble(BuiltinModel):
         """
         if self.periodic_table_index:
             species_coordinates = self.species_converter(species_coordinates)
-        species, aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc)
+        species, aevs = self.aev_computer(species_coordinates, cell, pbc)
         member_outputs = []
         for nnp in self.neural_networks:
             unshifted_energies = nnp((species, aevs)).energies

(I've confirmed that identical results as before.)

Is there a reason why kwargs need to be used? If not, would the proposed patch be acceptable? Using a forward hook gives an appreciable speed gain for my model by removing the need to compute AEVs twice.

Just to note that I am pretty new to TorchScript, so perhaps there is a way to get it to correctly detect the use of kwargs in hooks. Alternatively, perhaps there is another way to get the AEVs from the module. (Currently you use them as an intermediate in the calculation, but they could be stored as a module attribute.) I certainly could derive from ANI2x and overload forward, but the way I have things now is quite a bit more flexible for my use case.

Many thanks.

yueyericardo commented 2 weeks ago

Hi Lester, could you pass empty tensor (cell, pbc) for your hook?

lohedges commented 2 weeks ago

No, it doesn't seem to be possible, unless I'm misunderstanding things. The hook signature needs to be:

        def hook_wrapper():
            def hook(
                module,
                input: Tuple[Tuple[Tensor, Tensor], Optional[Tensor], Optional[Tensor]],
                output: torchani.aev.SpeciesAEV,
            ):
                # Do something with the AEVComputer.forward output here.

This is then registered with something like:

self._aev_hook = self._ani2x.aev_computer.register_forward_hook(hook_wrapper())

(Here self._ani2x is an instance of torchani.ANI2x that is used in my module. When I get the energies from self._ani2x I can get the AEVs from the hooks output.)

The input matches the signature of AEVComputer.forward, which is what I'm hooking. I believe that with TorchScript, the input defined in the hook should be a tuple containing all of the args to the hooked function, not just the input argument of that function. The hook must strictly have three inputs (module, input, and output) and I've tried changing the type signature for input but it complains that it must be what I've put.

For reference, your signature is:

    def forward(self, input_: Tuple[Tensor, Tensor],
                cell: Optional[Tensor] = None,
                pbc: Optional[Tensor] = None) -> SpeciesAEV:

It seems that TorchScript can't resolve that cell and pbc are the optional arguments for AEVComputer.forward when they are passed as kwargs in your module.

Here's a minimal example that reproduces the issue:

import torch
import torchani

from torch import Tensor
from typing import Optional, Tuple

class Test(torch.nn.Module):
    def __init__(self):
        super(Test, self).__init__()
        self._ani2x = torchani.models.ANI2x(periodic_table_index=True)

        # Assign a tensor attribute that can be used for assigning the AEVs.
        self._ani2x.aev_computer._aev = torch.empty(0)

        # Hook the forward pass of the ANI2x model to get the AEV features.
        def hook_wrapper():
            def hook(
                module,
                input: Tuple[Tuple[Tensor, Tensor], Optional[Tensor], Optional[Tensor]],
                output: torchani.aev.SpeciesAEV,
            ):
                module._aev = output[1][0]

            return hook

        # Register the hook.
        self._aev_hook = self._ani2x.aev_computer.register_forward_hook(hook_wrapper())

    def forward(self, species: Tensor, coordinates: Tensor) -> Tensor:
        # Forward pass of the ANI2x model.
        energy = self._ani2x((species, coordinates))[0]

        # Do something with the AEV features.
        aevs = self._ani2x.aev_computer._aev

        return energy

# Create an instance of the model.
model = Test()

# Convert to TorchScript.
script_model = torch.jit.script(model)
lohedges commented 2 weeks ago

There's the possibility of using with_kwargs=True when registering the forward hook. When just using PyTorch, everything is okay:

import torch
import torchani

def hook(module, args, kwargs, output):
    print(module)
    print(len(args))
    print(args)
    print(len(kwargs))
    print(kwargs)

class Test(torch.nn.Module):
    def __init__(self):
        super(Test, self).__init__()

        self._ani2x = torchani.models.ANI2x(periodic_table_index=True)
        self._aev_hook = self._ani2x.aev_computer.register_forward_hook(
            hook, with_kwargs=True
        )

    def forward(self, coordinates, species):
        return self._ani2x((species, coordinates), cell=None, pbc=None).energies

coordinates = torch.tensor(
    [
        [
            [0.03192167, 0.00638559, 0.01301679],
            [-0.83140486, 0.39370209, -0.26395324],
            [-0.66518241, -0.84461308, 0.20759389],
            [0.45554739, 0.54289633, 0.81170881],
            [0.66091919, -0.16799635, -0.91037834],
        ]
    ],
    requires_grad=True,
    dtype=torch.float32,
)

# In periodic table, C = 6 and H = 1
species = torch.tensor([[6, 1, 1, 1, 1]])

# Create model.
model = Test()

# Convert to TorchScript.
#model = torch.jit.script(model)

# Compute energies.
energies = model(coordinates, species)

Gives:

AEVComputer()
1
(SpeciesCoordinates(species=tensor([[1, 0, 0, 0, 0]]), coordinates=tensor([[[ 0.0319,  0.0064,  0.0130],
         [-0.8314,  0.3937, -0.2640],
         [-0.6652, -0.8446,  0.2076],
         [ 0.4555,  0.5429,  0.8117],
         [ 0.6609, -0.1680, -0.9104]]], requires_grad=True)),)
2
{'cell': None, 'pbc': None}

However, this doesn't work with TorchScript, i.e. uncommenting the model = torch.script(model) bit above gives:

...
  File "/home/lester/.conda/envs/emle/lib/python3.10/site-packages/torch/jit/_recursive.py", line 479, in create_hooks_from_stubs
    concrete_type._create_hooks(hook_defs, hook_rcbs, pre_hook_defs, pre_hook_rcbs)
RuntimeError: Hook 'hook' on module 'AEVComputer' was expected to only have exactly 3 inputs but it had 4 inputs

So it looks like TorchScript can only work with forward hooks when using args, so expects the hooked function to be called with args too, hence why I needed to convert kwargs to args in your models.py file.

yueyericardo commented 1 week ago

Hi, thanks for the update, sounds good. Could you use your local fork to make the changes? We plan to have a major update later this year, and the current code base is freezed.

lohedges commented 1 week ago

No problem. I'll provide a patch for our users since it's just a modification to a single file. Do you want me to raise a PR with a fix and test anyway? I appreciate that it won't be merged, but at least you'll have a record if you want to apply it to whatever update comes later in the year.

Cheers.

yueyericardo commented 1 week ago

Yes, please open a PR, thank you!