ACEsuit / mace

MACE - Fast and accurate machine learning interatomic potentials with higher order equivariant message passing.
Other
553 stars 205 forks source link

multihead fine-tuning starting from a previously fit model gives a tensor size error (E0 related?) #622

Open bernstei opened 1 month ago

bernstei commented 1 month ago

Running

mace_run_train --foundation_model='small' --ema_decay=0.995 --energy_weight=1.0 --forces_weight=1.0 --stress_weight=1.0 --max_num_epochs=2 --scheduler_patience=5 --patience=40 --clip_grad=100.0 --num_samples_pt=10 --batch_size=2 --valid_batch_size=2 --energy_key='REF_energy' --forces_key='REF_forces' --stress_key='REF_stress' --name='MACE' --valid_file='_MACE_valid_file_configs._rde9yoa.xyz' --train_file='_MACE_train_file_configs.qm_t3m5e.xyz'

rm -r mp_finetuning-MACE_run-123*xyz logs results checkpoints  MACE_compiled.model

mace_run_train --foundation_model='MACE.model' --ema_decay=0.995 --energy_weight=1.0 --forces_weight=1.0 --stress_weight=1.0 --max_num_epochs=2 --scheduler_patience=5 --patience=40 --clip_grad=100.0 --num_samples_pt=10 --batch_size=2 --valid_batch_size=2 --energy_key='REF_energy' --forces_key='REF_forces' --stress_key='REF_stress' --name='MACE' --valid_file='_MACE_valid_file_configs._rde9yoa.xyz' --train_file='_MACE_train_file_configs.qm_t3m5e.xyz'

the second fit gives the error

