divelab / DIG

A library for graph deep learning research
https://diveintographs.readthedocs.io/
GNU General Public License v3.0
1.84k stars 280 forks source link

Issue on running subgraphx.ipynb example #98

Closed gouliang1997 closed 2 years ago

gouliang1997 commented 2 years ago

When I run model.load_state_dict(torch.load(ckpt_path)['state_dict']) of subgraphx.ipynb, I got the error KeyError: 'conv1.lin.weight'. Even if I set strict=False, it doesn't work.

My PyTorch version is 1.11.0 and PyTorch Geometric version is 2.0.4.

The specific error logs:

KeyError: 'conv1.lin.weight'
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Input In [7], in <cell line: 13>()
     11 check_checkpoints()
     12 ckpt_path = osp.join('checkpoints', 'ba_shapes', 'GCN_2l', '0', 'GCN_2l_best.ckpt')
---> 13 model.load_state_dict(torch.load(ckpt_path)['state_dict'], strict=False)

File /mnt/test_li/miniconda3/envs/dig/lib/python3.8/site-packages/torch/nn/modules/module.py:1483, in Module.load_state_dict(self, state_dict, strict)
   1480         if child is not None:
   1481             load(child, prefix + name + '.')
-> 1483 load(self)
   1484 del load
   1486 if strict:

File /mnt/test_li/miniconda3/envs/dig/lib/python3.8/site-packages/torch/nn/modules/module.py:1481, in Module.load_state_dict.<locals>.load(module, prefix)
   1479 for name, child in module._modules.items():
   1480     if child is not None:
-> 1481         load(child, prefix + name + '.')

File /mnt/test_li/miniconda3/envs/dig/lib/python3.8/site-packages/torch/nn/modules/module.py:1481, in Module.load_state_dict.<locals>.load(module, prefix)
   1479 for name, child in module._modules.items():
   1480     if child is not None:
-> 1481         load(child, prefix + name + '.')

File /mnt/test_li/miniconda3/envs/dig/lib/python3.8/site-packages/torch/nn/modules/module.py:1477, in Module.load_state_dict.<locals>.load(module, prefix)
   1475 def load(module, prefix=''):
   1476     local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
-> 1477     module._load_from_state_dict(
   1478         state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
   1479     for name, child in module._modules.items():
   1480         if child is not None:

File /mnt/test_li/miniconda3/envs/dig/lib/python3.8/site-packages/torch/nn/modules/module.py:1380, in Module._load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
   1348 r"""Copies parameters and buffers from :attr:`state_dict` into only
   1349 this module, but not its descendants. This is called on every submodule
   1350 in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
   (...)
   1377         :meth:`~torch.nn.Module.load_state_dict`
   1378 """
   1379 for hook in self._load_state_dict_pre_hooks.values():
-> 1380     hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
   1382 persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
   1383 local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())

File /mnt/test_li/miniconda3/envs/dig/lib/python3.8/site-packages/torch_geometric/nn/dense/linear.py:140, in Linear._lazy_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
    137 def _lazy_load_hook(self, state_dict, prefix, local_metadata, strict,
    138                     missing_keys, unexpected_keys, error_msgs):
--> 140     weight = state_dict[prefix + 'weight']
    141     if is_uninitialized_parameter(weight):
    142         self.in_channels = -1

KeyError: 'conv1.lin.weight'
gouliang1997 commented 2 years ago

I soved the problem by using PyTorch 1.8.0 and Torch Geometric 1.7.0 (torch-sparse 0.6.12)