aiqm / torchani

Accurate Neural Network Potential on PyTorch
https://aiqm.github.io/torchani/
MIT License
461 stars 127 forks source link

torch.jit.script profile guided optimisations produce errors in aev_computer gradients #628

Open sef43 opened 1 year ago

sef43 commented 1 year ago

Hi, I have found that with pytorch 1.13 and 2.0 (not with pytorch<=1.12) the torch.jit.script profile guided optimisations (that are on by default) cause significant errors in the position gradients calculated via backpropagation of aev_computer when using a CUDA device. This is demonstrated in issue https://github.com/openmm/openmm-ml/issues/50.

An example is shown below, manually turning off the jit optimizations gives accurate forces:


from torchani.neurochem import parse_neurochem_resources, Constants
from torchani.aev import AEVComputer
import torch
import numpy as np

class Model(torch.nn.Module):
   def __init__(self, device):
      super(Model, self).__init__()
      info_file_path='ani-2x_8x.info'
      const_file, _,_,_ = parse_neurochem_resources(info_file_path)
      consts = Constants(const_file)
      self.aev_computer = AEVComputer(**consts)
      self.aev_computer.to(device)

   def forward(self, species, positions):
      incoords = positions
      inspecies = species
      aev = self.aev_computer((inspecies.unsqueeze(0), incoords.unsqueeze(0)))
      sumaevs = torch.mean(aev.aevs)

      return sumaevs

## setup
N=100
species = torch.randint(0, 7, (N,), device="cuda")
pos = np.random.random((N, 3))

for optimize in [True, False]: 
   print("JIT optimize = ", optimize)

   torch._C._jit_set_profiling_executor(optimize)
   torch._C._jit_set_profiling_mode(optimize)

   model = Model("cuda")
   model = torch.jit.script(model)

   grads=[]
   for i in range(10):
      incoords = torch.tensor(pos, dtype=torch.float32, requires_grad=True, device="cuda")
      result = model(species, incoords)
      result.backward(retain_graph=True)
      grad = incoords.grad
      grads.append(grad.cpu().numpy())
      print(i,"max percentage error: ",np.max(100.0*np.abs((grads[0]-grads[-1])/grads[0])))

output I get on an RTX3090 is:

JIT optimize =  True
Downloading ANI model parameters ...
0 max percentage error:  0.0
1 max percentage error:  0.00055674225
2 max percentage error:  217.80972
3 max percentage error:  217.80959
4 max percentage error:  217.81003
5 max percentage error:  217.80975
6 max percentage error:  217.80972
7 max percentage error:  217.81082
8 max percentage error:  217.80956
9 max percentage error:  217.81024
JIT optimize =  False
0 max percentage error:  0.0
1 max percentage error:  0.0003876826
2 max percentage error:  0.0002178617
3 max percentage error:  0.00021537923
4 max percentage error:  0.0005815239
5 max percentage error:  0.0010768962
6 max percentage error:  0.00017895782
7 max percentage error:  0.00035465648
8 max percentage error:  0.00039845158
9 max percentage error:  0.00018266498

I have found a workaround to remove the errors is to replace a ** operation with a torch.float_power: https://github.com/aiqm/torchani/commit/172b6fe85d3ab2acd3faa7a025b5aded22f2537c,

yueyericardo commented 1 year ago

Thanks for reporting the issue!

This is a problem of NVFuser. A bug report has been filed at https://github.com/pytorch/pytorch/issues/84510

The minimal reproducible example I extracted from the angular function is the following:

def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor,
                  ShfA: Tensor, vectors12: Tensor) -> Tensor:
    vectors12 = vectors12.view(2, -1, 3, 1, 1, 1, 1)
    cos_angles = vectors12.prod(0).sum(1)

    ret = (cos_angles + ShfZ) * Zeta * ShfA * 2
    return ret.flatten(start_dim=1)

Replace a ** operation with a torch.float_power will not solve the root cause of this problem.

At this moment, I would recommend disabling NVFuser by running the following:

torch._C._jit_set_nvfuser_enabled(False)

This will change to NNC fuser (https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#fusers) instead of nvfuser, which I tested is working correctly.