ACEsuit / mace-mp

MACE-MP models
https://arxiv.org/abs/2401.00096
MIT License
48 stars 9 forks source link

Apple GPU support & float32 dtype #14

Open TomaSusi opened 1 month ago

TomaSusi commented 1 month ago

Hi,

Just getting started with MACE but am really digging it! I was excited to see that you support Apple GPUs, but is that only for training? When I try to use a mace_off() or mace_mp() ASE calculator and specify both the dtype and the device, I get an error:

Using MACE-OFF23 MODEL for MACECalculator with /Users/tomasusi/.cache/mace/MACE-OFF23_medium.model
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[95], line 13
---> 13 calc = mace_off(device='mps', default_dtype='float32')

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/foundations_models.py:206](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/foundations_models.py#line=205), in mace_off(model, device, default_dtype, return_raw_model, **kwargs)
    202 if default_dtype == "float32":
    203     print(
    204         "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization."
    205     )
--> 206 mace_calc = MACECalculator(
    207     model_paths=model, device=device, default_dtype=default_dtype, **kwargs
    208 )
    209 return mace_calc

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/mace.py:127](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/mace.py#line=126), in MACECalculator.__init__(self, model_paths, device, energy_units_to_eV, length_units_to_A, default_dtype, charges_key, model_type, compile_mode, fullgraph, **kwargs)
    125     self.use_compile = True
    126 else:
--> 127     self.models = [
    128         torch.load(f=model_path, map_location=device)
    129         for model_path in model_paths
    130     ]
    131     self.use_compile = False
    132 for model in self.models:

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/mace.py:128](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/mace.py#line=127), in <listcomp>(.0)
    125     self.use_compile = True
    126 else:
    127     self.models = [
--> 128         torch.load(f=model_path, map_location=device)
    129         for model_path in model_paths
    130     ]
    131     self.use_compile = False
    132 for model in self.models:

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py:1097](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py#line=1096), in load(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)
   1095             except RuntimeError as e:
   1096                 raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
-> 1097         return _load(
   1098             opened_zipfile,
   1099             map_location,
   1100             pickle_module,
   1101             overall_storage=overall_storage,
   1102             **pickle_load_args,
   1103         )
   1104 if mmap:
   1105     f_name = "" if not isinstance(f, str) else f"{f}, "

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py:1525](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py#line=1524), in _load(zip_file, map_location, pickle_module, pickle_file, overall_storage, **pickle_load_args)
   1522 # Needed for tensors where storage device and rebuild tensor device are
   1523 # not connected (wrapper subclasses and tensors rebuilt using numpy)
   1524 torch._utils._thread_local_state.map_location = map_location
-> 1525 result = unpickler.load()
   1526 del torch._utils._thread_local_state.map_location
   1528 torch._utils._validate_loaded_sparse_tensors()

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/_utils.py:200](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/_utils.py#line=199), in _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
    197 def _rebuild_tensor_v2(
    198     storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None
    199 ):
--> 200     tensor = _rebuild_tensor(storage, storage_offset, size, stride)
    201     tensor.requires_grad = requires_grad
    202     if metadata:

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/_utils.py:178](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/_utils.py#line=177), in _rebuild_tensor(storage, storage_offset, size, stride)
    176 def _rebuild_tensor(storage, storage_offset, size, stride):
    177     # first construct a tensor with the correct dtype[/device](http://localhost:8889/device)
--> 178     t = torch.empty((0,), dtype=storage.dtype, device=storage._untyped_storage.device)
    179     return t.set_(storage._untyped_storage, storage_offset, size, stride)

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

Or maybe this is just a simple bug..? I am running PyTorch 2.4.1.

ilyes319 commented 1 month ago

Hey @TomaSusi,

The models were trained with float64 and because MPS does not support float64, it is a bit of a pain to deliver the model on MPS. The short term hack for you is the following:

  1. Download the models here for MP: https://github.com/ACEsuit/mace-mp/releases/tag/mace_mp_0 or mace-off: https://github.com/ACEsuit/mace-off. They are the .model files. Select the size you want.
  2. Load the model on CPU:
    model = torch.load(model_path, device="cpu")
    model = model.float()
    torch.save(model, new_model_path)
  3. Create the calculator on MPS using the new model path:
    calc = MACECalculator(new_model_path, device="mps")
TomaSusi commented 1 month ago

Thanks for the quick reply!

Loading the model with the given syntax doesn't work on the latest pytorch:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[334], line 7
      4 model_path_off = 'mace-off[/MACE-OFF23_medium.model](http://localhost:8889/MACE-OFF23_medium.model)'
      5 new_model_path_off = 'mace-off[/MACE-OFF23_medium_mps.model](http://localhost:8889/MACE-OFF23_medium_mps.model)'
----> 7 model_off = torch.load(model_path_off, device="cpu")
      8 model_off = model_off.float()
      9 torch.save(model_off, new_model_path_off)

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py:1114](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py#line=1113), in load(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)
   1112     except RuntimeError as e:
   1113         raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
-> 1114 return _legacy_load(
   1115     opened_file, map_location, pickle_module, **pickle_load_args
   1116 )

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py:1338](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/serialization.py#line=1337), in _legacy_load(f, map_location, pickle_module, **pickle_load_args)
   1332 if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2):
   1333     raise RuntimeError(
   1334         "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. "
   1335         f'Received object of type "{type(f)}". Please update to Python 3.8.2 or newer to restore this '
   1336         "functionality.")
