When I run model.load_state_dict(torch.load(ckpt_path)['state_dict']) of subgraphx.ipynb, I got the error KeyError: 'conv1.lin.weight'.
Even if I set strict=False, it doesn't work.
My PyTorch version is 1.11.0 and PyTorch Geometric version is 2.0.4.
The specific error logs:
KeyError: 'conv1.lin.weight'
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Input In [7], in <cell line: 13>()
11 check_checkpoints()
12 ckpt_path = osp.join('checkpoints', 'ba_shapes', 'GCN_2l', '0', 'GCN_2l_best.ckpt')
---> 13 model.load_state_dict(torch.load(ckpt_path)['state_dict'], strict=False)
File /mnt/test_li/miniconda3/envs/dig/lib/python3.8/site-packages/torch/nn/modules/module.py:1483, in Module.load_state_dict(self, state_dict, strict)
1480 if child is not None:
1481 load(child, prefix + name + '.')
-> 1483 load(self)
1484 del load
1486 if strict:
File /mnt/test_li/miniconda3/envs/dig/lib/python3.8/site-packages/torch/nn/modules/module.py:1481, in Module.load_state_dict.<locals>.load(module, prefix)
1479 for name, child in module._modules.items():
1480 if child is not None:
-> 1481 load(child, prefix + name + '.')
File /mnt/test_li/miniconda3/envs/dig/lib/python3.8/site-packages/torch/nn/modules/module.py:1481, in Module.load_state_dict.<locals>.load(module, prefix)
1479 for name, child in module._modules.items():
1480 if child is not None:
-> 1481 load(child, prefix + name + '.')
File /mnt/test_li/miniconda3/envs/dig/lib/python3.8/site-packages/torch/nn/modules/module.py:1477, in Module.load_state_dict.<locals>.load(module, prefix)
1475 def load(module, prefix=''):
1476 local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
-> 1477 module._load_from_state_dict(
1478 state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
1479 for name, child in module._modules.items():
1480 if child is not None:
File /mnt/test_li/miniconda3/envs/dig/lib/python3.8/site-packages/torch/nn/modules/module.py:1380, in Module._load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
1348 r"""Copies parameters and buffers from :attr:`state_dict` into only
1349 this module, but not its descendants. This is called on every submodule
1350 in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
(...)
1377 :meth:`~torch.nn.Module.load_state_dict`
1378 """
1379 for hook in self._load_state_dict_pre_hooks.values():
-> 1380 hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
1382 persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
1383 local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
File /mnt/test_li/miniconda3/envs/dig/lib/python3.8/site-packages/torch_geometric/nn/dense/linear.py:140, in Linear._lazy_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
137 def _lazy_load_hook(self, state_dict, prefix, local_metadata, strict,
138 missing_keys, unexpected_keys, error_msgs):
--> 140 weight = state_dict[prefix + 'weight']
141 if is_uninitialized_parameter(weight):
142 self.in_channels = -1
KeyError: 'conv1.lin.weight'
When I run
model.load_state_dict(torch.load(ckpt_path)['state_dict'])
of subgraphx.ipynb, I got the errorKeyError: 'conv1.lin.weight'
. Even if I setstrict=False
, it doesn't work.My PyTorch version is
1.11.0
and PyTorch Geometric version is2.0.4
.The specific error logs: