Closed PatWalters closed 1 year ago
Hey Pat -- I had the same issue; for MPNN.py, GAT.py and GCN.py try this in the forward pass:
out = self.transformer(x=node_feats, index=batch, edge_index=edge_index)
This is for mpnn.py; GAT and GCN are slightly different.
What is going on here is explicitly naming the args.
At least for my venv --
print(torch.__version__) #1.13.1+cu117
print(torch_geometric.__version__) #2.2.0
GraphMultisetTransformer may be deprecated and explicit naming of args makes a difference.
More can be found here -- cheers! https://github.com/pyg-team/pytorch_geometric/issues/3443
Thanks! I'll give this a try.
I've been unable to run the example. It doesn't seem possible to directly reproduce the environment you used, and I'm getting an exception when I try to run your code using an environment I created with.
When I try to run the README example. I get an exception on model.train(data.x_train, data.y_train)