lightaime / deep_gcns_torch

Pytorch Repo for DeepGCNs (ICCV'2019 Oral, TPAMI'2021), DeeperGCN (arXiv'2020) and GNN1000(ICML'2021): https://www.deepgcns.org
MIT License
1.13k stars 155 forks source link

Cannot load pretrained model #56

Closed ccandee closed 3 years ago

ccandee commented 3 years ago

Hi,

When I run python test.py --use_gpu --conv_encode_edge --add_virtual_node --mlp_layers 2 --num_layers 14 --dataset ogbg-molpcba --block res+ --gcn_aggr softmax_sg --t 0.1 --model_load_path ogbg_molpcba_pretrained_model.pth for ogbg_molpcba dataset, it returns shape mismatch errors. Can you have a check?

Elizabeth1997 commented 3 years ago

Hi, thanks for your interest in DeeperGCN and sorry for replying a little bit late. I double checked and reran the test.py with the command provided by us and I got the results successfully {'Train': 0.523825154367263, 'Validation': 0.28569296697443686, 'Test': 0.27966258091068474}. I also downloaded the ogbg_molpcba_pretrained_model.pth from Google Drive just in case we might make some mistakes on uploading but I still got the expected results. Could you for example paste the specific error here or show us the model printed so we can better figure out what happened.

ccandee commented 3 years ago

Hi, thanks for your answer. I reinstalled by git clone in another directory and rerun the pretrained model. This time the error is:

