materialsvirtuallab / matgl

Graph deep learning library for materials
BSD 3-Clause "New" or "Revised" License
232 stars 57 forks source link

[Bug]: Cannot load from checkpoint #237

Closed JunsuAndrewLee closed 4 months ago

JunsuAndrewLee commented 4 months ago

Email (Optional)

No response

Version

1.0.0

Which OS(es) are you using?

What happened?

  1. I am attempting to load from checkpoint, which is saved during fit (as shown in Code snippet)
  2. However it keeps failing with the message: argument 'model' is missing (as shown in Log output)
  3. I found self.save_hyperparameters(ignore=["model"]) in src/matgl/utils/training.py , so I reinstalled it modifying the line to self.save_hyperparameters()
  4. But the result is the same...

Please let me know if any mistake I have made. Thank you!

Code snippet

checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath="./checkpoints/",
        filename="chkpoint_{epoch:03d}_{val_loss:.4f}",
        every_n_epochs=1,
        save_top_k=-1,
    )
trainer = pl.Trainer(max_epochs=10, accelerator="gpu", callbacks=[checkpoint_callback], logger=logger, inference_mode=False)

from __future__ import annotations
import os
import shutil
import numpy as np
import matgl
import pytorch_lightning as pl
from matgl.utils.training import PotentialLightningModule

my_lit_module = PotentialLightningModule.load_from_checkpoint("PATH_TO_CHECKPOINT")

Log output

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[6], line 1
----> 1 my_lit_module = PotentialLightningModule.load_from_checkpoint("PATH_TO_CHECKPOINT")

File ~/anaconda3/envs/mgldebug2/lib/python3.10/site-packages/pytorch_lightning/core/module.py:1561, in LightningModule.load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
   1480 @_restricted_classmethod
   1481 def load_from_checkpoint(
   1482     cls,
   (...)
   1487     **kwargs: Any,
   1488 ) -> Self:
   1489     r"""Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments
   1490     passed to ``__init__``  in the checkpoint under ``"hyper_parameters"``.
   1491 
   (...)
   1559 
   1560     """
-> 1561     loaded = _load_from_checkpoint(
   1562         cls,  # type: ignore[arg-type]
   1563         checkpoint_path,
   1564         map_location,
   1565         hparams_file,
   1566         strict,
   1567         **kwargs,
   1568     )
   1569     return cast(Self, loaded)

File ~/anaconda3/envs/mgldebug2/lib/python3.10/site-packages/pytorch_lightning/core/saving.py:89, in _load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
     87     return _load_state(cls, checkpoint, **kwargs)
     88 if issubclass(cls, pl.LightningModule):
---> 89     model = _load_state(cls, checkpoint, strict=strict, **kwargs)
     90     state_dict = checkpoint["state_dict"]
     91     if not state_dict:

File ~/anaconda3/envs/mgldebug2/lib/python3.10/site-packages/pytorch_lightning/core/saving.py:156, in _load_state(cls, checkpoint, strict, **cls_kwargs_new)
    152 if not cls_spec.varkw:
    153     # filter kwargs according to class init unless it allows any argument via kwargs
    154     _cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name}
--> 156 obj = cls(**_cls_kwargs)
    158 if isinstance(obj, pl.LightningModule):
    159     # give model a chance to load something
    160     obj.on_load_checkpoint(checkpoint)

TypeError: PotentialLightningModule.__init__() missing 1 required positional argument: 'model'

Code of Conduct

kenko911 commented 4 months ago

Hi @JunsuAndrewLee, you should add your model class in the load_check_point like that

model = M3GNet(input_args) my_lit_module = PotentialLightningModule.load_from_checkpoint("PATH_TO_CHECKPOINT", model=model)

This should work and please let me know if any questions

JonathanSchmidt1 commented 3 months ago

Using the method you described I get a different error now (this checkpoint is still from version 0.9.2 I tried to load it both with the original as well as a newer version, dgl was 1.1.3). Any idea?:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/cluster/apps/nss/gcc-8.2.0/python/3.11.2/x86_64/lib64/python3.11/site-packages/pytorch_lightning/core/module.py", line 1531, in load_from_checkpoint
    loaded = _load_from_checkpoint(
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/cluster/apps/nss/gcc-8.2.0/python/3.11.2/x86_64/lib64/python3.11/site-packages/pytorch_lightning/core/saving.py", line 60, in _load_from_checkpoint
    checkpoint = pl_load(checkpoint_path, map_location=map_location)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cluster/apps/nss/gcc-8.2.0/python/3.11.2/x86_64/lib64/python3.11/site-packages/lightning_fabric/utilities/cloud_io.py", line 51, in _load
    return torch.load(f, map_location=map_location)  # type: ignore[arg-type]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cluster/apps/nss/gcc-8.2.0/python/3.11.2/x86_64/lib64/python3.11/site-packages/torch/serialization.py", line 809, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cluster/apps/nss/gcc-8.2.0/python/3.11.2/x86_64/lib64/python3.11/site-packages/torch/serialization.py", line 1172, in _load
    result = unpickler.load()
             ^^^^^^^^^^^^^^^^
  File "/cluster/apps/nss/gcc-8.2.0/python/3.11.2/x86_64/lib64/python3.11/pickle.py", line 1213, in load
    dispatch[key[0]](self)
  File "/cluster/apps/nss/gcc-8.2.0/python/3.11.2/x86_64/lib64/python3.11/pickle.py", line 1254, in load_binpersid
    self.append(self.persistent_load(pid))
                ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cluster/apps/nss/gcc-8.2.0/python/3.11.2/x86_64/lib64/python3.11/site-packages/torch/serialization.py", line 1142, in persistent_load
    typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cluster/apps/nss/gcc-8.2.0/python/3.11.2/x86_64/lib64/python3.11/site-packages/torch/serialization.py", line 1116, in load_tensor
    wrap_storage=restore_location(storage, location),
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cluster/apps/nss/gcc-8.2.0/python/3.11.2/x86_64/lib64/python3.11/site-packages/torch/serialization.py", line 1089, in restore_location
    result = map_location(storage, location)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cluster/apps/nss/gcc-8.2.0/python/3.11.2/x86_64/lib64/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/cluster/home/sjonathan/dgl/lib64/python3.11/site-packages/matgl/models/_m3gnet.py", line 236, in forward
    node_types = g.ndata["node_type"]
                 ^^^^^^^
AttributeError: 'torch.storage.UntypedStorage' object has no attribute 'ndata'. Did you mean: '_cdata'?
matthewkuner commented 4 days ago

@kenko911 I am also interested in a follow-up to @JonathanSchmidt1 's comment