ACEsuit / mace

MACE - Fast and accurate machine learning interatomic potentials with higher order equivariant message passing.
Other
480 stars 181 forks source link

models created from multihead fine tuning don't work in lammps #557

Closed bernstei closed 1 week ago

bernstei commented 3 weeks ago

Models creates by multihead fine tuning (multi-head-interface branch) and converted to lammps with mace_create_lammps_model do not work. They give the error below, suggestive of something (conversion to lammps? actual lammps code itself) not knowing about the data structures involved in the different heads. Also, it might just in general be nice for the results of multihead fine tuning to be completely normal MACE models.

Exception: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/mace/calculators/lammps_mace.py", line 21, in forward
      compute_displacement = False
    model = self.model
    out = (model).forward(data, False, False, False, False, compute_displacement, )
           ~~~~~~~~~~~~~~ <--- HERE
    node_energy = out["node_energy"]
    if torch.__is__(node_energy, None):
  File "code/__torch__/mace/modules/models.py", line 33, in forward
    num_graphs = torch.sub(torch.numel(data["ptr"]), 1)
    num_atoms_arange = torch.arange((torch.size(data["positions"]))[0])
    _6 = data["head"]
         ~~~~~~~~~~~~ <--- HERE
    _7 = annotate(List[Optional[Tensor]], [data["batch"]])
    node_heads = torch.index(_6, _7)

Traceback of TorchScript, original code (most recent call last):
  File "/home/cluster/bernstei/.local/lib/python3.9/site-packages/mace/calculators/lammps_mace.py", line 30, in forward
        if compute_virials:
            compute_displacement = True
        out = self.model(
              ~~~~~~~~~~ <--- HERE
            data,
            training=False,
  File "/home/cluster/bernstei/.local/lib/python3.9/site-packages/mace/modules/models.py", line 340, in forward
        num_graphs = data["ptr"].numel() - 1
        num_atoms_arange = torch.arange(data["positions"].shape[0])
        node_heads = data["head"][data["batch"]]
                     ~~~~~~~~~~~ <--- HERE
        displacement = torch.zeros(
            (num_graphs, 3, 3),
RuntimeError: KeyError: head
wcwitt commented 3 weeks ago

Happy to address this quickly. It would help if someone could point me to the main changes in the model structure for multi-head. I've not been deeply involved there.

bernstei commented 3 weeks ago

FWIW, I'm not convinced it isn't better addressed by stripping out of the saved model all the multihead stuff, at least when it's just being used for stabilizing the fine-tuning.

wcwitt commented 3 weeks ago

Maybe the create_lammps_model script could/should create a separate LAMMPS model for each head that it finds? Assuming I'm using that concept correctly.

ilyes319 commented 3 weeks ago

Yes I will just make the new Model a completely normal mace model. I do not think there is any reason to support multihead in lammps.

bernstei commented 3 weeks ago

If you could update the multi-head-interface, I'd very much like to use it. It'll make my GAP-ACE-MACE cleaner (although I also need to info on the way MACE fitting uses the weights)

wcwitt commented 3 weeks ago

Thanks @ilyes319. It doesn't sound like there is anything for me to do here, but feel free to ping me if that changes.

ilyes319 commented 1 week ago

closing this as it is supported in develop.