Namespace(add_virtual_node=True, batch_size=32, block='res+', conv='gen', conv_encode_edge=True, dataset='ogbg-molpcba', device=0, dropout=0.5, epochs=300, feature='full', gcn_aggr='softmax_sg', graph_pooling='mean', hidden_channels=256, learn_msg_scale=False, learn_p=False, learn_t=False, lr=0.01, mlp_layers=2, model_load_path='ogbg_molpcba_pretrained_model.pth', model_save_path='model_ckpt', msg_norm=False, norm='batch', num_layers=14, num_tasks=128, num_workers=0, p=1.0, save='EXP', t=0.1, use_gpu=True) The number of layers 14 Aggr aggregation method softmax_sg block: res+ LN/BN->ReLU->GraphConv->Res Traceback (most recent call last): File "test.py", line 84, in main() File "test.py", line 67, in main model = DeeperGCN(args) File "/home/muhan/projects/tmp/deep_gcns_torch/examples/ogb/ogbg_mol/model.py", line 61, in init self.mlp_virtualnode_list.append(MLP([hidden_channels, hidden_channels], TypeError: init() got an unexpected keyword argument 'last_act'

So I changed the 'last_lin' in MLP() to 'last_act' in 'torch_nn.py' and 'torch_vertex.py' to make the keyword argument consistent. Then rerunning the pretrained model I got the following error:

Namespace(add_virtual_node=True, batch_size=32, block='res+', conv='gen', conv_encode_edge=True, dataset='ogbg-molpcba', device=0, dropout=0.5, epochs=300, feature='full', gcn_aggr='softmax_sg', graph_pooling='mean', hidden_channels=256, learn_msg_scale=False, learn_p=False, learn_t=False, lr=0.01, mlp_layers=2, model_load_path='ogbg_molpcba_pretrained_model.pth', model_save_path='model_ckpt', msg_norm=False, norm='batch', num_layers=14, num_tasks=128, num_workers=0, p=1.0, save='EXP', t=0.1, use_gpu=True) The number of layers 14 Aggr aggregation method softmax_sg block: res+ LN/BN->ReLU->GraphConv->Res Traceback (most recent call last): File "test.py", line 84, in main() File "test.py", line 69, in main model.load_state_dict(torch.load(args.model_load_path)['model_state_dict']) File "/home/muhan/anaconda3/envs/seal/lib/python3.8/site-packages/torch/nn/modules/module.py", line 846, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for DeeperGCN: Unexpected key(s) in state_dict: "mlp_virtualnode_list.0.3.weight", "mlp_virtualnode_list.0.3.bias", "mlp_virtualnode_list.0.4.weight", "mlp_virtualnode_list.0.4.bias", "mlp_virtualnode_list.0.4.running_mean", "mlp_virtualnode_list.0.4.running_var", "mlp_virtualnode_list.0.4.num_batches_tracked", "mlp_virtualnode_list.1.3.weight", "mlp_virtualnode_list.1.3.bias", "mlp_virtualnode_list.1.4.weight", "mlp_virtualnode_list.1.4.bias", "mlp_virtualnode_list.1.4.running_mean", "mlp_virtualnode_list.1.4.running_var", "mlp_virtualnode_list.1.4.num_batches_tracked", "mlp_virtualnode_list.2.3.weight", "mlp_virtualnode_list.2.3.bias", "mlp_virtualnode_list.2.4.weight", "mlp_virtualnode_list.2.4.bias", "mlp_virtualnode_list.2.4.running_mean", "mlp_virtualnode_list.2.4.running_var", "mlp_virtualnode_list.2.4.num_batches_tracked", "mlp_virtualnode_list.3.3.weight", "mlp_virtualnode_list.3.3.bias", "mlp_virtualnode_list.3.4.weight", "mlp_virtualnode_list.3.4.bias", "mlp_virtualnode_list.3.4.running_mean", "mlp_virtualnode_list.3.4.running_var", "mlp_virtualnode_list.3.4.num_batches_tracked", "mlp_virtualnode_list.4.3.weight", "mlp_virtualnode_list.4.3.bias", "mlp_virtualnode_list.4.4.weight", "mlp_virtualnode_list.4.4.bias", "mlp_virtualnode_list.4.4.running_mean", "mlp_virtualnode_list.4.4.running_var", "mlp_virtualnode_list.4.4.num_batches_tracked", "mlp_virtualnode_list.5.3.weight", "mlp_virtualnode_list.5.3.bias", "mlp_virtualnode_list.5.4.weight", "mlp_virtualnode_list.5.4.bias", "mlp_virtualnode_list.5.4.running_mean", "mlp_virtualnode_list.5.4.running_var", "mlp_virtualnode_list.5.4.num_batches_tracked", "mlp_virtualnode_list.6.3.weight", "mlp_virtualnode_list.6.3.bias", "mlp_virtualnode_list.6.4.weight", "mlp_virtualnode_list.6.4.bias", "mlp_virtualnode_list.6.4.running_mean", "mlp_virtualnode_list.6.4.running_var", "mlp_virtualnode_list.6.4.num_batches_tracked", "mlp_virtualnode_list.7.3.weight", "mlp_virtualnode_list.7.3.bias", "mlp_virtualnode_list.7.4.weight", "mlp_virtualnode_list.7.4.bias", "mlp_virtualnode_list.7.4.running_mean", "mlp_virtualnode_list.7.4.running_var", "mlp_virtualnode_list.7.4.num_batches_tracked", "mlp_virtualnode_list.8.3.weight", "mlp_virtualnode_list.8.3.bias", "mlp_virtualnode_list.8.4.weight", "mlp_virtualnode_list.8.4.bias", "mlp_virtualnode_list.8.4.running_mean", "mlp_virtualnode_list.8.4.running_var", "mlp_virtualnode_list.8.4.num_batches_tracked", "mlp_virtualnode_list.9.3.weight", "mlp_virtualnode_list.9.3.bias", "mlp_virtualnode_list.9.4.weight", "mlp_virtualnode_list.9.4.bias", "mlp_virtualnode_list.9.4.running_mean", "mlp_virtualnode_list.9.4.running_var", "mlp_virtualnode_list.9.4.num_batches_tracked", "mlp_virtualnode_list.10.3.weight", "mlp_virtualnode_list.10.3.bias", "mlp_virtualnode_list.10.4.weight", "mlp_virtualnode_list.10.4.bias", "mlp_virtualnode_list.10.4.running_mean", "mlp_virtualnode_list.10.4.running_var", "mlp_virtualnode_list.10.4.num_batches_tracked", "mlp_virtualnode_list.11.3.weight", "mlp_virtualnode_list.11.3.bias", "mlp_virtualnode_list.11.4.weight", "mlp_virtualnode_list.11.4.bias", "mlp_virtualnode_list.11.4.running_mean", "mlp_virtualnode_list.11.4.running_var", "mlp_virtualnode_list.11.4.num_batches_tracked", "mlp_virtualnode_list.12.3.weight", "mlp_virtualnode_list.12.3.bias", "mlp_virtualnode_list.12.4.weight", "mlp_virtualnode_list.12.4.bias", "mlp_virtualnode_list.12.4.running_mean", "mlp_virtualnode_list.12.4.running_var", "mlp_virtualnode_list.12.4.num_batches_tracked".

Elizabeth1997 commented 3 years ago

Hi, I just found where the problem comes from. I did minor changes on my local repo and I forgot to update. I already submitted a pull request to fix it. Please kindly change the code in model.py as follows before the merge if you want. I think you are able to get the results successfully after that. Thanks for finding this bug!

for layer in range(self.num_layers - 1):
                self.mlp_virtualnode_list.append(MLP([hidden_channels]*3, norm=norm))
lightaime commented 3 years ago

Hi @ccandee. Sorry for the mistake in our code. You can pull the new version of the code and try it again. Let us know if it works fine or not.

ccandee commented 3 years ago

@Elizabeth1997 @lightaime Thanks for the quick fix! It works now after pulling.