pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.16k stars 3.64k forks source link

Nonetype error during forward pass of hetrogenous model #3772

Closed sidhantls closed 2 years ago

sidhantls commented 2 years ago

🐛 Bug

Created a directed Hetrogenous graph datadataset and hetrogenous GAT model. There is an error occuring during the forward pass when the graph dataset is modified a certain way.

When each node type in the hetrogenous graph has an edge going to and from it, there are no errors. But when some of the node types only have edges going from it, but no edge going to it, I get the below error during the forward pas

Traceback (most recent call last):
  File "/home/miniconda3/lib/python3.8/site-packages/torch/fx/graph_module.py", line 505, in wrapped_call
    return cls_call(self, *args, **kwargs)
  File "/home/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "<eval_with_key_5>", line 29, in forward
    add__d = None + lin1__d;  lin1__d = None
TypeError: unsupported operand type(s) for +: 'NoneType' and 'Tensor'

Call using an FX-traced Module, line 29 of the traced Module's generated forward function:
    add__p = conv1__p + lin1__p;  conv1__p = lin1__p = None
    add__d = None + lin1__d;  lin1__d = None

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    add__s = None + lin1__s;  lin1__s = None

    relu__c = add__c.relu();  add__c = None

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-27-67b9136ef67a> in <module>
      1 model.train()
----> 2 out = model(graph_data.x_dict, graph_data.edge_index_dict)

~/miniconda3/lib/python3.8/site-packages/torch/fx/graph_module.py in wrapped_call(self, *args, **kwargs)
    511                     print(generate_error_message(topmost_framesummary),
    512                           file=sys.stderr)
--> 513                 raise e.with_traceback(None)
    514 
    515         cls.__call__ = wrapped_call

TypeError: unsupported operand type(s) for +: 'NoneType' and 'Tensor'

I'm suspected this has something to do with a particular key not being found in the dataset according to #1103, although this had something to do with the dataloader. I'm not sure why it wouldn't work though, when there are some nodes in the graph that have only edges going from it and none going to it.

The graph dataset where the forward pass does work looks something like this:

HeteroData(
  a={ x=[327724, 70] },
  b={ x=[103, 70] },
  q={ x=[7, 70] },
  r={ x=[42912, 70] },
  (a, same_transition, a)={
    edge_index=[2, 77],
    edge_attr=[77, 2]
  },
  (a, diff_transition, b)={
    edge_index=[2, 88],
    edge_attr=[88, 2]
  },
  (b, same_transition, b)={
    edge_index=[2, 99],
    edge_attr=[99, 2]
  },
  (b, diff_transition, a)={
    edge_index=[2, 66],
    edge_attr=[66, 2]
  },
  (q, diff_transition, a)={
    edge_index=[2, 5],
    edge_attr=[5, 2]
  },
  (a, diff_transition, q)={
    edge_index=[2, 5],
    edge_attr=[5, 2]
  },
  (r, diff_transition, a)={
    edge_index=[2, 10],
    edge_attr=[10, 2]
  },
  (a, diff_transition, r)={
    edge_index=[2, 10],
    edge_attr=[10, 2]
  }

The graph dataset where the error does occur is this. here,there are only edges from q to a and r to a but not a to q or a to r, like the previous dataset had. This is what seems to cause the error.

HeteroData(
  a={ x=[327724, 70] },
  b={ x=[103, 70] },
  q={ x=[7, 70] },
  r={ x=[42912, 70] },
  (a, same_transition, a)={
    edge_index=[2, 77],
    edge_attr=[77, 2]
  },
  (a, diff_transition, b)={
    edge_index=[2, 88],
    edge_attr=[88, 2]
  },
  (b, same_transition, b)={
    edge_index=[2, 99],
    edge_attr=[99, 2]
  },
  (b, diff_transition, a)={
    edge_index=[2, 66],
    edge_attr=[66, 2]
  },
  (q, diff_transition, a)={
    edge_index=[2, 5],
    edge_attr=[5, 2]
  },
  (r, diff_transition, a)={
    edge_index=[2, 10],
    edge_attr=[10, 2]
  }
)

Graph model used:

class GAT(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GATConv((-1, -1), hidden_channels, add_self_loops=False)
        self.lin1 = Linear(-1, hidden_channels)
        self.conv2 = GATConv((-1, -1), out_channels, add_self_loops=False)
        self.lin2 = Linear(-1, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index) + self.lin1(x)
        x = x.relu()
        x = self.conv2(x, edge_index) + self.lin2(x)
        return x
model = GAT(hidden_channels=100, out_channels=100)
model = to_hetero(model, graph_data.metadata(), aggr='sum')
graph_data.metadata()

Any reason as to why

To Reproduce

Unable to at the moment because graph as custom created. And work on creating a similar dataset if required to be reproducible

Expected behavior

Forward pass occurs without error in both cases

Environment

rusty1s commented 2 years ago

Yes, this is actually intended. There already exists some dicussion about this, see, e.g., https://github.com/pyg-team/pytorch_geometric/discussions/3760. I just merged a check for this into master, see https://github.com/pyg-team/pytorch_geometric/pull/3775.

This is really a challenging problem to tackle. What would you expect to happen in this case? How can we update a node type during GNN propagation if it does not receive any new information?

sidhantls commented 2 years ago

interesting. I would actually expect it to still work because it will be a source of information to the other nodes, hence it can be still useful.

For example, even if Node type C has no edges going to it, but only edges going from it, currently it's not compatible because Node C's node type will never be updated. But it would still be able to improve the feature representations of it neighbouring nodes through message passing. So it could be better if these nodes are still connected in the graph?

@rusty1s what do you think?

rusty1s commented 2 years ago

Yes, I understand that this can be useful (that's why we enforce bidirectional edges in the first place). However, without these bidirectional edges, it is not clear how a isolated node type actually updates its information in layer L (since it does not receive any new information). I think the current strategy to inform the user about this is better than to fix this magically in the background (e.g., by applying a Linear transformation on isolated node types), but I'm happy to take any alternative suggestions.