ACEsuit / mace

MACE - Fast and accurate machine learning interatomic potentials with higher order equivariant message passing.
Other
554 stars 205 forks source link

Error at end of training #672

Closed cecilia-hong closed 1 week ago

cecilia-hong commented 3 weeks ago

Hi, hope you are doing well.

Quite often, at the end of my training, I get this error:

2024-11-04 18:46:59.098 INFO: ===========RESULTS===========
2024-11-04 18:46:59.098 INFO: Computing metrics for training, validation, and test sets

-------------------------------------------------------------------------------
mace_run_train 8 <module>
sys.exit(main())

run_train.py 63 main
run(args)

run_train.py 695 run
epoch = checkpoint_handler.load_latest(

checkpoint.py 210 load_latest
result = self.io.load_latest(swa=swa, device=device)

checkpoint.py 171 load_latest
path = self._get_latest_checkpoint_path(swa=swa)

checkpoint.py 149 _get_latest_checkpoint_path
latest_checkpoint_info = max(

ValueError:
max() arg is an empty sequence

and as a result, I do not get a model from the training.

I have not found out what causes this as I usually use the same input script and vary the model size. I noticed a similar issue but that issue was because the training stopped before the swa starting epoch which is not the case with mine. I had also tried increasing the max_num_epochs and restarted the training but once I have reached that epoch, the same error occurs.

My inputs to the training is:

mace_run_train \
    --name="model" \
    --restart_latest \
    --train_file="train.xyz" \
    --valid_fraction=0.05 \
    --test_file="valid.xyz" \
    --E0s='{1:-11.54248, 6:-142.28746, 8:-426.0449, 31:-2017.672}' \
    --energy_key="energy_dft" \
    --forces_key="forces_dft" \
    --model="MACE" \
    --num_interactions=2 \
    --max_ell=2 \
    --hidden_irreps="16x0e + 16x1o + 16x2e" \
    --lr=0.002 \
    --num_cutoff_basis=3 \
    --correlation=3 \
    --r_max=5.0 \
    --batch_size=5 \
    --max_num_epochs=1000 \
    --patience=200 \
    --eval_int="1" \
    --ema \
    --ema_decay=0.995 \
    --amsgrad \
    --error_table="PerAtomRMSE" \
    --default_dtype="float32" \
    --start_swa="300" \
    --device=cuda \
    --seed=1234

(I have checked my MACE version to be 0.3.7 this time!)

Many thanks in advance!

ilyes319 commented 2 weeks ago

Hello could share your full log file, that would greatly help. Thank you!

cecilia-hong commented 2 weeks ago

model_run-1234_debug.log

Hi, here it is, many thanks!

cecilia-hong commented 2 weeks ago

Hello, I see that this request has been grouped into similar issues to do with multi-head finetuning. I just want to add that I had tried the same training with multiheads_finetuning=False but still get the same error. Not sure if that helps at all. Many thanks!

beckobert commented 2 weeks ago

I think the problem is related to swa and checkpoint handling. Here's what I assume is happening: You specify swa_start, but never enable swa itself (which should ideally not be possible). So the Checkpoint handler gets a start value to save swa checkpoints from epoch 300. swa is never used (or mentioned in the log), but the checkpoints after epoch 300 are still logged under the swa name. Because old checkpoints are deleted unless specified otherwise (--keep_checkpoints), according to the log at epoch 300 the old non-swa checkpoint is deleted and a new "fake-swa" checkpoint is created. Therefore, when MACE searches for a non-swa checkpoint at the end of training, it can't find any and throws an error. Can you please try running training with --swa as an additional keyword.

gabor1 commented 2 weeks ago

hm.. can you propose a better handling of the various option states?

beckobert commented 1 week ago

As a user, it is reasonable to expect, that setting --swa_start to a value, results in swa actually being used during training from the set epoch onwards.

The mildest solution would be thatt if args.swa_start is set to a value, there should be a check if swa itself is activated and print out a warning, if it isn't.

The more radical approach would be scrapping args.swa completely and running swa, if a starting epoch is set and not, if args.swa is None. To me that looks like an efficient solution, but there might be some consequences, that I am not aware of.

The in-between solution would be automatically setting args.swa to True, if args.start_swa was set and logging it.

cecilia-hong commented 1 week ago

Hello, just want to confirm that adding --swa does indeed solve my error. Many thanks!