Closed bfocassio closed 6 months ago
Hey @bfocassio,
Indeed you will need to put the path of the model you want to finetune, however if you do that you will need to give the full hypers of the model. You cam find them here https://github.com/ACEsuit/mace-mp/tree/main/mace_mp_0.
It selects only at the atoms in your training set for two reasons: to not keep a large model that is memory intensive even for small dataset, and also because of isolated atom energy. If you want to keep an element, then you need to put a dummy isolated atom with the right isolated atom energy to keep it.
That is due to not providing the right input hypers. See my first response to see the hypers.
I am not sure I understand this question. The fine tuning currently fine tunes all the mace-mp model. It might change if we find better finetuning protocols.
Hi @ilyes319 ,
Ok, I understand.
So, just to try using the 2024-01-07-mace-128-L2_epoch-199.model for fine-tunning. I've created dummy structures, each with 2 atoms, for all the 89 different atomic species from the MPTrj dataset.
Here is the command for training:
python /home/bruno.focassio/codes/mace-foundations/scripts/run_train.py \
--name="2024-01-07-mace-128-L2" \
--foundation_model="2024-01-07-mace-128-L2.model" \
--train_file="training_data.xyz" \
--test_file="test_data.xyz" \
--valid_fraction=0.05 \
--loss="universal" \
--energy_weight=1 \
--forces_weight=10 \
--compute_stress=True \
--stress_weight=100 \
--stress_key='stress' \
--eval_interval=1 \
--error_table='PerAtomMAE' \
--E0s="{1: -3.667168021358939, 2: -1.3320953124042916, 3: -3.482100566595956, 4: -4.736697230897597, 5: -7.724935420523256, 6: -8.405573550273285, 7: -7.360100452662763, 8: -7.28459863421322, 9: -4.896490881731322, 10: 1.3917755836700962e-12, 11: -2.7593613569762425, 12: -2.814047612069227, 13: -4.846881245288104, 14: -7.694793133351899, 15: -6.9632957911820235, 16: -4.672630400190884, 17: -2.8116892814008096, 18: -0.06259504416367478, 19: -2.6176454856894793, 20: -5.390461060484104, 21: -7.8857952163517675, 22: -10.268392986214433, 23: -8.665147785496703, 24: -9.233050763772013, 25: -8.304951520770791, 26: -7.0489865771593765, 27: -5.577439766222147, 28: -5.172747618813715, 29: -3.2520726958619472, 30: -1.2901611618726314, 31: -3.527082192997912, 32: -4.70845955030298, 33: -3.9765109025623238, 34: -3.886231055836541, 35: -2.5184940099633986, 36: 6.766947645687137, 37: -2.5634958965928316, 38: -4.938005211501922, 39: -10.149818838085771, 40: -11.846857579882572, 41: -12.138896361658485, 42: -8.791678800595722, 43: -8.78694939675911, 44: -7.78093221529871, 45: -6.850021409115055, 46: -4.891019073240479, 47: -2.0634296773864045, 48: -0.6395695518943755, 49: -2.7887442084286693, 50: -3.818604275441892, 51: -3.587068329278862, 52: -2.8804045971118897, 53: -1.6355986842433357, 54: 9.846723842807721, 55: -2.765284507132287, 56: -4.990956432167774, 57: -8.933684809576345, 58: -8.735591176647514, 59: -8.018966025544966, 60: -8.251491970213372, 61: -7.591719594359237, 62: -8.169659881166858, 63: -13.592664636171698, 64: -18.517523458456985, 65: -7.647396572993602, 66: -8.122981037851925, 67: -7.607787319678067, 68: -6.85029094445494, 69: -7.8268821327130365, 70: -3.584786591677161, 71: -7.455406192077973, 72: -12.796283502572146, 73: -14.108127281277586, 74: -9.354916969477486, 75: -11.387537567890853, 76: -9.621909492152557, 77: -7.324393429417677, 78: -5.3046964808341945, 79: -2.380092582080244, 80: 0.24948924158195362, 81: -2.3239789120665026, 82: -3.730042357127322, 83: -3.438792347649683, 89: -5.062878214511315, 90: -11.02462566385297, 91: -12.265613551943261, 92: -13.855648206100362, 93: -14.933092020258243, 94: -15.282826131998245}" \
--interaction_first="RealAgnosticResidualInteractionBlock" \
--interaction="RealAgnosticResidualInteractionBlock" \
--num_interactions=2 \
--correlation=3 \
--max_ell=3 \
--r_max=6.0 \
--max_L=2 \
--num_channels=128 \
--num_radial_basis=10 \
--MLP_irreps="16x0e" \
--scaling='rms_forces_scaling' \
--lr=0.005 \
--weight_decay=1e-8 \
--ema \
--ema_decay=0.995 \
--scheduler_patience=5 \
--batch_size=2 \
--valid_batch_size=4 \
--max_num_epochs=300 \
--patience=30 \
--amsgrad \
--device="cuda" \
--default_dtype="float64" \
--seed=1 \
--clip_grad=100 \
--keep_checkpoints \
--restart_latest \
--save_cpu
Notice I kept the average atomic energies from the training input (https://github.com/ACEsuit/mace-mp/issues/1#issuecomment-1931665658)
I've created a checkpoints folder with the checkpoint downloaded from hugging face
Here is the log with the error:
2024-02-19 16:15:27.752 INFO: MACE version: 0.3.4
2024-02-19 16:15:27.753 INFO: Configuration: Namespace(name='2024-01-07-mace-128-L2', seed=1, log_dir='logs', model_dir='.', checkpoints_dir='checkpoints', results_dir='results', downloads_dir='downloads', device='cuda', default_dtype='float64', log_level='INFO', error_table='PerAtomMAE', model='MACE', r_max=6.0, radial_type='bessel', num_radial_basis=10, num_cutoff_basis=5, pair_repulsion=False, distance_transform=False, interaction='RealAgnosticResidualInteractionBlock', interaction_first='RealAgnosticResidualInteractionBlock', max_ell=3, correlation=3, num_interactions=2, MLP_irreps='16x0e', radial_MLP='[64, 64, 64]', hidden_irreps='128x0e + 128x1o', num_channels=128, max_L=2, gate='silu', scaling='rms_forces_scaling', avg_num_neighbors=1, compute_avg_num_neighbors=True, compute_stress=True, compute_forces=True, train_file='/home/bruno.focassio/mace_large_model/uip/fine_tune_2024_example/training_data.xyz', valid_file=None, valid_fraction=0.05, test_file='/home/bruno.focassio/mace_large_model/uip/fine_tune_2024_example/test_data.xyz', E0s='{1: -3.667168021358939, 2: -1.3320953124042916, 3: -3.482100566595956, 4: -4.736697230897597, 5: -7.724935420523256, 6: -8.405573550273285, 7: -7.360100452662763, 8: -7.28459863421322, 9: -4.896490881731322, 10: 1.3917755836700962e-12, 11: -2.7593613569762425, 12: -2.814047612069227, 13: -4.846881245288104, 14: -7.694793133351899, 15: -6.9632957911820235, 16: -4.672630400190884, 17: -2.8116892814008096, 18: -0.06259504416367478, 19: -2.6176454856894793, 20: -5.390461060484104, 21: -7.8857952163517675, 22: -10.268392986214433, 23: -8.665147785496703, 24: -9.233050763772013, 25: -8.304951520770791, 26: -7.0489865771593765, 27: -5.577439766222147, 28: -5.172747618813715, 29: -3.2520726958619472, 30: -1.2901611618726314, 31: -3.527082192997912, 32: -4.70845955030298, 33: -3.9765109025623238, 34: -3.886231055836541, 35: -2.5184940099633986, 36: 6.766947645687137, 37: -2.5634958965928316, 38: -4.938005211501922, 39: -10.149818838085771, 40: -11.846857579882572, 41: -12.138896361658485, 42: -8.791678800595722, 43: -8.78694939675911, 44: -7.78093221529871, 45: -6.850021409115055, 46: -4.891019073240479, 47: -2.0634296773864045, 48: -0.6395695518943755, 49: -2.7887442084286693, 50: -3.818604275441892, 51: -3.587068329278862, 52: -2.8804045971118897, 53: -1.6355986842433357, 54: 9.846723842807721, 55: -2.765284507132287, 56: -4.990956432167774, 57: -8.933684809576345, 58: -8.735591176647514, 59: -8.018966025544966, 60: -8.251491970213372, 61: -7.591719594359237, 62: -8.169659881166858, 63: -13.592664636171698, 64: -18.517523458456985, 65: -7.647396572993602, 66: -8.122981037851925, 67: -7.607787319678067, 68: -6.85029094445494, 69: -7.8268821327130365, 70: -3.584786591677161, 71: -7.455406192077973, 72: -12.796283502572146, 73: -14.108127281277586, 74: -9.354916969477486, 75: -11.387537567890853, 76: -9.621909492152557, 77: -7.324393429417677, 78: -5.3046964808341945, 79: -2.380092582080244, 80: 0.24948924158195362, 81: -2.3239789120665026, 82: -3.730042357127322, 83: -3.438792347649683, 89: -5.062878214511315, 90: -11.02462566385297, 91: -12.265613551943261, 92: -13.855648206100362, 93: -14.933092020258243, 94: -15.282826131998245}', energy_key='energy', forces_key='forces', virials_key='virials', stress_key='stress', dipole_key='dipole', charges_key='charges', loss='universal', forces_weight=10.0, swa_forces_weight=100.0, energy_weight=1.0, swa_energy_weight=1000.0, virials_weight=1.0, swa_virials_weight=10.0, stress_weight=100.0, swa_stress_weight=10.0, dipole_weight=1.0, swa_dipole_weight=1.0, config_type_weights='{"Default":1.0}', huber_delta=0.01, optimizer='adam', batch_size=2, valid_batch_size=4, lr=0.005, swa_lr=0.001, weight_decay=1e-08, amsgrad=True, scheduler='ReduceLROnPlateau', lr_factor=0.8, scheduler_patience=5, lr_scheduler_gamma=0.9993, swa=False, start_swa=None, ema=True, ema_decay=0.995, max_num_epochs=300, patience=30, foundation_model='2024-01-07-mace-128-L2.model', foundation_model_readout=True, eval_interval=1, keep_checkpoints=True, restart_latest=True, save_cpu=True, clip_grad=100.0, wandb=False, wandb_project='', wandb_entity='', wandb_name='', wandb_log_hypers=['num_channels', 'max_L', 'correlation', 'lr', 'swa_lr', 'weight_decay', 'batch_size', 'max_num_epochs', 'start_swa', 'energy_weight', 'forces_weight'])
2024-02-19 16:15:27.814 INFO: CUDA version: 11.3, CUDA device: 0
2024-02-19 16:15:27.920 INFO: Loaded 89 training configurations from '/home/bruno.focassio/mace_large_model/uip/fine_tune_2024_example/training_data.xyz'
2024-02-19 16:15:27.921 INFO: Using random 5.0% of training set for validation
2024-02-19 16:15:27.957 INFO: Loaded 89 test configurations from '/home/bruno.focassio/mace_large_model/uip/fine_tune_2024_example/test_data.xyz'
2024-02-19 16:15:27.957 INFO: Total number of configurations: train=85, valid=4, tests=[Default: 89]
2024-02-19 16:15:27.958 INFO: AtomicNumberTable: (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 89, 90, 91, 92, 93, 94)
2024-02-19 16:15:27.958 INFO: Atomic Energies not in training file, using command line argument E0s
2024-02-19 16:15:27.959 INFO: Atomic energies: [-3.667168021358939, -1.3320953124042916, -3.482100566595956, -4.736697230897597, -7.724935420523256, -8.405573550273285, -7.360100452662763, -7.28459863421322, -4.896490881731322, 1.3917755836700962e-12, -2.7593613569762425, -2.814047612069227, -4.846881245288104, -7.694793133351899, -6.9632957911820235, -4.672630400190884, -2.8116892814008096, -0.06259504416367478, -2.6176454856894793, -5.390461060484104, -7.8857952163517675, -10.268392986214433, -8.665147785496703, -9.233050763772013, -8.304951520770791, -7.0489865771593765, -5.577439766222147, -5.172747618813715, -3.2520726958619472, -1.2901611618726314, -3.527082192997912, -4.70845955030298, -3.9765109025623238, -3.886231055836541, -2.5184940099633986, 6.766947645687137, -2.5634958965928316, -4.938005211501922, -10.149818838085771, -11.846857579882572, -12.138896361658485, -8.791678800595722, -8.78694939675911, -7.78093221529871, -6.850021409115055, -4.891019073240479, -2.0634296773864045, -0.6395695518943755, -2.7887442084286693, -3.818604275441892, -3.587068329278862, -2.8804045971118897, -1.6355986842433357, 9.846723842807721, -2.765284507132287, -4.990956432167774, -8.933684809576345, -8.735591176647514, -8.018966025544966, -8.251491970213372, -7.591719594359237, -8.169659881166858, -13.592664636171698, -18.517523458456985, -7.647396572993602, -8.122981037851925, -7.607787319678067, -6.85029094445494, -7.8268821327130365, -3.584786591677161, -7.455406192077973, -12.796283502572146, -14.108127281277586, -9.354916969477486, -11.387537567890853, -9.621909492152557, -7.324393429417677, -5.3046964808341945, -2.380092582080244, 0.24948924158195362, -2.3239789120665026, -3.730042357127322, -3.438792347649683, -5.062878214511315, -11.02462566385297, -12.265613551943261, -13.855648206100362, -14.933092020258243, -15.282826131998245]
2024-02-19 16:15:28.078 INFO: UniversalLoss(energy_weight=1.000, forces_weight=10.000, stress_weight=100.000)
2024-02-19 16:15:28.127 INFO: Average number of neighbors: nan
2024-02-19 16:15:28.128 INFO: Selected the following outputs: {'energy': True, 'forces': True, 'virials': False, 'stress': True, 'dipoles': False}
2024-02-19 16:15:28.128 INFO: Building model
2024-02-19 16:15:28.130 INFO: Hidden irreps: 128x0e+128x1o+128x2e
2024-02-19 16:15:28.289 WARNING: Standard deviation of the scaling is zero, Changing to no scaling
2024-02-19 16:15:31.924 INFO: Using foundation model 2024-01-07-mace-128-L2.model as initial checkpoint.
Traceback (most recent call last):
File "/home/bruno.focassio/codes/mace-foundations/scripts/run_train.py", line 6, in <module>
main()
File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/mace/cli/run_train.py", line 398, in main
model = load_foundations(
File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/mace/tools/utils.py", line 173, in load_foundations
indices_weights = [z_table.z_to_index(z) for z in new_z_table.zs]
File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/mace/tools/utils.py", line 173, in <listcomp>
indices_weights = [z_table.z_to_index(z) for z in new_z_table.zs]
File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/mace/tools/utils.py", line 89, in z_to_index
return self.zs.index(atomic_number)
ValueError: 1 is not in list
Attached there is the dummy train and test files. The energies were predicted with the mace calculator from this model.
Hi again @ilyes319
I realized the error above was caused by a model that was overwritten with a different table of atomic numbers. I downloaded the 2024-01-07-mace-128-L2.model again and it gets past that error.
Now the trouble is, as I mentioned before, a shape mismatch. Can you please take a look?
python /home/bruno.focassio/codes/mace-foundations/scripts/run_train.py \
--name="2024-01-07-mace-128-L2" \
--foundation_model="2024-01-07-mace-128-L2.model" \
--train_file="training_data.xyz" \
--test_file="test_data.xyz" \
--valid_fraction=0.05 \
--loss="universal" \
--energy_weight=1 \
--forces_weight=10 \
--compute_stress=True \
--stress_weight=100 \
--stress_key='stress' \
--eval_interval=1 \
--error_table='PerAtomMAE' \
--E0s="{1: -3.667168021358939, 2: -1.3320953124042916, 3: -3.482100566595956, 4: -4.736697230897597, 5: -7.724935420523256, 6: -8.405573550273285, 7: -7.360100452662763, 8: -7.28459863421322, 9: -4.896490881731322, 10: 1.3917755836700962e-12, 11: -2.7593613569762425, 12: -2.814047612069227, 13: -4.846881245288104, 14: -7.694793133351899, 15: -6.9632957911820235, 16: -4.672630400190884, 17: -2.8116892814008096, 18: -0.06259504416367478, 19: -2.6176454856894793, 20: -5.390461060484104, 21: -7.8857952163517675, 22: -10.268392986214433, 23: -8.665147785496703, 24: -9.233050763772013, 25: -8.304951520770791, 26: -7.0489865771593765, 27: -5.577439766222147, 28: -5.172747618813715, 29: -3.2520726958619472, 30: -1.2901611618726314, 31: -3.527082192997912, 32: -4.70845955030298, 33: -3.9765109025623238, 34: -3.886231055836541, 35: -2.5184940099633986, 36: 6.766947645687137, 37: -2.5634958965928316, 38: -4.938005211501922, 39: -10.149818838085771, 40: -11.846857579882572, 41: -12.138896361658485, 42: -8.791678800595722, 43: -8.78694939675911, 44: -7.78093221529871, 45: -6.850021409115055, 46: -4.891019073240479, 47: -2.0634296773864045, 48: -0.6395695518943755, 49: -2.7887442084286693, 50: -3.818604275441892, 51: -3.587068329278862, 52: -2.8804045971118897, 53: -1.6355986842433357, 54: 9.846723842807721, 55: -2.765284507132287, 56: -4.990956432167774, 57: -8.933684809576345, 58: -8.735591176647514, 59: -8.018966025544966, 60: -8.251491970213372, 61: -7.591719594359237, 62: -8.169659881166858, 63: -13.592664636171698, 64: -18.517523458456985, 65: -7.647396572993602, 66: -8.122981037851925, 67: -7.607787319678067, 68: -6.85029094445494, 69: -7.8268821327130365, 70: -3.584786591677161, 71: -7.455406192077973, 72: -12.796283502572146, 73: -14.108127281277586, 74: -9.354916969477486, 75: -11.387537567890853, 76: -9.621909492152557, 77: -7.324393429417677, 78: -5.3046964808341945, 79: -2.380092582080244, 80: 0.24948924158195362, 81: -2.3239789120665026, 82: -3.730042357127322, 83: -3.438792347649683, 89: -5.062878214511315, 90: -11.02462566385297, 91: -12.265613551943261, 92: -13.855648206100362, 93: -14.933092020258243, 94: -15.282826131998245}" \
--interaction_first="RealAgnosticResidualInteractionBlock" \
--interaction="RealAgnosticResidualInteractionBlock" \
--num_interactions=2 \
--correlation=3 \
--max_ell=3 \
--r_max=6.0 \
--max_L=2 \
--num_channels=128 \
--num_radial_basis=10 \
--MLP_irreps="16x0e" \
--scaling='rms_forces_scaling' \
--lr=0.005 \
--weight_decay=1e-8 \
--ema \
--ema_decay=0.995 \
--scheduler_patience=5 \
--batch_size=2 \
--valid_batch_size=4 \
--max_num_epochs=300 \
--patience=30 \
--amsgrad \
--device="cuda" \
--default_dtype="float64" \
--seed=1 \
--clip_grad=100 \
--keep_checkpoints \
--restart_latest \
--save_cpu
And log:
2024-02-20 08:42:06.620 INFO: MACE version: 0.3.4
2024-02-20 08:42:06.621 INFO: Configuration: Namespace(name='2024-01-07-mace-128-L2', seed=1, log_dir='logs', model_dir='.', checkpoints_dir='checkpoints', results_dir='results', downloads_dir='downloads', device='cuda', default_dtype='float64', log_level='INFO', error_table='PerAtomMAE', model='MACE', r_max=6.0, radial_type='bessel', num_radial_basis=10, num_cutoff_basis=5, pair_repulsion=False, distance_transform=False, interaction='RealAgnosticResidualInteractionBlock', interaction_first='RealAgnosticResidualInteractionBlock', max_ell=3, correlation=3, num_interactions=2, MLP_irreps='16x0e', radial_MLP='[64, 64, 64]', hidden_irreps='128x0e + 128x1o', num_channels=128, max_L=2, gate='silu', scaling='rms_forces_scaling', avg_num_neighbors=1, compute_avg_num_neighbors=True, compute_stress=True, compute_forces=True, train_file='/home/bruno.focassio/mace_large_model/uip/train_2024/training_data.xyz', valid_file=None, valid_fraction=0.05, test_file='/home/bruno.focassio/mace_large_model/uip/train_2024/test_data.xyz', E0s='{1: -3.667168021358939, 2: -1.3320953124042916, 3: -3.482100566595956, 4: -4.736697230897597, 5: -7.724935420523256, 6: -8.405573550273285, 7: -7.360100452662763, 8: -7.28459863421322, 9: -4.896490881731322, 10: 1.3917755836700962e-12, 11: -2.7593613569762425, 12: -2.814047612069227, 13: -4.846881245288104, 14: -7.694793133351899, 15: -6.9632957911820235, 16: -4.672630400190884, 17: -2.8116892814008096, 18: -0.06259504416367478, 19: -2.6176454856894793, 20: -5.390461060484104, 21: -7.8857952163517675, 22: -10.268392986214433, 23: -8.665147785496703, 24: -9.233050763772013, 25: -8.304951520770791, 26: -7.0489865771593765, 27: -5.577439766222147, 28: -5.172747618813715, 29: -3.2520726958619472, 30: -1.2901611618726314, 31: -3.527082192997912, 32: -4.70845955030298, 33: -3.9765109025623238, 34: -3.886231055836541, 35: -2.5184940099633986, 36: 6.766947645687137, 37: -2.5634958965928316, 38: -4.938005211501922, 39: -10.149818838085771, 40: -11.846857579882572, 41: -12.138896361658485, 42: -8.791678800595722, 43: -8.78694939675911, 44: -7.78093221529871, 45: -6.850021409115055, 46: -4.891019073240479, 47: -2.0634296773864045, 48: -0.6395695518943755, 49: -2.7887442084286693, 50: -3.818604275441892, 51: -3.587068329278862, 52: -2.8804045971118897, 53: -1.6355986842433357, 54: 9.846723842807721, 55: -2.765284507132287, 56: -4.990956432167774, 57: -8.933684809576345, 58: -8.735591176647514, 59: -8.018966025544966, 60: -8.251491970213372, 61: -7.591719594359237, 62: -8.169659881166858, 63: -13.592664636171698, 64: -18.517523458456985, 65: -7.647396572993602, 66: -8.122981037851925, 67: -7.607787319678067, 68: -6.85029094445494, 69: -7.8268821327130365, 70: -3.584786591677161, 71: -7.455406192077973, 72: -12.796283502572146, 73: -14.108127281277586, 74: -9.354916969477486, 75: -11.387537567890853, 76: -9.621909492152557, 77: -7.324393429417677, 78: -5.3046964808341945, 79: -2.380092582080244, 80: 0.24948924158195362, 81: -2.3239789120665026, 82: -3.730042357127322, 83: -3.438792347649683, 89: -5.062878214511315, 90: -11.02462566385297, 91: -12.265613551943261, 92: -13.855648206100362, 93: -14.933092020258243, 94: -15.282826131998245}', energy_key='energy', forces_key='forces', virials_key='virials', stress_key='stress', dipole_key='dipole', charges_key='charges', loss='universal', forces_weight=10.0, swa_forces_weight=100.0, energy_weight=1.0, swa_energy_weight=1000.0, virials_weight=1.0, swa_virials_weight=10.0, stress_weight=100.0, swa_stress_weight=10.0, dipole_weight=1.0, swa_dipole_weight=1.0, config_type_weights='{"Default":1.0}', huber_delta=0.01, optimizer='adam', batch_size=2, valid_batch_size=4, lr=0.005, swa_lr=0.001, weight_decay=1e-08, amsgrad=True, scheduler='ReduceLROnPlateau', lr_factor=0.8, scheduler_patience=5, lr_scheduler_gamma=0.9993, swa=False, start_swa=None, ema=True, ema_decay=0.995, max_num_epochs=300, patience=30, foundation_model='2024-01-07-mace-128-L2.model', foundation_model_readout=True, eval_interval=1, keep_checkpoints=True, restart_latest=True, save_cpu=True, clip_grad=100.0, wandb=False, wandb_project='', wandb_entity='', wandb_name='', wandb_log_hypers=['num_channels', 'max_L', 'correlation', 'lr', 'swa_lr', 'weight_decay', 'batch_size', 'max_num_epochs', 'start_swa', 'energy_weight', 'forces_weight'])
2024-02-20 08:42:06.654 INFO: CUDA version: 11.3, CUDA device: 0
2024-02-20 08:42:07.619 INFO: Loaded 1491 training configurations from '/home/bruno.focassio/mace_large_model/uip/train_2024/training_data.xyz'
2024-02-20 08:42:07.620 INFO: Using random 5.0% of training set for validation
2024-02-20 08:42:07.739 INFO: Loaded 253 test configurations from '/home/bruno.focassio/mace_large_model/uip/train_2024/test_data.xyz'
2024-02-20 08:42:07.740 INFO: Total number of configurations: train=1417, valid=74, tests=[Default: 253]
2024-02-20 08:42:07.756 INFO: AtomicNumberTable: (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 89, 90, 91, 92, 93, 94)
2024-02-20 08:42:07.756 INFO: Atomic Energies not in training file, using command line argument E0s
2024-02-20 08:42:07.757 INFO: Atomic energies: [-3.667168021358939, -1.3320953124042916, -3.482100566595956, -4.736697230897597, -7.724935420523256, -8.405573550273285, -7.360100452662763, -7.28459863421322, -4.896490881731322, 1.3917755836700962e-12, -2.7593613569762425, -2.814047612069227, -4.846881245288104, -7.694793133351899, -6.9632957911820235, -4.672630400190884, -2.8116892814008096, -0.06259504416367478, -2.6176454856894793, -5.390461060484104, -7.8857952163517675, -10.268392986214433, -8.665147785496703, -9.233050763772013, -8.304951520770791, -7.0489865771593765, -5.577439766222147, -5.172747618813715, -3.2520726958619472, -1.2901611618726314, -3.527082192997912, -4.70845955030298, -3.9765109025623238, -3.886231055836541, -2.5184940099633986, 6.766947645687137, -2.5634958965928316, -4.938005211501922, -10.149818838085771, -11.846857579882572, -12.138896361658485, -8.791678800595722, -8.78694939675911, -7.78093221529871, -6.850021409115055, -4.891019073240479, -2.0634296773864045, -0.6395695518943755, -2.7887442084286693, -3.818604275441892, -3.587068329278862, -2.8804045971118897, -1.6355986842433357, 9.846723842807721, -2.765284507132287, -4.990956432167774, -8.933684809576345, -8.735591176647514, -8.018966025544966, -8.251491970213372, -7.591719594359237, -8.169659881166858, -13.592664636171698, -18.517523458456985, -7.647396572993602, -8.122981037851925, -7.607787319678067, -6.85029094445494, -7.8268821327130365, -3.584786591677161, -7.455406192077973, -12.796283502572146, -14.108127281277586, -9.354916969477486, -11.387537567890853, -9.621909492152557, -7.324393429417677, -5.3046964808341945, -2.380092582080244, 0.24948924158195362, -2.3239789120665026, -3.730042357127322, -3.438792347649683, -5.062878214511315, -11.02462566385297, -12.265613551943261, -13.855648206100362, -14.933092020258243, -15.282826131998245]
2024-02-20 08:42:14.706 INFO: UniversalLoss(energy_weight=1.000, forces_weight=10.000, stress_weight=100.000)
2024-02-20 08:42:15.579 INFO: Average number of neighbors: 62.46425785800387
2024-02-20 08:42:15.580 INFO: Selected the following outputs: {'energy': True, 'forces': True, 'virials': False, 'stress': True, 'dipoles': False}
2024-02-20 08:42:15.580 INFO: Building model
2024-02-20 08:42:15.588 INFO: Hidden irreps: 128x0e+128x1o+128x2e
2024-02-20 08:42:21.455 INFO: Using foundation model 2024-01-07-mace-128-L2.model as initial checkpoint.
2024-02-20 08:42:21.523 WARNING: No SWA checkpoint found, while SWA is enabled. Compare the swa_start parameter and the latest checkpoint.
2024-02-20 08:42:21.524 INFO: Loading checkpoint: checkpoints/2024-01-07-mace-128-L2_run-1_epoch-199.pt
Traceback (most recent call last):
File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/mace/cli/run_train.py", line 525, in main
opt_start_epoch = checkpoint_handler.load_latest(
File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/mace/tools/checkpoint.py", line 210, in load_latest
result = self.io.load_latest(swa=swa, device=device)
File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/mace/tools/checkpoint.py", line 171, in load_latest
path = self._get_latest_checkpoint_path(swa=swa)
File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/mace/tools/checkpoint.py", line 152, in _get_latest_checkpoint_path
return latest_checkpoint_info.path
UnboundLocalError: local variable 'latest_checkpoint_info' referenced before assignment
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/bruno.focassio/codes/mace-foundations/scripts/run_train.py", line 6, in <module>
main()
File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/mace/cli/run_train.py", line 531, in main
opt_start_epoch = checkpoint_handler.load_latest(
File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/mace/tools/checkpoint.py", line 215, in load_latest
self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict)
File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/mace/tools/checkpoint.py", line 40, in load_checkpoint
state.model.load_state_dict(checkpoint["model"], strict=strict) # type: ignore
File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1604, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ScaleShiftMACE:
size mismatch for interactions.0.skip_tp.weight: copying a param with shape torch.Size([1458176]) from checkpoint, the shape in current model is torch.Size([5832704]).
size mismatch for interactions.0.skip_tp.output_mask: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([2048]).
size mismatch for interactions.1.linear_up.weight: copying a param with shape torch.Size([16384]) from checkpoint, the shape in current model is torch.Size([49152]).
size mismatch for interactions.1.linear_up.output_mask: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1152]).
size mismatch for interactions.1.conv_tp.output_mask: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([9088]).
size mismatch for interactions.1.conv_tp_weights.layer3.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([64, 2176]).
size mismatch for interactions.1.linear.weight: copying a param with shape torch.Size([65536]) from checkpoint, the shape in current model is torch.Size([278528]).
size mismatch for products.0.linear.weight: copying a param with shape torch.Size([16384]) from checkpoint, the shape in current model is torch.Size([49152]).
size mismatch for products.0.linear.output_mask: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1152]).
I see here you have a checkpoint named : INFO: Loading checkpoint: checkpoints/2024-01-07-mace-128-L2_run-1_epoch-199.pt Can you please change your run name so it does not load this checkpoint.
Indeed not loading the checkpoint from the pt file works. However, I was actually trying to load the checkpoint available on hugging-face: https://huggingface.co/cyrusyc/mace-universal/tree/main/pretrained If the checkpoint was generated from the model training, shouldn't it be compatible to continue training?
however if you do that you will need to give the full hypers of the model
I there technical reason why it has to be this way? Could you come up with a way to save the hypers with the model, so this kind of problem happens less?
It would be ideal to save all the hyper params in a kwargs dictionary and have flexible parsing like MACE(**kwargs)
. However it seems pretty hard for the current way of implementation... and the train cli and argparser are still evolving...
I am curious about what is the founction of --fundation_model_readout, does it greatly influence the finetuning result?
Describe the bug Dear developers,
I've been trying to fine-tune a large mace-mp-0 model. However I'm running into some problems.
I'm using the foundational branch
First, what works?
Using the following training script input works:
python /home/bruno.focassio/codes/mace-foundations/scripts/run_train.py \ --name="mace_fine_tunning_100" \ --foundation_model="large" \ --train_file="training_data.xyz" \ --test_file="test_data.xyz" \ --valid_fraction=0.05 \ --energy_weight=1 \ --forces_weight=10 \ --compute_stress=True \ --stress_weight=100 \ --stress_key='stress' \ --eval_interval=1 \ --error_table='PerAtomMAE' \ --E0s="average" \ --interaction_first="RealAgnosticResidualInteractionBlock" \ --interaction="RealAgnosticResidualInteractionBlock" \ --scaling='rms_forces_scaling' \ --lr=0.005 \ --weight_decay=1e-8 \ --ema \ --ema_decay=0.995 \ --scheduler_patience=5 \ --batch_size=2 \ --valid_batch_size=4 \ --max_num_epochs=100 \ --patience=20 \ --amsgrad \ --device="cuda" \ --seed=1 \ --clip_grad=100 \ --keep_checkpoints \ --restart_latest \ --save_cpu
And you can check the log:There are a couple of questions from this:
large="http://tinyurl.com/5f5yavf3", # MACE_MPtrj_2022.9.model
, however I find that the 0.3.3 release uses the model fromlarge="https://figshare.com/ndownloader/files/43117273",
In that regard, how can I use the most up-to-date one? I've tried replacing the
--foundation_model="large" \
by the path of several different models, including the one from https://figshare.com/ndownloader/files/43117273 and even the one available on Hugging Face: 2024-01-07-mace-128-L2_epoch-199.modelThe linear embedding block (first one) its showing:
(node_embedding): LinearNodeEmbeddingBlock( (linear): Linear(6x0e -> 128x0e | 768 weights) )
because my fine-tunning training set only has 6 elements, however the full large mace-mp-0 model is supposed to be(node_embedding): LinearNodeEmbeddingBlock( (linear): Linear(89x0e -> 128x0e | 11392 weights) )
How can I keep the original shape of the linear embedding? And only fine-tune for the elements on my training set? Should I create dummy samples with the single atoms with the average atomic energy?I have tried to use the 2024-01-07-mace-128-L2_epoch-199 model, however, when I try that I run into very similar problems. When I try to use the checkpoint available on hugging-face for that model, it gives me a size mismatch between the model I'm loading and the model from the checkpoint, which I suspect is something related to the above questions.
Any help is appreciated