CederGroupHub / chgnet

Pretrained universal neural network potential for charge-informed atomistic modeling https://chgnet.lbl.gov
https://doi.org/10.1038/s42256-023-00716-3
Other
232 stars 62 forks source link

[Bug]: Error when loading trainer state using trainer.load #168

Closed naveenmohandas closed 3 months ago

naveenmohandas commented 3 months ago

Email (Optional)

n.k.mohandas@tudelft.nl

Version

v0.3.8

Which OS(es) are you using?

What happened?

  1. I was attempting to restart a training from a particular epoch but when I try to load it using trainer.load() I get the KeyError: 'decay_fraction'. The code snippet I use to load the checkpoint is
    
    from pathlib import Path

from chgnet.trainer import Trainer from chgnet.model import CHGNet

def load_trainer(trainer_path: Path): """Load the trainer from the path"""

load the trainer

trainer = Trainer.load(trainer_path)
return trainer

if name == "main": file_path = '/output/epoch4_e0_f0_sNA_mNA.pth.tar' trainer = load_trainer(file_path)


2. When I saved checkpoint using torch.load the scheduler_params was empty.  

The dummy script representing my initial training script is given below  (in case I am doing something wrong there). 

On another note: not sure if i need to raise it as another issue, I also run into issues when I set `save_test_result=True` in the traininer. I get the error `   raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type int64 is not JSON serializable`.

### Code snippet

```python
# The code snippet for training
#
from pathlib import Path
import numpy as np

from pymatgen.core import Structure, Lattice
from chgnet.data.dataset import StructureData, get_train_val_test_loader
from chgnet.trainer import Trainer
from chgnet.model import CHGNet

def get_structures():
    model = CHGNet.load()
    structures = []
    energies_per_atom = []
    forces = []
    a_0 = 2.85
    lattice = Lattice.cubic(a_0)
    mo_structure = Structure(lattice, ["Mo", "Mo"], [[0, 0, 0], [0.5, 0.5, 0.5]])
    mo_structure.make_supercell([3,3,3])
    for _ in range(100):
        structure = mo_structure.copy()
        # stretch the cell by a small amount
        structure.apply_strain(np.random.uniform(-0.1, 0.1, size=3))
        # perturb all atom positions by a small amount
        structure.perturb(0.1) 
        pred = model.predict_structure(structure)

        structures.append(structure)
        energies_per_atom.append(pred["e"] + np.random.uniform(-0.1, 0.1, size=1))
        forces.append(pred["f"] + np.random.uniform(-0.01, 0.01, size=pred["f"].shape))

    return structures, energies_per_atom, forces

def train_chgnet(output_dir=, epochs):

    structures, energies_per_atom, forces = get_structures()
    print("Len of dataset:",len(structures))
    dataset = StructureData(
        structures=structures,
        energies=energies_per_atom,
        forces = forces
    )

    # load the trainer
    train_loader, val_loader, test_loader = get_train_val_test_loader(
        dataset, batch_size=5, train_ratio=0.9, val_ratio=0.05
    )
    train_test_split = { "train_loader": train_loader,
                         "test_loader": test_loader,
                         "val_loader": val_loader
                       }

    model = CHGNet.load()

    trainer = Trainer(
                    model=model,
                    targets='ef',
                    energy_loss_ratio=1,
                    force_loss_ratio=1,
                    stress_loss_ratio=0.1,
                    mag_loss_ratio=0.1,
                    optimizer='Adam',
                    weight_decay=0,
                    scheduler='CosLR',
                    scheduler_params={'decay_fraction': 0.5e-2},
                    criterion='Huber',
                    delta=0.1,
                    epochs=5,
                    starting_epoch=0,
                    learning_rate=5e-3,
                    use_device='cpu',
                    print_freq=1
                )

    for param in model.parameters():
        param.requires_grad = True

    Path(output_dir).mkdir(parents=True, exist_ok=True)

    trainer.train(train_loader, 
                  val_loader, 
                  test_loader, 
                  train_composition_model=False,
                  save_dir=output_dir,
                  )

if __name__ == "__main__":
    train_chgnet(output_dir='output/error', epochs=2)

Log output

CHGNet initialized with 412,525 parameters
Loaded model params = 412,525
Traceback (most recent call last):
  File "{path}/trainer_reload.py", line 31, in <module>
    trainer = load_trainer(file_path)
              ^^^^^^^^^^^^^^^^^^^^^^^
  File "{path}/trainer_reload.py", line 25, in load_trainer
    trainer = Trainer.load(trainer_path)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "{path}/lib/python3.11/site-packages/chgnet/trainer/trainer.py", line 545, in load
    trainer = Trainer(model=model, **state["trainer_args"])
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "{path}/lib/python3.11/site-packages/chgnet/trainer/trainer.py", line 147, in __init__
    decay_fraction = scheduler_params.pop("decay_fraction")
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
KeyError: 'decay_fraction'

Code of Conduct

janosh commented 3 months ago

i also recently noticed this error

TypeError: Object of type int64 is not JSON serializable

both should be easy to fix

naveenmohandas commented 3 months ago

Hi Janosh, Thanks for the quick fix. I am a bit new to this so could you let me know how to get the fix?

I tried upgrading using pip pip install --upgrade --force-reinstall chgnet --no-cache-dir but for some reason numpy gets upgraded to 2.0.0 then I manually have to downgrade to 1.26.4 which was the version working for me earlier. But I don't get the new fixes when I upgrade this way.

Is there some other way to get the new fixes?

janosh commented 3 months ago

we have to make a new pypi release first for pip install chgnet to contain the fixes. in the meantime, you would have to build from source by git cloning and running python setup.py build_ext --inplace to manually compile (see test.yml#L39)