aiqm / torchani

Accurate Neural Network Potential on PyTorch
https://aiqm.github.io/torchani/
MIT License
459 stars 126 forks source link

CUAEV RuntimeError: CUDA error: an illegal memory access was encountered #595

Closed LiuCMU closed 3 years ago

LiuCMU commented 3 years ago

Hi, when using CUAEV on torch.device('cuda:1'), it gave me the following error information:

Traceback (most recent call last):
  File "ani_e.py", line 208, in <module>
    rmse = validate()
  File "ani_e.py", line 194, in validate
    _, predicted_energies = model((species, coordinates))
  File "/home/jack/miniconda3/envs/CUAEV/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jack/miniconda3/envs/CUAEV/lib/python3.7/site-packages/torchani-2.3.dev60+ga85e330-py3.7-linux-x86_64.egg/torchani/nn.py", line 111, in forward
    input_ = module(input_, cell=cell, pbc=pbc)
  File "/home/jack/miniconda3/envs/CUAEV/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jack/miniconda3/envs/CUAEV/lib/python3.7/site-packages/torchani-2.3.dev60+ga85e330-py3.7-linux-x86_64.egg/torchani/nn.py", line 60, in forward
    atomic_energies = self._atomic_energies((species, aev))
  File "/home/jack/miniconda3/envs/CUAEV/lib/python3.7/site-packages/torchani-2.3.dev60+ga85e330-py3.7-linux-x86_64.egg/torchani/nn.py", line 72, in _atomic_energies
    output = aev.new_zeros(species_.shape)
**RuntimeError: CUDA error: an illegal memory access was encountered**
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

If I use torch.device('cuda:0') with CUAEV, the the code runs well. Otherwise, using torch.device('cuda:1') without CUAEV also works. This looks like a issue related to CUAEV on a device with multiple GPUs. Could anyone provide some debug tips for that? Thanks!

To sum up:

yueyericardo commented 3 years ago

Hi, thanks for the report! Could you please provide a minimal reproducible example? The aev should be created on the same device as the input coordinates: https://github.com/aiqm/torchani/blob/master/torchani/cuaev/aev.cu#L755

LiuCMU commented 3 years ago

Hi, thanks for the reply. I created a folder containing a reproducible example.

This folder contains the following three files:

You can get the same error message when executing the following in the folder: python ani_2x_f.py

The GPU device can be adjusted in the python file.

Sorry for the late sample😂I got obsessed with another script the past week. Please let me know how things are going, thank you very much!

yueyericardo commented 3 years ago

Thanks! I Was able to reproduce by the following, will try to fix it soon

import os
import torch
import torchani
import unittest
import pickle
import copy
from torchani.testing import TestCase, make_tensor

class TestCUAEV(TestCase):

    def setUp(self):
        self.tolerance = 5e-5
        i = 1
        self.device = torch.device(f'cuda:{i}' if torch.cuda.is_available() else 'cpu')
        print('{}: {}'.format(i, torch.cuda.get_device_name(f'cuda:{i}')))
        Rcr = 5.2000e+00
        Rca = 3.5000e+00
        EtaR = torch.tensor([1.6000000e+01], device=self.device)
        ShfR = torch.tensor([9.0000000e-01, 1.1687500e+00, 1.4375000e+00, 1.7062500e+00, 1.9750000e+00, 2.2437500e+00, 2.5125000e+00, 2.7812500e+00, 3.0500000e+00, 3.3187500e+00, 3.5875000e+00, 3.8562500e+00, 4.1250000e+00, 4.3937500e+00, 4.6625000e+00, 4.9312500e+00], device=self.device)
        Zeta = torch.tensor([3.2000000e+01], device=self.device)
        ShfZ = torch.tensor([1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00, 1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00], device=self.device)
        EtaA = torch.tensor([8.0000000e+00], device=self.device)
        ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00], device=self.device)
        num_species = 4
        self.aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
        self.cuaev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species, use_cuda_extension=True)
        self.nn = torch.nn.Sequential(torch.nn.Linear(384, 1, False)).to(self.device)
        self.radial_length = self.aev_computer.radial_length

    def testSimple(self):
        coordinates = torch.tensor([
            [[0.03192167, 0.00638559, 0.01301679],
             [-0.83140486, 0.39370209, -0.26395324],
             [-0.66518241, -0.84461308, 0.20759389],
             [0.45554739, 0.54289633, 0.81170881],
             [0.66091919, -0.16799635, -0.91037834]],
            [[-4.1862600, 0.0575700, -0.0381200],
             [-3.1689400, 0.0523700, 0.0200000],
             [-4.4978600, 0.8211300, 0.5604100],
             [-4.4978700, -0.8000100, 0.4155600],
             [0.00000000, -0.00000000, -0.00000000]]
        ], device=self.device)
        species = torch.tensor([[1, 0, 0, 0, 0], [2, 0, 0, 0, -1]], device=self.device)

        _, aev = self.aev_computer((species, coordinates))
        _, cu_aev = self.cuaev_computer((species, coordinates))
        self.assertEqual(cu_aev, aev)

if __name__ == '__main__':
    unittest.main()