pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.08k stars 3.63k forks source link

to_hetero() tries to activate None #4404

Open fierval opened 2 years ago

fierval commented 2 years ago

🐛 Describe the bug

Converting the following simple model with to_hetero():

    def __init__(self, config: ELConfig):
        super().__init__()

        self.config = config
        self.conv_layers = 2

        self.conv1 = TransformerConv(
            256,
            256,
            heads=4,
            dropout=0.6,
            edge_dim=4
        )

        self.conv2 = TransformerConv(
            256 * 4,
            256,
            heads=1,
            dropout=0.6,
            edge_dim=4
        )

    def forward(self, x, edge_index, edge_attr):

        x = self.conv1(x, edge_index, edge_attr).relu()
        x = self.conv2(x, edge_index, edge_attr).relu()

        return x

With metadata: (['category', 'vertex'], [('category', 'assign', 'vertex'), ('vertex', 'groups', 'vertex'), ('vertex', 'continues', 'vertex'), ('vertex', 'keyof', 'vertex')])

Yields the following:

def forward(self, x, edge_index, edge_attr):
    x__category = x.get('category')
    x__vertex = x.get('vertex');  x = None
    edge_index__category__assign__vertex = edge_index.get(('category', 'assign', 'vertex'))
    edge_index__vertex__groups__vertex = edge_index.get(('vertex', 'groups', 'vertex'))
    edge_index__vertex__continues__vertex = edge_index.get(('vertex', 'continues', 'vertex'))
    edge_index__vertex__keyof__vertex = edge_index.get(('vertex', 'keyof', 'vertex'));  edge_index = None
    edge_attr__category__assign__vertex = edge_attr.get(('category', 'assign', 'vertex'))
    edge_attr__vertex__groups__vertex = edge_attr.get(('vertex', 'groups', 'vertex'))
    edge_attr__vertex__continues__vertex = edge_attr.get(('vertex', 'continues', 'vertex'))
    edge_attr__vertex__keyof__vertex = edge_attr.get(('vertex', 'keyof', 'vertex'));  edge_attr = None
    conv1__vertex1 = self.conv1.category__assign__vertex((x__category, x__vertex), edge_index__category__assign__vertex, edge_attr__category__assign__vertex);  x__category = None
    conv1__vertex2 = self.conv1.vertex__groups__vertex((x__vertex, x__vertex), edge_index__vertex__groups__vertex, edge_attr__vertex__groups__vertex)
    conv1__vertex3 = self.conv1.vertex__continues__vertex((x__vertex, x__vertex), edge_index__vertex__continues__vertex, edge_attr__vertex__continues__vertex)
    conv1__vertex4 = self.conv1.vertex__keyof__vertex((x__vertex, x__vertex), edge_index__vertex__keyof__vertex, edge_attr__vertex__keyof__vertex);  x__vertex = None
    conv1__vertex5 = torch.add(conv1__vertex1, conv1__vertex2);  conv1__vertex1 = conv1__vertex2 = None
    conv1__vertex6 = torch.add(conv1__vertex3, conv1__vertex4);  conv1__vertex3 = conv1__vertex4 = None
    conv1__vertex = torch.add(conv1__vertex5, conv1__vertex6);  conv1__vertex5 = conv1__vertex6 = None
    relu__category = None.relu()
    relu__vertex = conv1__vertex.relu();  conv1__vertex = None
    conv2__vertex1 = self.conv2.category__assign__vertex((relu__category, relu__vertex), edge_index__category__assign__vertex, edge_attr__category__assign__vertex);  relu__category = edge_index__category__assign__vertex = edge_attr__category__assign__vertex = None
    conv2__vertex2 = self.conv2.vertex__groups__vertex((relu__vertex, relu__vertex), edge_index__vertex__groups__vertex, edge_attr__vertex__groups__vertex);  edge_index__vertex__groups__vertex = edge_attr__vertex__groups__vertex = None
    conv2__vertex3 = self.conv2.vertex__continues__vertex((relu__vertex, relu__vertex), edge_index__vertex__continues__vertex, edge_attr__vertex__continues__vertex);  edge_index__vertex__continues__vertex = edge_attr__vertex__continues__vertex = None
    conv2__vertex4 = self.conv2.vertex__keyof__vertex((relu__vertex, relu__vertex), edge_index__vertex__keyof__vertex, edge_attr__vertex__keyof__vertex);  relu__vertex = edge_index__vertex__keyof__vertex = edge_attr__vertex__keyof__vertex = None
    conv2__vertex5 = torch.add(conv2__vertex1, conv2__vertex2);  conv2__vertex1 = conv2__vertex2 = None
    conv2__vertex6 = torch.add(conv2__vertex3, conv2__vertex4);  conv2__vertex3 = conv2__vertex4 = None
    conv2__vertex = torch.add(conv2__vertex5, conv2__vertex6);  conv2__vertex5 = conv2__vertex6 = None
    relu_1__category = None.relu()
    relu_1__vertex = conv2__vertex.relu();  conv2__vertex = None
    return {'category': relu_1__category, 'vertex': relu_1__vertex}

Then relu_category = None.relu() causes it to crash.

Environment

rusty1s commented 2 years ago

The issue is that there is no edge type pointing to category. Upgrading to torch-geometric==2.0.4 should at least warn you about this.

fierval commented 2 years ago

Figured as much but thought code like None.relu() was worth reporting.

rusty1s commented 2 years ago

Yes, definitely :) I think the current workaround of warning the user is okay, but I agree that in your case it should definitely crash prior to model execution. I need to look into torch.fx to see if there is some way to check for this.

stayones commented 2 years ago

So every node in the heterogeneous graph should have an edge pointing to it? How should we resolve this problem? I tried to change the graph to Undirected but it didn't work. Should we add reverse edge type to our data by ourselves?

rusty1s commented 2 years ago

Yes, otherwise certain node types will not get properly updated during message passing. The ToUndirected transform should take care of that. Let me know if you encounter any issues wit that.