-> 1338 magic_number = pickle_module.load(f, **pickle_load_args)
   1339 if magic_number != MAGIC_NUMBER:
   1340     raise RuntimeError("Invalid magic number; corrupt file?")

TypeError: 'device' is an invalid keyword argument for load()

If I update this to map_location='cpu' (or remove the keyword), the model is loaded.

However, this results in another error:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[358], line 9
      7 prim.calc = calc_off
      8 print("forces on primitive:")
----> 9 print(prim.get_forces())

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/ase/atoms.py:812](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/ase/atoms.py#line=811), in Atoms.get_forces(self, apply_constraint, md)
    810 if self._calc is None:
    811     raise RuntimeError('Atoms object has no calculator.')
--> 812 forces = self._calc.get_forces(self)
    814 if apply_constraint:
    815     # We need a special md flag here because for MD we want
    816     # to skip real constraints but include special "constraints"
    817     # Like Hookean.
    818     for constraint in self.constraints:

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/ase/calculators/abc.py:30](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/ase/calculators/abc.py#line=29), in GetPropertiesMixin.get_forces(self, atoms)
     29 def get_forces(self, atoms=None):
---> 30     return self.get_property('forces', atoms)

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/ase/calculators/calculator.py:538](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/ase/calculators/calculator.py#line=537), in BaseCalculator.get_property(self, name, atoms, allow_calculation)
    535     if self.use_cache:
    536         self.atoms = atoms.copy()
--> 538     self.calculate(atoms, [name], system_changes)
    540 if name not in self.results:
    541     # For some reason the calculator was not able to do what we want,
    542     # and that is OK.
    543     raise PropertyNotImplementedError(
    544         '{} not present in this ' 'calculation'.format(name)
    545     )

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/mace.py:244](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/calculators/mace.py#line=243), in MACECalculator.calculate(self, atoms, properties, system_changes)
    242 for i, model in enumerate(self.models):
    243     batch = self._clone_batch(batch_base)
--> 244     out = model(
    245         batch.to_dict(),
    246         compute_stress=compute_stress,
    247         training=self.use_compile,
    248     )
    249     if self.model_type in ["MACE", "EnergyDipoleMACE"]:
    250         ret_tensors["energies"][i] = out["energy"].detach()

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/nn/modules/module.py:1553](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1552), in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/nn/modules/module.py:1562](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1561), in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/modules/models.py:395](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/modules/models.py#line=394), in ScaleShiftMACE.forward(self, data, training, compute_force, compute_virials, compute_stress, compute_displacement, compute_hessian)
    393 total_energy = e0 + inter_e
    394 node_energy = node_e0 + node_inter_es
--> 395 forces, virials, stress, hessian = get_outputs(
    396     energy=inter_e,
    397     positions=data["positions"],
    398     displacement=displacement,
    399     cell=data["cell"],
    400     training=training,
    401     compute_force=compute_force,
    402     compute_virials=compute_virials,
    403     compute_stress=compute_stress,
    404     compute_hessian=compute_hessian,
    405 )
    406 output = {
    407     "energy": total_energy,
    408     "node_energy": node_energy,
   (...)
    415     "node_feats": node_feats_out,
    416 }
    418 return output

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/modules/utils.py:185](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/modules/utils.py#line=184), in get_outputs(energy, positions, displacement, cell, training, compute_force, compute_virials, compute_stress, compute_hessian)
    167 def get_outputs(
    168     energy: torch.Tensor,
    169     positions: torch.Tensor,
   (...)
    181     Optional[torch.Tensor],
    182 ]:
    183     if (compute_virials or compute_stress) and displacement is not None:
    184         # forces come for free
--> 185         forces, virials, stress = compute_forces_virials(
    186             energy=energy,
    187             positions=positions,
    188             displacement=displacement,
    189             cell=cell,
    190             compute_stress=compute_stress,
    191             training=(training or compute_hessian),
    192         )
    193     elif compute_force:
    194         forces, virials, stress = (
    195             compute_forces(
    196                 energy=energy,
   (...)
    201             None,
    202         )

File [~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/modules/utils.py:62](http://localhost:8889/lab/tree/~/miniforge3/envs/mlphonons/lib/python3.11/site-packages/mace/modules/utils.py#line=61), in compute_forces_virials(energy, positions, displacement, cell, training, compute_stress)
     60 if compute_stress and virials is not None:
     61     cell = cell.view(-1, 3, 3)
---> 62     volume = torch.linalg.det(cell).abs().unsqueeze(-1)
     63     stress = virials [/](http://localhost:8889/) volume.view(-1, 1, 1)
     64     stress = torch.where(torch.abs(stress) < 1e10, stress, torch.zeros_like(stress))

NotImplementedError: The operator 'aten::_linalg_det.result' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

Setting this environment variable does not seem to help (I tried both %env within the notebook and !export, and editing my .zprofile file. I posted into the suggested issue (https://github.com/pytorch/pytorch/issues/77764#issuecomment-2403721666).