pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.09k stars 3.63k forks source link

[BUG] RuntimeError when Tracing a Graph-UNet with Torch JIT #1229

Open liaopeiyuan opened 4 years ago

liaopeiyuan commented 4 years ago

šŸ› Bug

There seems to be a type error when tracing a Graph-UNet with Torch JIT.

[omitted]/torch_sparse/matmul.py in spspmm(src, other, reduce)
     94     if reduce == 'sum' or reduce == 'add':
---> 95         return spspmm_sum(src, other)

[omitted]/torch_sparse/matmul.py in spspmm_sum(src, other)
     82     rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
---> 83         rowptrA, colA, valueA, rowptrB, colB, valueB, K)

RuntimeError: unsupported output type: Tensor?

To Reproduce

Steps to reproduce the behavior:

import os.path as osp

import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GraphUNet
from torch_geometric.utils import dropout_adj

dataset = 'Cora'
path = osp.join('..', 'data', dataset)
dataset = Planetoid(path, dataset)
data = dataset[0]

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        pool_ratios = [2000 / data.num_nodes, 0.5]
        self.unet = GraphUNet(1433, 32, 7,
                              depth=3, pool_ratios=pool_ratios)

    def forward(self, x, edge_index):
        e, _ = dropout_adj(edge_index, p=0.2,
                                    force_undirected=True,
                                    num_nodes=2708,
                                    training=self.training)
        d1 = F.dropout(x, p=0.92, training=self.training)

        u = self.unet(d1, e)
        print(type(u))
        r = F.log_softmax(u, dim=1)
        return r

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.001)

inp = (data.x.cuda(), data.edge_index.cuda())
scripted_model = torch.jit.trace(model, inp).eval()

Expected behavior

Correctly returns a value of type torch.jit.TopLevelTracedModule

Environment

Additional context

I'm working on compilation and deployment of models written with torch_geometric to be deployed on Jetson Nano with tvm, which requires JIT tracing first.

rusty1s commented 4 years ago

Thanks for this issue. I will look into this. We are currently in the process of providing jit support for all PyTorch modules, so please stay tuned!

liaopeiyuan commented 4 years ago

Do you potentially have a direction to where it may go wrong? I will be working closely on integrating torch_geometric with tvm the following months, so I may be able to help with some of the issues.

rusty1s commented 4 years ago

We are currently in the process of making all convs jittable, see here, but tracing should generally work fine. In your case, it might be a problem with torch-sparse.