Open TomaSusi opened 1 month ago
Hey @TomaSusi,
The models were trained with float64 and because MPS does not support float64, it is a bit of a pain to deliver the model on MPS. The short term hack for you is the following:
model = torch.load(model_path, device="cpu")
model = model.float()
torch.save(model, new_model_path)
calc = MACECalculator(new_model_path, device="mps")
Thanks for the quick reply!
Loading the model with the given syntax doesn't work on the latest pytorch:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[334], line 7
4 model_path_off = 'mace-off[/MACE-OFF23_medium.model](http://localhost:8889/MACE-OFF23_medium.model)'
5 new_model_path_off = 'mace-off[/MACE-OFF23_medium_mps.model](http://localhost:8889/MACE-OFF23_medium_mps.model)'
----> 7 model_off = torch.load(model_path_off, device="cpu")
8 model_off = model_off.float()
9 torch.save(model_off, new_model_path_off)
File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py:1114](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py#line=1113), in load(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)
1112 except RuntimeError as e:
1113 raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
-> 1114 return _legacy_load(
1115 opened_file, map_location, pickle_module, **pickle_load_args
1116 )
File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py:1338](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py#line=1337), in _legacy_load(f, map_location, pickle_module, **pickle_load_args)
1332 if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2):
1333 raise RuntimeError(
1334 "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. "
1335 f'Received object of type "{type(f)}". Please update to Python 3.8.2 or newer to restore this '
1336 "functionality.")
-> 1338 magic_number = pickle_module.load(f, **pickle_load_args)
1339 if magic_number != MAGIC_NUMBER:
1340 raise RuntimeError("Invalid magic number; corrupt file?")
TypeError: 'device' is an invalid keyword argument for load()
If I update this to map_location='cpu'
(or remove the keyword), the model is loaded.
However, this results in another error:
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Cell In[358], line 9
7 prim.calc = calc_off
8 print("forces on primitive:")
----> 9 print(prim.get_forces())
File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/ase/atoms.py:812](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/ase/atoms.py#line=811), in Atoms.get_forces(self, apply_constraint, md)
810 if self._calc is None:
811 raise RuntimeError('Atoms object has no calculator.')
--> 812 forces = self._calc.get_forces(self)
814 if apply_constraint:
815 # We need a special md flag here because for MD we want
816 # to skip real constraints but include special "constraints"
817 # Like Hookean.
818 for constraint in self.constraints:
File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/ase/calculators/abc.py:30](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/ase/calculators/abc.py#line=29), in GetPropertiesMixin.get_forces(self, atoms)
29 def get_forces(self, atoms=None):
---> 30 return self.get_property('forces', atoms)
File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/ase/calculators/calculator.py:538](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/ase/calculators/calculator.py#line=537), in BaseCalculator.get_property(self, name, atoms, allow_calculation)
535 if self.use_cache:
536 self.atoms = atoms.copy()
--> 538 self.calculate(atoms, [name], system_changes)
540 if name not in self.results:
541 # For some reason the calculator was not able to do what we want,
542 # and that is OK.
543 raise PropertyNotImplementedError(
544 '{} not present in this ' 'calculation'.format(name)
545 )
File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/mace.py:244](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/mace.py#line=243), in MACECalculator.calculate(self, atoms, properties, system_changes)
242 for i, model in enumerate(self.models):
243 batch = self._clone_batch(batch_base)
--> 244 out = model(
245 batch.to_dict(),
246 compute_stress=compute_stress,
247 training=self.use_compile,
248 )
249 if self.model_type in ["MACE", "EnergyDipoleMACE"]:
250 ret_tensors["energies"][i] = out["energy"].detach()
File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/nn/modules/module.py:1553](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1552), in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/nn/modules/module.py:1562](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1561), in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/modules/models.py:395](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/modules/models.py#line=394), in ScaleShiftMACE.forward(self, data, training, compute_force, compute_virials, compute_stress, compute_displacement, compute_hessian)
393 total_energy = e0 + inter_e
394 node_energy = node_e0 + node_inter_es
--> 395 forces, virials, stress, hessian = get_outputs(
396 energy=inter_e,
397 positions=data["positions"],
398 displacement=displacement,
399 cell=data["cell"],
400 training=training,
401 compute_force=compute_force,
402 compute_virials=compute_virials,
403 compute_stress=compute_stress,
404 compute_hessian=compute_hessian,
405 )
406 output = {
407 "energy": total_energy,
408 "node_energy": node_energy,
(...)
415 "node_feats": node_feats_out,
416 }
418 return output
File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/modules/utils.py:185](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/modules/utils.py#line=184), in get_outputs(energy, positions, displacement, cell, training, compute_force, compute_virials, compute_stress, compute_hessian)
167 def get_outputs(
168 energy: torch.Tensor,
169 positions: torch.Tensor,
(...)
181 Optional[torch.Tensor],
182 ]:
183 if (compute_virials or compute_stress) and displacement is not None:
184 # forces come for free
--> 185 forces, virials, stress = compute_forces_virials(
186 energy=energy,
187 positions=positions,
188 displacement=displacement,
189 cell=cell,
190 compute_stress=compute_stress,
191 training=(training or compute_hessian),
192 )
193 elif compute_force:
194 forces, virials, stress = (
195 compute_forces(
196 energy=energy,
(...)
201 None,
202 )
File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/modules/utils.py:62](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/modules/utils.py#line=61), in compute_forces_virials(energy, positions, displacement, cell, training, compute_stress)
60 if compute_stress and virials is not None:
61 cell = cell.view(-1, 3, 3)
---> 62 volume = torch.linalg.det(cell).abs().unsqueeze(-1)
63 stress = virials [/](http://localhost:8889/) volume.view(-1, 1, 1)
64 stress = torch.where(torch.abs(stress) < 1e10, stress, torch.zeros_like(stress))
NotImplementedError: The operator 'aten::_linalg_det.result' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
Setting this environment variable does not seem to help (I tried both %env within the notebook and !export, and editing my .zprofile file. I posted into the suggested issue (https://github.com/pytorch/pytorch/issues/77764#issuecomment-2403721666).
Hi,
Just getting started with MACE but am really digging it! I was excited to see that you support Apple GPUs, but is that only for training? When I try to use a
mace_off()
ormace_mp()
ASE calculator and specify both the dtype and the device, I get an error:Or maybe this is just a simple bug..? I am running PyTorch 2.4.1.