aiqm / torchani

Accurate Neural Network Potential on PyTorch
https://aiqm.github.io/torchani/
MIT License
461 stars 127 forks source link

[Fix #648] Pass cell and pbc as args when calling AEVComputer so forward hook works via TorchScript #649

Open lohedges opened 3 months ago

lohedges commented 3 months ago

This PR closes #648 by passing cell and pbc as args rather than kwargs when calling self.aev_computer within BuiltinModel and BuiltinEnsemble. This allows a user to add a forward hook to the internal AEVComputer of an ANI model to allow them to get AEVs from the hook at the same time that energies are obtained from the model.

While kwargs work fine with PyTorch, it appears that TorchScript assumes that the input parameter defined in the hook represents all input to the hooked function (AEVComputer.forward) , i.e. both args and kwargs. When recursing to create the scripted version of a module, TorchScript requires that the hooked function is called with args only throughout.

(I have been made aware of the code freeze and appreciate that this won't be merged. Just posting for posterity.)