isayevlab / AIMNet2

MIT License
87 stars 24 forks source link

RuntimeError when using new AIMNet2ASE calculator #22

Closed mamo248 closed 1 month ago

mamo248 commented 4 months ago

I installed the new aimnet2 calculator mostly according to instructions (I had to do the first two lines of installation commands in one line during environment creation, otherwise I ran into problems solving the environment). When I now try the following example script:

from aimnet2calc import AIMNet2ASE
from ase import Atoms

calc = AIMNet2ASE('aimnet2')
atoms = Atoms('CO', positions=[(0, 0, 0), (0, 0, 1.1)])
atoms.calc = calc
atoms.get_potential_energy()

I'm running into a RuntimeError:

{
    "name": "RuntimeError",
    "message": "The following operation failed in the TorchScript interpreter.\nTraceback of TorchScript, serialized code (most recent call last):\n  File \"code/__torch__/aimnet/models/aimnet2.py\", line 70, in forward\n    data9 = (atomic_sum).forward(data8, )\n    data10 = (lrcoulomb).forward(data9, )\n    return (dftd3).forward(data10, )\n            ~~~~~~~~~~~~~~ <--- HERE\n  def prepare_data(self: __torch__.aimnet.models.aimnet2.AIMNet2,\n    data: Dict[str, Tensor]) -> Dict[str, Tensor]:\n  File \"code/__torch__/aimnet/modules.py\", line 290, in forward\n  def forward(self: __torch__.aimnet.modules.DFTD3,\n    data: Dict[str, Tensor]) -> Dict[str, Tensor]:\n    c6ij, d_ij, = (self)._calc_c6ij(data, )\n                   ~~~~~~~~~~~~~~~~ <--- HERE\n    r4r2 = self.r4r2\n    _70 = annotate(List[Optional[Tensor]], [data[\"numbers\"]])\n  File \"code/__torch__/aimnet/modules.py\", line 336, in _calc_c6ij\n    cnmax = self.cnmax\n    _84 = annotate(List[Optional[Tensor]], [numbers])\n    cn0 = torch.clamp(cn, None, torch.index(cnmax, _84))\n                                ~~~~~~~~~~~ <--- HERE\n    _85 = torch.unsqueeze(torch.unsqueeze(cn0, -1), -1)\n    cn_i = torch.unsqueeze(_85, -1)\n\nTraceback of TorchScript, original code (most recent call last):\n  File \"/home/roman/repo/AIMNET2NB/aimnet2nbmat/aimnet/models/aimnet2.py\", line 136, in forward\n    \n        for m in self.outputs.children():\n            data = m(data)\n                   ~ <--- HERE\n        return data\n  File \"/home/roman/repo/AIMNET2NB/aimnet2nbmat/aimnet/modules.py\", line 524, in forward\n    def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:\n        c6ij, d_ij = self._calc_c6ij(data)\n                     ~~~~~~~~~~~~~~~ <--- HERE\n        rrij = 3 * ops.ij_product(self.r4r2[data['numbers']], data['nbmat_lr'])\n        r0ij = self.a1 * rrij.sqrt() + self.a2\n  File \"/home/roman/repo/AIMNET2NB/aimnet2nbmat/aimnet/modules.py\", line 508, in _calc_c6ij\n        rcov_ij = rcov_i.unsqueeze(-1) + rcov_j\n        cn = (1.0 / (1.0 + torch.exp(self.k1*((rcov_ij)/d_ij - 1.0)))).sum(dim=-1)\n        cn = torch.clamp(cn, max=self.cnmax[numbers])\n                                 ~~~~~~~~~~~~~~~~~~~ <--- HERE\n        cn_i = cn.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)\n        cn_j = ops.select_j(cn, data['nbmat_lr']).unsqueeze(-1).unsqueeze(-1)\nRuntimeError: tensors used as indices must be long, byte or bool tensors\n",
    "stack": "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m\n\u001b[1;31mRuntimeError\u001b[0m                              Traceback (most recent call last)\nCell \u001b[1;32mIn [6], line 4\u001b[0m\n\u001b[0;32m      2\u001b[0m atoms \u001b[39m=\u001b[39m Atoms(\u001b[39m'\u001b[39m\u001b[39mCO\u001b[39m\u001b[39m'\u001b[39m, positions\u001b[39m=\u001b[39m[(\u001b[39m0\u001b[39m, \u001b[39m0\u001b[39m, \u001b[39m0\u001b[39m), (\u001b[39m0\u001b[39m, \u001b[39m0\u001b[39m, \u001b[39m1.1\u001b[39m)])\n\u001b[0;32m      3\u001b[0m atoms\u001b[39m.\u001b[39mcalc \u001b[39m=\u001b[39m calc\n\u001b[1;32m----> 4\u001b[0m atoms\u001b[39m.\u001b[39;49mget_potential_energy()\n\nFile \u001b[1;32mc:\\ProgramData\\mambaforge\\envs\\aimnet2\\lib\\site-packages\\ase\\atoms.py:755\u001b[0m, in \u001b[0;36mAtoms.get_potential_energy\u001b[1;34m(self, force_consistent, apply_constraint)\u001b[0m\n\u001b[0;32m    752\u001b[0m     energy \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_calc\u001b[39m.\u001b[39mget_potential_energy(\n\u001b[0;32m    753\u001b[0m         \u001b[39mself\u001b[39m, force_consistent\u001b[39m=\u001b[39mforce_consistent)\n\u001b[0;32m    754\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m--> 755\u001b[0m     energy \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_calc\u001b[39m.\u001b[39;49mget_potential_energy(\u001b[39mself\u001b[39;49m)\n\u001b[0;32m    756\u001b[0m \u001b[39mif\u001b[39;00m apply_constraint:\n\u001b[0;32m    757\u001b[0m     \u001b[39mfor\u001b[39;00m constraint \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconstraints:\n\nFile \u001b[1;32mc:\\ProgramData\\mambaforge\\envs\\aimnet2\\lib\\site-packages\\ase\\calculators\\abc.py:24\u001b[0m, in \u001b[0;36mGetPropertiesMixin.get_potential_energy\u001b[1;34m(self, atoms, force_consistent)\u001b[0m\n\u001b[0;32m     22\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m     23\u001b[0m     name \u001b[39m=\u001b[39m \u001b[39m'\u001b[39m\u001b[39menergy\u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m---> 24\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mget_property(name, atoms)\n\nFile \u001b[1;32mc:\\ProgramData\\mambaforge\\envs\\aimnet2\\lib\\site-packages\\ase\\calculators\\calculator.py:538\u001b[0m, in \u001b[0;36mBaseCalculator.get_property\u001b[1;34m(self, name, atoms, allow_calculation)\u001b[0m\n\u001b[0;32m    535\u001b[0m     \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39muse_cache:\n\u001b[0;32m    536\u001b[0m         \u001b[39mself\u001b[39m\u001b[39m.\u001b[39matoms \u001b[39m=\u001b[39m atoms\u001b[39m.\u001b[39mcopy()\n\u001b[1;32m--> 538\u001b[0m     \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcalculate(atoms, [name], system_changes)\n\u001b[0;32m    540\u001b[0m \u001b[39mif\u001b[39;00m name \u001b[39mnot\u001b[39;00m \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mresults:\n\u001b[0;32m    541\u001b[0m     \u001b[39m# For some reason the calculator was not able to do what we want,\u001b[39;00m\n\u001b[0;32m    542\u001b[0m     \u001b[39m# and that is OK.\u001b[39;00m\n\u001b[0;32m    543\u001b[0m     \u001b[39mraise\u001b[39;00m PropertyNotImplementedError(\n\u001b[0;32m    544\u001b[0m         \u001b[39m'\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m not present in this \u001b[39m\u001b[39m'\u001b[39m \u001b[39m'\u001b[39m\u001b[39mcalculation\u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mformat(name)\n\u001b[0;32m    545\u001b[0m     )\n\nFile \u001b[1;32mc:\\ProgramData\\mambaforge\\envs\\aimnet2\\lib\\site-packages\\aimnet2calc-0.0.1-py3.10.egg\\aimnet2calc\\aimnet2ase.py:65\u001b[0m, in \u001b[0;36mAIMNet2ASE.calculate\u001b[1;34m(self, atoms, properties, system_changes)\u001b[0m\n\u001b[0;32m     62\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m     63\u001b[0m     cell \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m---> 65\u001b[0m results \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbase_calc({\n\u001b[0;32m     66\u001b[0m     \u001b[39m'\u001b[39;49m\u001b[39mcoord\u001b[39;49m\u001b[39m'\u001b[39;49m: torch\u001b[39m.\u001b[39;49mtensor(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49matoms\u001b[39m.\u001b[39;49mpositions, dtype\u001b[39m=\u001b[39;49mtorch\u001b[39m.\u001b[39;49mfloat32, device\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbase_calc\u001b[39m.\u001b[39;49mdevice),\n\u001b[0;32m     67\u001b[0m     \u001b[39m'\u001b[39;49m\u001b[39mnumbers\u001b[39;49m\u001b[39m'\u001b[39;49m: \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_t_numbers,\n\u001b[0;32m     68\u001b[0m     \u001b[39m'\u001b[39;49m\u001b[39mcell\u001b[39;49m\u001b[39m'\u001b[39;49m: cell,\n\u001b[0;32m     69\u001b[0m     \u001b[39m'\u001b[39;49m\u001b[39mmol_idx\u001b[39;49m\u001b[39m'\u001b[39;49m: \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_t_mol_idx,\n\u001b[0;32m     70\u001b[0m     \u001b[39m'\u001b[39;49m\u001b[39mcharge\u001b[39;49m\u001b[39m'\u001b[39;49m: \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_t_charge,\n\u001b[0;32m     71\u001b[0m     \u001b[39m'\u001b[39;49m\u001b[39mmult\u001b[39;49m\u001b[39m'\u001b[39;49m: \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_t_mult,\n\u001b[0;32m     72\u001b[0m }, forces\u001b[39m=\u001b[39;49m\u001b[39m'\u001b[39;49m\u001b[39mforces\u001b[39;49m\u001b[39m'\u001b[39;49m \u001b[39min\u001b[39;49;00m properties, stress\u001b[39m=\u001b[39;49m\u001b[39m'\u001b[39;49m\u001b[39mstress\u001b[39;49m\u001b[39m'\u001b[39;49m \u001b[39min\u001b[39;49;00m properties)\n\u001b[0;32m     73\u001b[0m \u001b[39mfor\u001b[39;00m k, v \u001b[39min\u001b[39;00m results\u001b[39m.\u001b[39mitems():\n\u001b[0;32m     74\u001b[0m     results[k] \u001b[39m=\u001b[39m v\u001b[39m.\u001b[39mdetach()\u001b[39m.\u001b[39mcpu()\u001b[39m.\u001b[39mnumpy()\n\nFile \u001b[1;32mc:\\ProgramData\\mambaforge\\envs\\aimnet2\\lib\\site-packages\\aimnet2calc-0.0.1-py3.10.egg\\aimnet2calc\\calculator.py:59\u001b[0m, in \u001b[0;36mAIMNet2Calculator.__call__\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m     58\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m---> 59\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39meval(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\nFile \u001b[1;32mc:\\ProgramData\\mambaforge\\envs\\aimnet2\\lib\\site-packages\\aimnet2calc-0.0.1-py3.10.egg\\aimnet2calc\\calculator.py:84\u001b[0m, in \u001b[0;36mAIMNet2Calculator.eval\u001b[1;34m(self, data, forces, stress, hessian)\u001b[0m\n\u001b[0;32m     82\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mset_grad_tensors(data, forces\u001b[39m=\u001b[39mforces, stress\u001b[39m=\u001b[39mstress, hessian\u001b[39m=\u001b[39mhessian)\n\u001b[0;32m     83\u001b[0m \u001b[39mwith\u001b[39;00m torch\u001b[39m.\u001b[39mjit\u001b[39m.\u001b[39moptimized_execution(\u001b[39mFalse\u001b[39;00m):\n\u001b[1;32m---> 84\u001b[0m     data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel(data)\n\u001b[0;32m     85\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mget_derivatives(data, forces\u001b[39m=\u001b[39mforces, stress\u001b[39m=\u001b[39mstress, hessian\u001b[39m=\u001b[39mhessian)\n\u001b[0;32m     86\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mprocess_output(data)\n\nFile \u001b[1;32mc:\\ProgramData\\mambaforge\\envs\\aimnet2\\lib\\site-packages\\torch\\nn\\modules\\module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1129\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1130\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39m\u001b[39minput\u001b[39m, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[0;32m   1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m   1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n\n\u001b[1;31mRuntimeError\u001b[0m: The following operation failed in the TorchScript interpreter.\nTraceback of TorchScript, serialized code (most recent call last):\n  File \"code/__torch__/aimnet/models/aimnet2.py\", line 70, in forward\n    data9 = (atomic_sum).forward(data8, )\n    data10 = (lrcoulomb).forward(data9, )\n    return (dftd3).forward(data10, )\n            ~~~~~~~~~~~~~~ <--- HERE\n  def prepare_data(self: __torch__.aimnet.models.aimnet2.AIMNet2,\n    data: Dict[str, Tensor]) -> Dict[str, Tensor]:\n  File \"code/__torch__/aimnet/modules.py\", line 290, in forward\n  def forward(self: __torch__.aimnet.modules.DFTD3,\n    data: Dict[str, Tensor]) -> Dict[str, Tensor]:\n    c6ij, d_ij, = (self)._calc_c6ij(data, )\n                   ~~~~~~~~~~~~~~~~ <--- HERE\n    r4r2 = self.r4r2\n    _70 = annotate(List[Optional[Tensor]], [data[\"numbers\"]])\n  File \"code/__torch__/aimnet/modules.py\", line 336, in _calc_c6ij\n    cnmax = self.cnmax\n    _84 = annotate(List[Optional[Tensor]], [numbers])\n    cn0 = torch.clamp(cn, None, torch.index(cnmax, _84))\n                                ~~~~~~~~~~~ <--- HERE\n    _85 = torch.unsqueeze(torch.unsqueeze(cn0, -1), -1)\n    cn_i = torch.unsqueeze(_85, -1)\n\nTraceback of TorchScript, original code (most recent call last):\n  File \"/home/roman/repo/AIMNET2NB/aimnet2nbmat/aimnet/models/aimnet2.py\", line 136, in forward\n    \n        for m in self.outputs.children():\n            data = m(data)\n                   ~ <--- HERE\n        return data\n  File \"/home/roman/repo/AIMNET2NB/aimnet2nbmat/aimnet/modules.py\", line 524, in forward\n    def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:\n        c6ij, d_ij = self._calc_c6ij(data)\n                     ~~~~~~~~~~~~~~~ <--- HERE\n        rrij = 3 * ops.ij_product(self.r4r2[data['numbers']], data['nbmat_lr'])\n        r0ij = self.a1 * rrij.sqrt() + self.a2\n  File \"/home/roman/repo/AIMNET2NB/aimnet2nbmat/aimnet/modules.py\", line 508, in _calc_c6ij\n        rcov_ij = rcov_i.unsqueeze(-1) + rcov_j\n        cn = (1.0 / (1.0 + torch.exp(self.k1*((rcov_ij)/d_ij - 1.0)))).sum(dim=-1)\n        cn = torch.clamp(cn, max=self.cnmax[numbers])\n                                 ~~~~~~~~~~~~~~~~~~~ <--- HERE\n        cn_i = cn.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)\n        cn_j = ops.select_j(cn, data['nbmat_lr']).unsqueeze(-1).unsqueeze(-1)\nRuntimeError: tensors used as indices must be long, byte or bool tensors\n"
}

Do you guys have any idea what causes this problem?