Traceback (most recent call last):
  File "/home/cluster/bernstei/.local/bin/mace_run_train", line 8, in <module>
    sys.exit(main())
  File "/home/cluster/bernstei/.local/lib/python3.9/site-packages/mace/cli/run_train.py", line 63, in main
    run(args)
  File "/home/cluster/bernstei/.local/lib/python3.9/site-packages/mace/cli/run_train.py", line 356, in run
    atomic_energies_dict[head_config.head_name] = {
  File "/home/cluster/bernstei/.local/lib/python3.9/site-packages/mace/cli/run_train.py", line 357, in <dictcomp>
    z: model_foundation.atomic_energies_fn.atomic_energies[
RuntimeError: a Tensor with 13 elements cannot be converted to Scalar

xyz files are here, with and extra .txt suffix to allow for github upload [edited - replaced txt files with a single zip] _MACE_files.zip

bernstei commented 1 month ago

@ilyes319 have you had any chance to look at this?

ilyes319 commented 3 weeks ago

@bernstei I should have fixed that in the main branch. Could you try and tell me if it is fixed indeed.

bernstei commented 3 weeks ago

The real code that tries to do this still fails. Let me see what's going on with my reproducible example above.

bernstei commented 3 weeks ago

When I try to run the code from the original description and the current main branch, I still get an error

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "<eval_with_key>.122", line 13, in forward
    reshape_1 = w.reshape(-1, 256);  w = None
    reshape_2 = reshape.reshape(getitem_1, 128, 1);  reshape = None
    reshape_3 = reshape_1.reshape((128, 2));  reshape_1 = None
                ~~~~~~~~~~~~~~~~~ <--- HERE
    tensordot = torch.functional.tensordot(reshape_2, reshape_3, ([1], [0]), out = None);  reshape_2 = reshape_3 = None
    mul = tensordot * 0.08838834764831845;  tensordot = None
RuntimeError: shape '[128, 2]' is invalid for input of size 512

Note that I'm not sure that's the error that the real code is getting, but the example above should be working regardless.

ilyes319 commented 3 weeks ago

Is that the full error? Just to check, can you confirm you are not using the "compiled" model but the vanilla model for restarting the finetuning. (btw I still can not run your example because of the error: ValueError: invalid literal for int() with base 10: 'Lattice="2.6989481088809497')

bernstei commented 3 weeks ago

(btw I still can not run your example because of the error: ValueError: invalid literal for int() with base 10: 'Lattice="2.6989481088809497')

Dunno - it works for me. Are you using some weird ASE version? Can you post the full error? Those are all a normal extxyz files.

MACE_compiled.model does not exist (it's deleted after the 1st fit) when I run the 2nd fit command. Its full output is

tin 1044 : mace_run_train --foundation_model='MACE.model' --ema_decay=0.995 --energy_weight=1.0 --forces_weight=1.0 --stress_weight=1.0 --max_num_epochs=2 --scheduler_patience=5 --patience=40 --clip_grad=100.0 --num_samples_pt=10 --batch_size=2 --valid_batch_size=2 --energy_key='REF_energy' --forces_key='REF_forces' --stress_key='REF_stress' --name='MACE' --valid_file='_MACE_valid_file_configs._rde9yoa.xyz' --train_file='_MACE_train_file_configs.qm_t3m5e.xyz'

2024-11-04 12:00:17.299 INFO: ===========VERIFYING SETTINGS===========
2024-11-04 12:00:17.299 INFO: MACE version: 0.3.7
2024-11-04 12:00:17.299 INFO: Using CPU
2024-11-04 12:00:17.435 INFO: Using foundation model MACE.model as initial checkpoint.
2024-11-04 12:00:17.436 INFO: ===========LOADING INPUT DATA===========
2024-11-04 12:00:17.436 INFO: Using heads: ['default']
2024-11-04 12:00:17.436 INFO: =============    Processing head default     ===========
2024-11-04 12:00:17.457 INFO: Using isolated atom energies from training file
2024-11-04 12:00:17.457 INFO: Training set [32 configs, 32 energy, 192 forces] loaded from '_MACE_train_file_configs.qm_t3m5e.xyz'
2024-11-04 12:00:17.462 INFO: Validation set [8 configs, 8 energy, 48 forces] loaded from '_MACE_valid_file_configs._rde9yoa.xyz'
2024-11-04 12:00:17.462 INFO: Total number of configurations: train=32, valid=8, tests=[],
2024-11-04 12:00:17.462 INFO: ==================Using multiheads finetuning mode==================
2024-11-04 12:00:17.462 INFO: Using foundation model for multiheads finetuning with Materials Project data
2024-11-04 12:00:17.462 INFO: Using Materials Project dataset with /home/cluster/bernstei/.cache/mace/mp_traj_combinedxyz
2024-11-04 12:00:17.462 INFO: Using Materials Project descriptors with /home/cluster/bernstei/.cache/mace/descriptorsnpy
2024-11-04 12:00:17.514 INFO: Using CPU
2024-11-04 12:00:17.531 INFO: Filtering configurations based on the finetuning set, filtering type: combinations, elements: ['Ni']
2024-11-04 12:01:06.788 INFO: Number of configurations after filtering 6 is less than the number of samples 10, selecting random configurations for the rest.
2024-11-04 12:01:07.128 INFO: Saving the selected configurations
2024-11-04 12:01:07.131 INFO: Saving a combined XYZ file
2024-11-04 12:01:07.465 WARNING: Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'energy' to 'REF_energy'. You need to use --energy_key='REF_energy' to specify the chosen key name.
2024-11-04 12:01:07.466 WARNING: Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'forces' to 'REF_forces'. You need to use --forces_key='REF_forces' to specify the chosen key name.
2024-11-04 12:01:07.467 WARNING: Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'stress' to 'REF_stress'. You need to use --stress_key='REF_stress' to specify the chosen key name.
2024-11-04 12:01:07.469 INFO: Training set [10 configs, 10 energy, 360 forces] loaded from 'mp_finetuning-MACE_run-123.xyz'
2024-11-04 12:01:07.469 INFO: Using random 10% of training set for validation with following indices: [9]
2024-11-04 12:01:07.469 INFO: Validaton set contains 1 configurations [1 energy, 129 forces]
2024-11-04 12:01:07.469 INFO: Total number of configurations: train=9, valid=1
2024-11-04 12:01:07.469 INFO: Atomic Numbers used: [7, 8, 9, 11, 13, 15, 27, 28, 42, 52, 55, 68, 82]
2024-11-04 12:01:07.469 INFO: Foundation model has multiple heads, using the first head as foundation E0s.
2024-11-04 12:01:07.469 INFO: Foundation model has multiple heads, using the first head as foundation E0s.
2024-11-04 12:01:07.470 INFO: Atomic Energies used (z: eV) for head default: {28: -0.10007476}
2024-11-04 12:01:07.470 INFO: Atomic Energies used (z: eV) for head pt_head: {7: -7.360100452662763, 8: -7.28459863421322, 9: -4.896490881731322, 11: -2.7593613569762425, 13: -4.846881245288104, 15: -6.9632957911820235, 27: -5.577439766222147, 28: -5.172747618813715, 42: -8.791678800595722, 52: -2.8804045971118897, 55: -2.765284507132287, 68: -6.85029094445494, 82: -3.730042357127322}
2024-11-04 12:01:07.500 INFO: Average number of neighbors: 61.964672446250916
2024-11-04 12:01:07.500 INFO: During training the following quantities will be reported: energy, forces, virials, stress
2024-11-04 12:01:07.500 INFO: ===========MODEL DETAILS===========
2024-11-04 12:01:07.516 INFO: Loading FOUNDATION model
2024-11-04 12:01:07.516 INFO: Model configuration extracted from foundation model
2024-11-04 12:01:07.516 INFO: Using universal loss function for fine-tuning
2024-11-04 12:01:07.516 INFO: Message passing with hidden irreps 128x0e)
2024-11-04 12:01:07.516 INFO: 2 layers, each with correlation order: 3 (body order: 4) and spherical harmonics up to: l=3
2024-11-04 12:01:07.516 INFO: Radial cutoff: 6.0 A (total receptive field for each atom: 12.0 A)
2024-11-04 12:01:07.517 INFO: Distance transform for radial basis functions: None
/home/Software/python/conda/torch/2.2.2/cpu/lib/python3.11/site-packages/torch/jit/_check.py:177: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn(
/home/Software/python/conda/torch/2.2.2/cpu/lib/python3.11/site-packages/torch/jit/_check.py:177: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn(
/home/Software/python/conda/torch/2.2.2/cpu/lib/python3.11/site-packages/torch/jit/_check.py:177: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn(
/home/Software/python/conda/torch/2.2.2/cpu/lib/python3.11/site-packages/torch/jit/_check.py:177: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn(
/home/Software/python/conda/torch/2.2.2/cpu/lib/python3.11/site-packages/torch/jit/_check.py:177: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn(
/home/Software/python/conda/torch/2.2.2/cpu/lib/python3.11/site-packages/torch/jit/_check.py:177: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn(
/home/Software/python/conda/torch/2.2.2/cpu/lib/python3.11/site-packages/torch/jit/_check.py:177: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn(
/home/Software/python/conda/torch/2.2.2/cpu/lib/python3.11/site-packages/torch/jit/_check.py:177: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn(
/home/Software/python/conda/torch/2.2.2/cpu/lib/python3.11/site-packages/torch/jit/_check.py:177: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn(
/home/Software/python/conda/torch/2.2.2/cpu/lib/python3.11/site-packages/torch/jit/_check.py:177: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn(
/home/Software/python/conda/torch/2.2.2/cpu/lib/python3.11/site-packages/torch/jit/_check.py:177: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn(
/home/Software/python/conda/torch/2.2.2/cpu/lib/python3.11/site-packages/torch/jit/_check.py:177: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn(
/home/Software/python/conda/torch/2.2.2/cpu/lib/python3.11/site-packages/torch/jit/_check.py:177: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn(
/home/Software/python/conda/torch/2.2.2/cpu/lib/python3.11/site-packages/torch/jit/_check.py:177: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn(
2024-11-04 12:01:08.365 INFO: Total number of parameters: 809610
2024-11-04 12:01:08.366 INFO: 
2024-11-04 12:01:08.366 INFO: ===========OPTIMIZER INFORMATION===========
2024-11-04 12:01:08.366 INFO: Using ADAM as parameter optimizer
2024-11-04 12:01:08.366 INFO: Batch size: 2
2024-11-04 12:01:08.366 INFO: Number of gradient updates: 41
2024-11-04 12:01:08.366 INFO: Learning rate: 0.01, weight decay: 5e-07
2024-11-04 12:01:08.366 INFO: UniversalLoss(energy_weight=1.000, forces_weight=1.000, stress_weight=1.000)
2024-11-04 12:01:08.366 INFO: Using gradient clipping with tolerance=100.000
2024-11-04 12:01:08.366 INFO: 
2024-11-04 12:01:08.367 INFO: ===========TRAINING===========
2024-11-04 12:01:08.367 INFO: Started training, reporting errors on validation set
2024-11-04 12:01:08.367 INFO: Loss metrics on validation set
Traceback (most recent call last):
  File "/home/cluster/bernstei/.local/bin/mace_run_train", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/cluster/bernstei/.local/lib/python3.11/site-packages/mace/cli/run_train.py", line 63, in main
    run(args)
  File "/home/cluster/bernstei/.local/lib/python3.11/site-packages/mace/cli/run_train.py", line 591, in run
    tools.train(
  File "/home/cluster/bernstei/.local/lib/python3.11/site-packages/mace/tools/train.py", line 183, in train
    valid_loss_head, eval_metrics = evaluate(
                                    ^^^^^^^^^
  File "/home/cluster/bernstei/.local/lib/python3.11/site-packages/mace/tools/train.py", line 411, in evaluate
    output = model(
             ^^^^^^
  File "/home/Software/python/conda/torch/2.2.2/cpu/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/Software/python/conda/torch/2.2.2/cpu/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cluster/bernstei/.local/lib/python3.11/site-packages/mace/modules/models.py", line 412, in forward
    readout(node_feats, node_heads)[num_atoms_arange, node_heads]
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/Software/python/conda/torch/2.2.2/cpu/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/Software/python/conda/torch/2.2.2/cpu/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cluster/bernstei/.local/lib/python3.11/site-packages/mace/modules/blocks.py", line 59, in forward
    return self.linear(x)  # [n_nodes, 1]
           ^^^^^^^^^^^^^^
  File "/home/Software/python/conda/torch/2.2.2/cpu/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/Software/python/conda/torch/2.2.2/cpu/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cluster/bernstei/.local/lib/python3.11/site-packages/e3nn/o3/_linear.py", line 280, in forward
    return self._compiled_main(features, weight, bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/Software/python/conda/torch/2.2.2/cpu/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/Software/python/conda/torch/2.2.2/cpu/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "<eval_with_key>.122", line 13, in forward
    reshape_1 = w.reshape(-1, 256);  w = None
    reshape_2 = reshape.reshape(getitem_1, 128, 1);  reshape = None
    reshape_3 = reshape_1.reshape((128, 2));  reshape_1 = None
                ~~~~~~~~~~~~~~~~~ <--- HERE
    tensordot = torch.functional.tensordot(reshape_2, reshape_3, ([1], [0]), out = None);  reshape_2 = reshape_3 = None
    mul = tensordot * 0.08838834764831845;  tensordot = None
RuntimeError: shape '[128, 2]' is invalid for input of size 512
bernstei commented 3 weeks ago

By the way, in case it's helpful, the real code (not the little toy example above) gives a different error. My data only has Z = 28 (Ni), so the Z=12 issue must have to do with the PT head.

tin 4955 : /home/cluster/bernstei/.local/bin/mace_run_train --foundation_model='/home/cluster/bernstei/src/work/MLIP_iterfit/wif/pytest_wif/test_cli_vasp0/stage_0_md_step_00.step_030.MACE_fit/MACE.model' --lr=0.001 --ema --ema_decay=0.995 --energy_weight=1.0 --forces_weight=1.0 --stress_weight=1.0 --max_num_epochs=2 --scheduler_patience=5 --patience=40 --clip_grad=100.0 --device='cpu' --save_cpu --distance_transform='Agnesi' --pair_repulsion --num_samples_pt=10 --batch_size=2 --valid_batch_size=2 --energy_key='REF_energy' --forces_key='REF_forces' --stress_key='REF_stress' --seed='2735729615' --name='MACE' --valid_file='/home/cluster/bernstei/src/work/MLIP_iterfit/wif/pytest_wif/test_cli_vasp0/stage_1_md_step_12.step_030.MACE_fit/_MACE_valid_file_configs.ecyxofd0.xyz' --train_file='/home/cluster/bernstei/src/work/MLIP_iterfit/wif/pytest_wif/test_cli_vasp0/stage_1_md_step_12.step_030.MACE_fit/_MACE_train_file_configs.3ry0vf15.xyz'
2024-11-04 12:02:18.244 INFO: ===========VERIFYING SETTINGS===========
2024-11-04 12:02:18.244 INFO: MACE version: 0.3.7
2024-11-04 12:02:18.244 INFO: Using CPU
2024-11-04 12:02:18.413 INFO: Using foundation model /home/cluster/bernstei/src/work/MLIP_iterfit/wif/pytest_wif/test_cli_vasp0/stage_0_md_step_00.step_030.MACE_fit/MACE.model as initial checkpoint.
2024-11-04 12:02:18.413 INFO: ===========LOADING INPUT DATA===========
2024-11-04 12:02:18.413 INFO: Using heads: ['default']
2024-11-04 12:02:18.413 INFO: =============    Processing head default     ===========
2024-11-04 12:02:18.421 INFO: Using isolated atom energies from training file
2024-11-04 12:02:18.421 INFO: Training set [4 configs, 4 energy, 24 forces] loaded from '/home/cluster/bernstei/src/work/MLIP_iterfit/wif/pytest_wif/test_cli_vasp0/stage_1_md_step_12.step_030.MACE_fit/_MACE_train_file_configs.3ry0vf15.xyz'
2024-11-04 12:02:18.423 INFO: Validation set [2 configs, 2 energy, 12 forces] loaded from '/home/cluster/bernstei/src/work/MLIP_iterfit/wif/pytest_wif/test_cli_vasp0/stage_1_md_step_12.step_030.MACE_fit/_MACE_valid_file_configs.ecyxofd0.xyz'
2024-11-04 12:02:18.423 INFO: Total number of configurations: train=4, valid=2, tests=[],
2024-11-04 12:02:18.423 INFO: ==================Using multiheads finetuning mode==================
2024-11-04 12:02:18.423 INFO: Using foundation model for multiheads finetuning with Materials Project data
2024-11-04 12:02:18.424 INFO: Using Materials Project dataset with /home/cluster/bernstei/.cache/mace/mp_traj_combinedxyz
2024-11-04 12:02:18.424 INFO: Using Materials Project descriptors with /home/cluster/bernstei/.cache/mace/descriptorsnpy
2024-11-04 12:02:18.490 INFO: Using CPU
2024-11-04 12:02:18.494 INFO: Filtering configurations based on the finetuning set, filtering type: combinations, elements: ['Ni']
2024-11-04 12:03:07.861 INFO: Number of configurations after filtering 6 is less than the number of samples 10, selecting random configurations for the rest.
2024-11-04 12:03:08.200 INFO: Saving the selected configurations
2024-11-04 12:03:08.204 INFO: Saving a combined XYZ file
2024-11-04 12:03:08.541 WARNING: Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'energy' to 'REF_energy'. You need to use --energy_key='REF_energy' to specify the chosen key name.
2024-11-04 12:03:08.542 WARNING: Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'forces' to 'REF_forces'. You need to use --forces_key='REF_forces' to specify the chosen key name.
2024-11-04 12:03:08.543 WARNING: Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'stress' to 'REF_stress'. You need to use --stress_key='REF_stress' to specify the chosen key name.
2024-11-04 12:03:08.545 INFO: Training set [10 configs, 10 energy, 315 forces] loaded from 'mp_finetuning-MACE_run-2735729615.xyz'
2024-11-04 12:03:08.545 INFO: Using random 10% of training set for validation with following indices: [5]
2024-11-04 12:03:08.545 INFO: Validaton set contains 1 configurations [1 energy, 3 forces]
2024-11-04 12:03:08.545 INFO: Total number of configurations: train=9, valid=1
2024-11-04 12:03:08.545 INFO: Atomic Numbers used: [1, 8, 12, 14, 25, 27, 28, 57, 62, 68, 77, 82]
2024-11-04 12:03:08.545 INFO: Foundation model has multiple heads, using the first head as foundation E0s.
Traceback (most recent call last):
  File "/home/cluster/bernstei/.local/bin/mace_run_train", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/cluster/bernstei/.local/lib/python3.11/site-packages/mace/cli/run_train.py", line 63, in main
    run(args)
  File "/home/cluster/bernstei/.local/lib/python3.11/site-packages/mace/cli/run_train.py", line 362, in run
    atomic_energies_dict[head_config.head_name] = {
                                                  ^
  File "/home/cluster/bernstei/.local/lib/python3.11/site-packages/mace/cli/run_train.py", line 364, in <dictcomp>
    z_table_foundation.z_to_index(z)
  File "/home/cluster/bernstei/.local/lib/python3.11/site-packages/mace/tools/utils.py", line 107, in z_to_index
    return self.zs.index(atomic_number)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: 12 is not in list
ilyes319 commented 3 weeks ago

Ok thank you, I think these are two seperate bugs. Re on ASE, the only weird thing of my version is that I am on windows, that might the problem. I ll try on a linux machine.

bernstei commented 3 weeks ago

I ll try on a linux machine.

I guess it could be a EOL mismatch type issue.

bernstei commented 2 weeks ago

So it's not a seed issue - both are defaulting to 123. However, I diffed the two log files (initial from foundation and the additional fine tune which starts from the saved earlier model), stripping out the time stamps, and get the following (excluding some trivial things, like the initial model being "small" vs my initial saved model file). Perhaps it's useful. The initial is the "small" run, the final is the second run from the initial's saved "MACE.model"

33a33,34
> INFO: Foundation model has multiple heads, using the first head as foundation E0s.
> INFO: Foundation model has multiple heads, using the first head as foundation E0s.
93,95c94,96
<     (linear_1): Linear(128x0e -> 32x0e | 4096 weights)
<     (non_linearity): Activation [x] (32x0e -> 32x0e)
<     (linear_2): Linear(32x0e -> 2x0e | 64 weights)
---
>     (linear_1): Linear(128x0e -> 64x0e | 8192 weights)
>     (non_linearity): Activation [x] (64x0e -> 64x0e)
>     (linear_2): Linear(64x0e -> 2x0e | 128 weights)
98c99
< (scale_shift): ScaleShiftBlock(scale=0.8042, 0.8042, shift=0.1641, 0.1641)
---
> (scale_shift): ScaleShiftBlock(scale=0.8042, 0.8042, 0.8042, 0.8042, shift=0.1641, 0.1641, 0.1641, 0.1641)
100c101
< INFO: Total number of parameters: 805066
---
> INFO: Total number of parameters: 809610
ilyes319 commented 1 week ago

I think I finally understood what is happening. The code was not working when the foundation model was a multihead model. A work around I made atm, is to require the foundation model to be a single head model. If it is not, then it will transform it to a a single head model by selecting a head. I have added a new arg to specify which head one wants to keep. I think in the long run, we might want to support multihead finetuning on a different heads at the same time but that would involve specifying a replay data for each previous head and that is currently not handled.

In your workflow, the pt_head of the first multihead finetuning was not removed so the output model was a two heads model. atm if a model input has two heads and none of them is named "pt_head" it will delete it when you finetune again. I preferred this solution to the alternative of just deleting the "pt_head" when saving the model, because I figured some people might want to have it in some application I have not thought of.

bernstei commented 1 week ago

I'm confused - don't we want the second step of fine-tuning to start with the fine tuned pt_head and the fine tuned real head? I don't quite follow your description of what you changed, but if you delete anything, don't you lose those parameters, so they're not used for the second time?

[added] or is it just turning the first stage fine tuned 2-head model into a single head from the fine-tuned pt_head, but also using the fine-tuned real head to initialize the real head during the 2nd fine tuning, so nothing is lost?

Also, if there a branch for me to try this with?

ilyes319 commented 1 week ago

You start finetuning a model so you end up with two heads: 1. pt_head 2. ft_head. Then you finetune again, I think you have to remove the pt_head and re initialize the pt_head from your current ft_head.

If you want to just continue training with new data on mp traj, then I suggest you just restart from the checkpoint instead of using the finetuning. Or use the finetuning but use the pt_train_file the mptraj subsamples of the first finetuning and not the full mptraj. Because if you use the full mptraj, it might select elements that are no longer covered.

bernstei commented 1 week ago

If I do use the checkpoint will that sidestep the pt_head structure selection issue, but use my new added structures for the real head? Or does the checkpoint not include the pt_head structures, so I have to be careful anyway?

ilyes319 commented 1 week ago

you do have to be careful but hopefully starting from the same seed would make the selection. I guess the safest way is to use the newly created subset as the pt_train_file.

bernstei commented 1 week ago

OK, I'll see what I can do to massage that. It'd be nice if it worked even if the pt_head gets different structures.