divelab / DIG

A library for graph deep learning research
https://diveintographs.readthedocs.io/
GNU General Public License v3.0
1.84k stars 281 forks source link

KeyError: 'conv1.lin.weight' #121

Closed MishanyaGeniyInformatiki closed 2 years ago

MishanyaGeniyInformatiki commented 2 years ago

Hi there! I use DIG. I'm trying to run a code from the 'Tutorial for GNN Explainability'.

import torch
import os.path as osp
from dig.xgraph.dataset import SynGraphDataset
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
dataset = SynGraphDataset('./datasets', 'BA_shapes')
dataset.data.x = dataset.data.x.to(torch.float32)
dataset.data.x = dataset.data.x[:, :1]
dim_node = dataset.num_node_features
dim_edge = dataset.num_edge_features
num_classes = dataset.num_classes
def check_checkpoints(root='./'):
    if osp.exists(osp.join(root, 'checkpoints')):
        return
    url = ('https://github.com/divelab/DIG_storage/raw/main/xgraph/checkpoints.zip')
    path = download_url(url, root)
    extract_zip(path, root)
    os.unlink(path)

from dig.xgraph.models import GCN_2l
model = GCN_2l(model_level='node', dim_node=dim_node, dim_hidden=300, num_classes=num_classes)
model.to(device)
check_checkpoints()
ckpt_path = osp.join('checkpoints', 'ba_shapes', 'GCN_2l', '0', 'GCN_2l_best.ckpt')
model.load_state_dict(torch.load(ckpt_path)['state_dict'])

Versions of installed packages: python 3.10.4 pwtorch 1.12.0 torch-geometric 2.0.4 dive-into-graphs 0.1.2

But I get an error: Снимок экрана от 2022-07-19 01-17-42_cut-photo ru What can it mean? The model does not find the necessary weights in GCN_2l_best.ckpt?

IdrissEQ commented 2 years ago

Apparently in PyG2.0 the GCN parameters self.weight were changed to a linear layer self.lin. whose weights shape also happen to be transposed so you have to both change the loaded dict keys as well as transpose the corresponding weights. The following works for me:

params_dict=torch.load(ckpt_path)['state_dict']
params_dict["convs.0.lin.weight"] = params_dict["convs.0.weight"].t()
params_dict["conv1.lin.weight"] = params_dict["conv1.weight"].t()

model.load_state_dict(params_dict,strict=False)

The "strict = False " is to account for the now unexpected keys.

CM-BF commented 2 years ago

Thank you for reminding me. The Tutorial example is a little out of date. We will fix it soon. Please refer to the PyG2.0 version: https://github.com/divelab/DIG/blob/dig-stable/examples/xgraph/gnnexplainer.ipynb

CM-BF commented 2 years ago

Since there is no more discussion, I will close this issue.