mir-group / nequip

NequIP is a code for building E(3)-equivariant interatomic potentials
https://www.nature.com/articles/s41467-022-29939-5
MIT License
611 stars 135 forks source link

🐛 [BUG] Cannot restart run with different dataset #349

Open pablo-unzueta opened 1 year ago

pablo-unzueta commented 1 year ago

Describe the bug I am trying to restart a training instance using load_model_state or initialize_from_state. I keep receiving an error that scale_by from the state_dict is empty while in the new run it is of size 1: RuntimeError: Error(s) in loading state_dict for RescaleOutput: size mismatch for scale_by: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([1]).

I also tried load_model_state_strict: false but that yielded the same error

To Reproduce Attached are the yaml files I used. I start the training with energy_only.yaml. After training for some time, I want to restart using a different dataset using the restart.yaml file.

Expected behavior Training should resume according to #343 or #297

Environment (please complete the following information):

Additional context Add any other context about the problem here. configs.zip

Linux-cpp-lisp commented 1 year ago

Hi @pablo-unzueta ,

Thanks for your interest in our code!

I'm not sure why this is happening, but initialize_from_state_strict: false would be the correct option in this case.

You could also try adding global_rescale_scale: 0.0 to your restart config...

pablo-unzueta commented 1 year ago

Hi @Linux-cpp-lisp

Thanks for you advice! I tried initialize_from_state_strict: false and received the same error: RuntimeError: Error(s) in loading state_dict for RescaleOutput: size mismatch for scale_by: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([1]).

I also tried global_rescale_scale: 0.0 and received the following error:

ValueError: Global energy scaling was very low: 0.0. If dataset values were used, does the dataset contain insufficient variation? Maybe try disabling global scaling with global_scale=None.

pablo-unzueta commented 1 year ago

I couldn't figure out how to set global_scale=None in the config, but I just set the global_scale_scale: 1.1e-6 so it wouldn't raise the ValueError due to it being lower than the threshold. Does this seem ok?

Linux-cpp-lisp commented 1 year ago

Yes, in princple if you just set it to some number, it will get overriden by the loaded state dict, but I'm still not totally sure why this is happening at all.

If you do this, does it pass sanity checks? Like is the starting validation and training loss the same as before if you restart with the same dataset?