materialsvirtuallab / matgl

Graph deep learning library for materials
BSD 3-Clause "New" or "Revised" License
232 stars 57 forks source link

[Bug]: `RuntimeError: expected scalar type Float but found Double` #235

Closed matthewkuner closed 4 months ago

matthewkuner commented 4 months ago

Email (Optional)

No response

Version

1.0.0

Which OS(es) are you using?

What happened?

attempting to run M3GNet to relax a structure using

import matgl
from matgl.ext import ase
m3gnet_potential = matgl.load_model("M3GNet-MP-2021.2.8-PES")
m3gnet_relaxer = ase.Relaxer(potential=m3gnet_potential, relax_cell=False)
result = m3gnet_relaxer.relax(structure, steps=0)

where structure is just a pymatgen structure object.

and I get the following stacktrace for my error:

 26   File "/global/home/users/matthewkuner/.conda/envs/atomate2/lib/python3.9/site-packages/ase/optimize/optimize.py    ", line 35, in get_forces
 27     return self.atoms.get_forces()
 28   File "/global/home/users/matthewkuner/.conda/envs/atomate2/lib/python3.9/site-packages/ase/atoms.py", line 812,     in get_forces
 29     forces = self._calc.get_forces(self)
 30   File "/global/home/users/matthewkuner/.conda/envs/atomate2/lib/python3.9/site-packages/ase/calculators/abc.py",     line 30, in get_forces
 31     return self.get_property('forces', atoms)
 32   File "/global/home/users/matthewkuner/.conda/envs/atomate2/lib/python3.9/site-packages/ase/calculators/calculat    or.py", line 538, in get_property
 33     self.calculate(atoms, [name], system_changes)
 34   File "/global/home/users/matthewkuner/.conda/envs/atomate2/lib/python3.9/site-packages/matgl/ext/ase.py", line     177, in calculate
 35     energies, forces, stresses, hessians = self.potential(graph, lattice, state_attr_default)
 36   File "/global/home/users/matthewkuner/.conda/envs/atomate2/lib/python3.9/site-packages/torch/nn/modules/module.    py", line 1511, in _wrapped_call_impl
 37     return self._call_impl(*args, **kwargs)
 38   File "/global/home/users/matthewkuner/.conda/envs/atomate2/lib/python3.9/site-packages/torch/nn/modules/module.    py", line 1520, in _call_impl
 39     return forward_call(*args, **kwargs)
 40   File "/global/home/users/matthewkuner/.conda/envs/atomate2/lib/python3.9/site-packages/matgl/apps/pes.py", line     98, in forward
 41     lattice = lat @ (torch.eye(3).to(st.device) + st)
 42 RuntimeError: expected scalar type Float but found Double

Any help with solving this issue would be much appreciated!

Code snippet

No response

Log output

No response

Code of Conduct

matthewkuner commented 4 months ago

Nevermind, this seems to be an issue when I am running this code in the same script as other UIPs. Some sort of dtype setting issue with pytorch.