ACEsuit / mace

MACE - Fast and accurate machine learning interatomic potentials with higher order equivariant message passing.
Other
415 stars 157 forks source link

Don't modify torch default dtype #328

Open janosh opened 4 months ago

janosh commented 4 months ago

this line prevents running other models after MACE for relaxation in the same Python session since MACE recommends float64 for geometry optimization while e.g. chgnet and m3gnet use float32.

https://github.com/ACEsuit/mace/blob/88d49f9ed6925dec07d1777043a36e1fe4872ff3/mace/calculators/mace.py#L145

error messages are not helpful so will likely take users time to troubleshoot this issue when encountered. only current workaround is to manually reset default dtype to float32 with

torch.set_default_dtype(torch.float32)

after every time MACE is called.

Suggested fix

only convert model inputs to model's dtype without modifying all float tensors everywhere

minimal example

import torch
from ase.build import bulk
from mace.calculators import mace_mp

orig_dtype = torch.get_default_dtype()
print(f"{orig_dtype=}")
>>> orig_dtype=torch.float32

atoms = bulk("Cu") * (2, 2, 2)
atoms.calc = mace_mp(default_dtype="float64")
atoms.get_potential_energy()

new_dtype = torch.get_default_dtype()
print(f"{new_dtype=}")
>>> orig_dtype=torch.float64
hatemhelal commented 4 months ago

FWIW you can use a context manager to run different models with different torch default dtype. There is an implementation in PR #310, commit: https://github.com/ACEsuit/mace/commit/80211fd1fe29f4aa6ce872805c32b2807b343930