mir-group / nequip

NequIP is a code for building E(3)-equivariant interatomic potentials
https://www.nature.com/articles/s41467-022-29939-5
MIT License
565 stars 124 forks source link

πŸ› [BUG] Using initialize_from_state #415

Closed JunsuAndrewLee closed 3 months ago

JunsuAndrewLee commented 3 months ago

Describe the bug Hello, I tried to re-train from pretrained NequIP .pth file by adding the following lines to .yaml config file.

model_builders:
 - SimpleIrrepsConfig         # update the config with all the irreps for the network if using the simplified `l_max` / `num_features` / `parity` syntax
 - EnergyModel                # build a full NequIP model
 - PerSpeciesRescale          # add per-atom / per-species scaling and shifting to the NequIP model before the total energy sum
 - ForceOutput                # wrap the energy model in a module that uses autodifferention to compute the forces
 - RescaleEnergyEtc           # wrap the entire model in the appropriate global rescaling of the energy, forces, etc.
 - initialize_from_state
initial_model_state: /home/users/user/nequip_test/7.3.sampled_water-deployed.pth

Then I got the following error..

Replace string dataset_forces_rms to 469.19586181640625
Replace string dataset_per_atom_total_energy_mean to -15927.2939453125
Atomic outputs are scaled by: [H, O: 469.195862], shifted by [H, O: -15927.293945].
Replace string dataset_forces_rms to 469.19586181640625
Initially outputs are globally scaled by: 469.19586181640625, total_energy are globally shifted by None.
/home/users/user/anaconda3/envs/nequip1/lib/python3.10/site-packages/torch/serialization.py:1007: UserWarning: 'torch.load' received a zip file that looks like a TorchScript archive dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to silence this warning)
  warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"
Traceback (most recent call last):
  File "/home/users/user/anaconda3/envs/nequip1/bin/nequip-train", line 8, in <module>
    sys.exit(main())
  File "/home/users/user/anaconda3/envs/nequip1/lib/python3.10/site-packages/nequip/scripts/train.py", line 72, in main
    trainer = fresh_start(config)
  File "/home/users/user/anaconda3/envs/nequip1/lib/python3.10/site-packages/nequip/scripts/train.py", line 163, in fresh_start
    final_model = model_from_config(
  File "/home/users/user/anaconda3/envs/nequip1/lib/python3.10/site-packages/nequip/model/_build.py", line 96, in model_from_config
    model = builder(**params)
  File "/home/users/user/anaconda3/envs/nequip1/lib/python3.10/site-packages/nequip/model/_weight_init.py", line 26, in initialize_from_state
    return load_model_state(
  File "/home/users/user/anaconda3/envs/nequip1/lib/python3.10/site-packages/nequip/model/_weight_init.py", line 52, in load_model_state
    model.load_state_dict(state, strict=config.get(_prefix + "_strict", True))
  File "/home/users/user/anaconda3/envs/nequip1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2104, in load_state_dict
    raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.")
TypeError: Expected state_dict to be dict-like, got <class 'torch.jit._script.RecursiveScriptModule'>.

Can you help me correcting the error?

Thank you.

To Reproduce

Expected behavior

Environment (please complete the following information):

Additional context Add any other context about the problem here.

Linux-cpp-lisp commented 3 months ago

Hi @JunsuAndrewLee ,

Thanks for your interest in our code!

initial_model_state should point to a checkpoint from a training directory, like best_model.pth, and not a deployed model.

JunsuAndrewLee commented 3 months ago

Oh, my bad. Thank you for your kind answer!

Linux-cpp-lisp commented 3 months ago

No worries, glad this resolved your issue!