ACEsuit / mace

MACE - Fast and accurate machine learning interatomic potentials with higher order equivariant message passing.
Other
554 stars 205 forks source link

Specifying E0s with .json During Multihead Training #651

Closed ThomasWarford closed 3 weeks ago

ThomasWarford commented 1 month ago

Describe the bug Multihead training fails when E0s specified with json file, like here.

To Reproduce config.yaml:

name: multihead
model: MACE
heads:
    PBE:
        train_file: "./multi-theory/data/PBE.xyz"
        valid_fraction: 0.05
        E0s: "./multi-theory/data/PBE.json"
        energy_key: "key_energy"
        forces_key: "key_forces"

    PBEsol:
        train_file: "./multi-theory/data/PBEsol.xyz"
        valid_fraction: 0.05
        E0s: "./multi-theory/data/PBEsol.json"
        energy_key: "key_energy"
        forces_key: "key_forces"

scaling: rms_forces_scaling
batch_size: 2
max_num_epochs: 6
ema: true
ema_decay: 0.99
amsgrad: true
default_dtype: float32
device: cuda
seed: 3

PBE.json:

{"29": -0.126164925,
"28": -0.75244326,
"46": -1.47561768,
"45": -1.487234645,
"78": -0.56514218,
"1": -0.558860735,
"6": 0.686087735}

train_multihead.sh

OUTPUT_DIR="multihead"

mace_run_train \
  --config="./multi-theory/config/multihead.yaml" \
  --log_dir="./${OUTPUT_DIR}/logs" \
  --model_dir="./${OUTPUT_DIR}" \
  --checkpoints_dir="./${OUTPUT_DIR}/checkpoints" \
  --results_dir="./${OUTPUT_DIR}/results" \

Error

...
2024-10-22 16:46:16.366 INFO: Loading atomic energies from ./multi-theory/data/PBEsol.json
Traceback (most recent call last):
  File "/rds/user/tdw50/hpc-work/mace/mace/cli/run_train.py", line 407, in run
    logging.info(f"Atomic Energies used (z: eV) for head {head_config.head_name}: " + "{" + ", ".join([f"{z}: {atomic_energies_dict[head_config.head_name][z]}" for z in head_config.z_table.zs]) + "}")
                                                                                                               ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^
KeyError: 1

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/tdw50/miniforge3/envs/mace/bin/mace_run_train", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/rds/user/tdw50/hpc-work/mace/mace/cli/run_train.py", line 63, in main
    run(args)
  File "/rds/user/tdw50/hpc-work/mace/mace/cli/run_train.py", line 409, in run
    raise KeyError(f"Atomic number {e} not found in atomic_energies_dict for head {head_config.head_name}, add E0s for this atomic number") from e
KeyError: 'Atomic number 1 not found in atomic_energies_dict for head PBE, add E0s for this atomic number'

Possibly related: https://github.com/ACEsuit/mace/issues/371