ACEsuit / mace

MACE - Fast and accurate machine learning interatomic potentials with higher order equivariant message passing.
Other
442 stars 173 forks source link

Error while using multihead interface #513

Closed jungsdao closed 3 weeks ago

jungsdao commented 3 weeks ago

Describe the bug

I wanted to try multihead interaface branch and trained with multihead finetuning. But when I set calculator to a structure and execute atoms.get_potential_energy(), I'm having following error. I wonder why this is happening and how this can be solved?

{
    "name": "RuntimeError",
    "message": "mat1 and mat2 shapes cannot be multiplied (39x78 and 2x78)",
    "stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[658], line 1
----> 1 atoms.get_potential_energy()

File ~/miniforge3/envs/mace_env/lib/python3.11/site-packages/ase/atoms.py:755, in Atoms.get_potential_energy(self, force_consistent, apply_constraint)
    752     energy = self._calc.get_potential_energy(
    753         self, force_consistent=force_consistent)
    754 else:
--> 755     energy = self._calc.get_potential_energy(self)
    756 if apply_constraint:
    757     for constraint in self.constraints:

File ~/miniforge3/envs/mace_env/lib/python3.11/site-packages/ase/calculators/abc.py:24, in GetPropertiesMixin.get_potential_energy(self, atoms, force_consistent)
     22 else:
     23     name = 'energy'
---> 24 return self.get_property(name, atoms)

File ~/miniforge3/envs/mace_env/lib/python3.11/site-packages/ase/calculators/calculator.py:538, in BaseCalculator.get_property(self, name, atoms, allow_calculation)
    535     if self.use_cache:
    536         self.atoms = atoms.copy()
--> 538     self.calculate(atoms, [name], system_changes)
    540 if name not in self.results:
    541     # For some reason the calculator was not able to do what we want,
    542     # and that is OK.
    543     raise PropertyNotImplementedError(
    544         '{} not present in this ' 'calculation'.format(name)
    545     )

File ~/miniforge3/envs/mace_env/lib/python3.11/site-packages/mace/calculators/mace.py:234, in calculate(self, atoms, properties, system_changes)
    219 config = data.config_from_atoms(atoms, charges_key=self.charges_key)
    220 data_loader = torch_geometric.dataloader.DataLoader(
    221     dataset=[
    222         data.AtomicData.from_config(
   (...)
    231     drop_last=False,
    232 )
--> 234 if self.model_type in [\"MACE\", \"EnergyDipoleMACE\"]:
    235     batch = next(iter(data_loader)).to(self.device)
    236     node_heads = batch[\"head\"][batch[\"batch\"]]

File ~/miniforge3/envs/mace_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniforge3/envs/mace_env/lib/python3.11/site-packages/mace/modules/blocks.py:146, in forward(self, x)
    145 def __init__(self, atomic_energies: Union[np.ndarray, torch.Tensor]):
--> 146     super().__init__()
    147     # assert len(atomic_energies.shape) == 1
    149     self.register_buffer(
    150         \"atomic_energies\",
    151         torch.tensor(atomic_energies, dtype=torch.get_default_dtype()),
    152     )

RuntimeError: mat1 and mat2 shapes cannot be multiplied (39x78 and 2x78)"
}

I've attached my logfile of MACE training. umbrella_run-123.log

CheukHinHoJerry commented 3 weeks ago

Did you use the multi-head interface branch too when you do the evaluation?

jungsdao commented 3 weeks ago

Yeah I think this was the issue. Thanks for pointing out :)