GentleZhu / Shift-Robust-GNNs

"Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data" (NeurIPS 21')
47 stars 7 forks source link

Errors when running GAT and GraphSAGE #1

Open simonzhang00 opened 2 years ago

simonzhang00 commented 2 years ago


I get the following errors when running

  1. GAT:
    Using backend: pytorch
    number of classes 7
    Using CUDA
    Traceback (most recent call last):
    File "", line 729, in <module>
    micro_f1, macro_f1, out_acc = main(args, [])
    File "", line 463, in main
    File ".../Shift-Robust-GNNs/", line 254, in __init__
    self.layers.append(GATConv(in_feats, n_hidden, num_heads=num_heads, feat_drop=dropout, activation=activation))
    File ".../python3.7/site-packages/dgl/nn/pytorch/conv/", line 160, in __init__
    self._in_src_feats, out_feats * num_heads, bias=False)
    File ".../python3.7/site-packages/torch/nn/modules/", line 81, in __init__
    self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
    TypeError: empty() received an invalid combination of arguments - got (tuple, dtype=NoneType, device=NoneType), but expected one of:
    * (tuple of ints size, *, tuple of names names, torch.memory_format memory_format, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
    * (tuple of ints size, *, torch.memory_format memory_format, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)

and 2. GraphSAGE:

Using backend: pytorch
number of classes 7
Using CUDA
Traceback (most recent call last):
  File "", line 729, in <module>
    micro_f1, macro_f1, out_acc = main(args, [])
  File "", line 544, in main
    total_loss = loss + 1 * cmd(model.h[idx_train, :], model.h[iid_train, :])
  File ".../python3.7/site-packages/torch/nn/modules/", line 1131, in __getattr__
    type(self).__name__, name))
AttributeError: 'GraphSAGE' object has no attribute 'h'

relevant packages form pip freeze: dgl-cu102==0.6.1 torch==1.9.0+cu102


GentleZhu commented 2 years ago


I've fixed the GraphSAGE issue. I will look into the first one, it seems it's the DGL library change. I will test and get back to you later.