divelab / DIG

A library for graph deep learning research
https://diveintographs.readthedocs.io/
GNU General Public License v3.0
1.89k stars 284 forks source link

GCN_2l pretrained poor prediction accuracy #85

Closed Cupcee closed 2 years ago

Cupcee commented 2 years ago

Hey, I want to use a pretrained model to play around with knowledge distillation from GNN to MLP. So I wrote the code for this. I use the BA_shapes dataset with 700 nodes and I import the pretrained GCN_2l model for distillation just like in the generic explanation example code:

model = GCN_2l(model_level='node', dim_node=dim_node, dim_hidden=300, num_classes=num_classes)
model.to(device)
check_checkpoints()
ckpt_path = osp.join('checkpoints', 'ba_shapes', 'GCN_2l', '0', 'GCN_2l_best.ckpt')
model.load_state_dict(torch.load(ckpt_path)['state_dict'])

the part of the code doing the GNN predictions is here:

# predict node labels with GNN
with torch.no_grad():
  logits = model(data.x, data.edge_index)
  z = logits[data.train_mask]
  gnn_pred = z.argmax(dim=-1)
  gnn_acc = float(gnn_pred.eq(y).sum().item()) / len(gnn_pred)
  print(f"[GNN BASELINE ACC]: {gnn_acc}")

The thing is, the prediction accuracy for this GNN with this dataset seems to be really poor, output from above: [GNN BASELINE ACC]: 0.42857142857142855

Am I doing something wrong, or is this just not a very good model at predicting this dataset? Or am I overestimating the accuracy such a model should reach with this dataset?

Oceanusity commented 2 years ago

Hello, could you provide your torch_geometric verison? Due to the version problem, when loading the provided GCN_2l model saved with 1.6.0, some model parameters can't be loaded into the GCN. Therefore, I recommend you use pytorch_geometirc with 1.6.0 or change the parameter name of the state_dict if you want to load with a higher torch_geometric version.

Cupcee commented 2 years ago

I now changed from 1.7.0 to 1.6.0, but this does not meaningfully change the result: [GNN BASELINE ACC]: 0.45714285714285713

Oceanusity commented 2 years ago

Hello, it seems like the trained model has a label shift.

# predict node labels with GNN
with torch.no_grad():
  logits = model(data.x, data.edge_index)
  z = logits[data.train_mask]
  gnn_pred = z.argmax(dim=-1)
  shift_y = y.clone()
  shift_y[y==1] = 2
  shift_y[y==2] = 3
  shift_y[y==3] = 1
  gnn_acc = float(gnn_pred.eq(shift_y).sum().item()) / len(gnn_pred)
  print(f"[GNN BASELINE ACC]: {gnn_acc}")

And the ACC is 0.9571.

Cupcee commented 2 years ago

Alright thanks, that solves it! Might want to fix this model at some point 😄