Open lohedges opened 2 weeks ago
Hi Lester, could you pass empty tensor (cell, pbc) for your hook?
No, it doesn't seem to be possible, unless I'm misunderstanding things. The hook signature needs to be:
def hook_wrapper():
def hook(
module,
input: Tuple[Tuple[Tensor, Tensor], Optional[Tensor], Optional[Tensor]],
output: torchani.aev.SpeciesAEV,
):
# Do something with the AEVComputer.forward output here.
This is then registered with something like:
self._aev_hook = self._ani2x.aev_computer.register_forward_hook(hook_wrapper())
(Here self._ani2x
is an instance of torchani.ANI2x
that is used in my module. When I get the energies from self._ani2x
I can get the AEVs from the hooks output.)
The input
matches the signature of AEVComputer.forward
, which is what I'm hooking. I believe that with TorchScript, the input
defined in the hook should be a tuple containing all of the args to the hooked function, not just the input
argument of that function. The hook must strictly have three inputs (module
, input
, and output
) and I've tried changing the type signature for input
but it complains that it must be what I've put.
For reference, your signature is:
def forward(self, input_: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesAEV:
It seems that TorchScript can't resolve that cell
and pbc
are the optional arguments for AEVComputer.forward
when they are passed as kwargs in your module.
Here's a minimal example that reproduces the issue:
import torch
import torchani
from torch import Tensor
from typing import Optional, Tuple
class Test(torch.nn.Module):
def __init__(self):
super(Test, self).__init__()
self._ani2x = torchani.models.ANI2x(periodic_table_index=True)
# Assign a tensor attribute that can be used for assigning the AEVs.
self._ani2x.aev_computer._aev = torch.empty(0)
# Hook the forward pass of the ANI2x model to get the AEV features.
def hook_wrapper():
def hook(
module,
input: Tuple[Tuple[Tensor, Tensor], Optional[Tensor], Optional[Tensor]],
output: torchani.aev.SpeciesAEV,
):
module._aev = output[1][0]
return hook
# Register the hook.
self._aev_hook = self._ani2x.aev_computer.register_forward_hook(hook_wrapper())
def forward(self, species: Tensor, coordinates: Tensor) -> Tensor:
# Forward pass of the ANI2x model.
energy = self._ani2x((species, coordinates))[0]
# Do something with the AEV features.
aevs = self._ani2x.aev_computer._aev
return energy
# Create an instance of the model.
model = Test()
# Convert to TorchScript.
script_model = torch.jit.script(model)
There's the possibility of using with_kwargs=True
when registering the forward hook. When just using PyTorch, everything is okay:
import torch
import torchani
def hook(module, args, kwargs, output):
print(module)
print(len(args))
print(args)
print(len(kwargs))
print(kwargs)
class Test(torch.nn.Module):
def __init__(self):
super(Test, self).__init__()
self._ani2x = torchani.models.ANI2x(periodic_table_index=True)
self._aev_hook = self._ani2x.aev_computer.register_forward_hook(
hook, with_kwargs=True
)
def forward(self, coordinates, species):
return self._ani2x((species, coordinates), cell=None, pbc=None).energies
coordinates = torch.tensor(
[
[
[0.03192167, 0.00638559, 0.01301679],
[-0.83140486, 0.39370209, -0.26395324],
[-0.66518241, -0.84461308, 0.20759389],
[0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834],
]
],
requires_grad=True,
dtype=torch.float32,
)
# In periodic table, C = 6 and H = 1
species = torch.tensor([[6, 1, 1, 1, 1]])
# Create model.
model = Test()
# Convert to TorchScript.
#model = torch.jit.script(model)
# Compute energies.
energies = model(coordinates, species)
Gives:
AEVComputer()
1
(SpeciesCoordinates(species=tensor([[1, 0, 0, 0, 0]]), coordinates=tensor([[[ 0.0319, 0.0064, 0.0130],
[-0.8314, 0.3937, -0.2640],
[-0.6652, -0.8446, 0.2076],
[ 0.4555, 0.5429, 0.8117],
[ 0.6609, -0.1680, -0.9104]]], requires_grad=True)),)
2
{'cell': None, 'pbc': None}
However, this doesn't work with TorchScript, i.e. uncommenting the model = torch.script(model)
bit above gives:
...
File "/home/lester/.conda/envs/emle/lib/python3.10/site-packages/torch/jit/_recursive.py", line 479, in create_hooks_from_stubs
concrete_type._create_hooks(hook_defs, hook_rcbs, pre_hook_defs, pre_hook_rcbs)
RuntimeError: Hook 'hook' on module 'AEVComputer' was expected to only have exactly 3 inputs but it had 4 inputs
So it looks like TorchScript can only work with forward hooks when using args, so expects the hooked function to be called with args too, hence why I needed to convert kwargs to args in your models.py file.
Hi, thanks for the update, sounds good. Could you use your local fork to make the changes? We plan to have a major update later this year, and the current code base is freezed.
No problem. I'll provide a patch for our users since it's just a modification to a single file. Do you want me to raise a PR with a fix and test anyway? I appreciate that it won't be merged, but at least you'll have a record if you want to apply it to whatever update comes later in the year.
Cheers.
Yes, please open a PR, thank you!
Hello there,
I am writing a module based on ANI2x that requires AEVs and have been trying to use a forward hook on the ANI2x AEVComputer to avoid duplicating the calculation. While this works perfectly fine in PyTorch, I appear to be unable to serialise my model using TorchScript since you are passing args within the hook input as kwargs. In particular,
cell
andpbc
are always passed ascell=cell
andpbc=pbc
. The exception that I get is:Applying the following patch to
models.py
gets things to work:(I've confirmed that identical results as before.)
Is there a reason why kwargs need to be used? If not, would the proposed patch be acceptable? Using a forward hook gives an appreciable speed gain for my model by removing the need to compute AEVs twice.
Just to note that I am pretty new to TorchScript, so perhaps there is a way to get it to correctly detect the use of kwargs in hooks. Alternatively, perhaps there is another way to get the AEVs from the module. (Currently you use them as an intermediate in the calculation, but they could be stored as a module attribute.) I certainly could derive from ANI2x and overload forward, but the way I have things now is quite a bit more flexible for my use case.
Many thanks.