pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.15k stars 3.64k forks source link

TorchScript conversion tutorial not working for pytorch_geometric #3104

Closed ronuchit closed 3 years ago

ronuchit commented 3 years ago

🐛 Bug

Hi, I am trying to follow the tutorial at https://pytorch-geometric.readthedocs.io/en/latest/notes/jit.html, but I am running into an error. Code and traceback are provided below.

To Reproduce

Run the following code:

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
class GNN(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, 64).jittable()
        self.conv2 = GCNConv(64, out_channels).jittable()

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)
dataset = Planetoid(root='/tmp/Cora', name='Cora')
model = GNN(dataset.num_features, dataset.num_classes)
model = torch.jit.script(model)

and the following traceback should arise:

RuntimeError:
Tried to access nonexistent attribute or method 'dtype' of type 'Optional[Tensor]'.:
  File "/Users/rohan/.virtualenvs/coding/lib/python3.7/site-packages/torch_sparse/matmul.py", line 99
        valueA = valueA.to(torch.float)
    if valueB is not None:
        valueB = valueB.to(valueA.dtype)
                           ~~~~~~~~~~~~ <--- HERE
    M, K = src.sparse_size(0), other.sparse_size(1)
    rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
'spspmm_sum' is being compiled since it was called from 'spspmm'
  File "/Users/rohan/.virtualenvs/coding/lib/python3.7/site-packages/torch_sparse/matmul.py", line 116
           reduce: str = "sum") -> SparseTensor:
    if reduce == 'sum' or reduce == 'add':
        return spspmm_sum(src, other)
               ~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
        raise NotImplementedError
'spspmm' is being compiled since it was called from 'matmul'
  File "/Users/rohan/.virtualenvs/coding/lib/python3.7/site-packages/torch_sparse/matmul.py", line 139
        return spmm(src, other, reduce)
    elif isinstance(other, SparseTensor):
        return spspmm(src, other, reduce)
               ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    raise ValueError
'matmul' is being compiled since it was called from 'GCNConvJittable_20110e.message_and_aggregate'
  File "/Users/rohan/.virtualenvs/coding/lib/python3.7/site-packages/torch_geometric/nn/conv/gcn_conv.py", line 194
    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        return matmul(adj_t, x, reduce=self.aggr)
               ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
'GCNConvJittable_20110e.message_and_aggregate' is being compiled since it was called from 'GCNConvJittable_20110e.propagate__1'
  File "/var/folders/5v/9njppv2x52q_c237jzkpls7m0000gn/T/rohan_pyg_jit/tmppo0w0nzw.py", line 165
        if self.fuse:
            if isinstance(edge_index, SparseTensor):
                out = self.message_and_aggregate(edge_index, x=in_kwargs.x)
                ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
                return self.update(out)

'GCNConvJittable_20110e.propagate__1' is being compiled since it was called from 'GCNConvJittable_20110e.forward__0'
  File "/var/folders/5v/9njppv2x52q_c237jzkpls7m0000gn/T/rohan_pyg_jit/tmppo0w0nzw.py", line 213
        # propagate_type: (x: Tensor, edge_weight: OptTensor)
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
                             size=None)
                                  ~~~~ <--- HERE

        if self.bias is not None:

Expected behavior

No output is expected. The jit compilation should work as intended.

Environment

Additional context

N/A

rusty1s commented 3 years ago

Can you try to re-install torch-sparse? This should fix the error. There was a bug in the recently released wheels.

ronuchit commented 3 years ago

Looks like this was addressed by https://github.com/rusty1s/pytorch_sparse/commit/85ce67e4de51ab72b5b5ae9edc4fc6e377232438, thank you! Closing.