pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.96k stars 3.61k forks source link

Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults #9638

Open KevinHooah opened 1 week ago

KevinHooah commented 1 week ago

🐛 Describe the bug

Bug Description

The latest release mentions fixing the issue of converting the model to TorchScript when it contains message_passing. However, we tested it and found that this bug remains.

Bug's Detail:

The error message:

Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults: File "some_path/pytorch_geometric-master/torch_geometric/nn/conv/message_passing.py", line 425 edge_index: Adj, size: Size = None, **kwargs: Any,


    ) -> Tensor:
        r"""The initial call to start propagating messages.

The toy model taken from official tutorial:

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import MessagePassing
from torch.nn import Linear, Parameter
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add') 
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.empty(out_channels))
        self.reset_parameters()
    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()
    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        x = self.lin(x)
        row, col = torch.split(edge_index, 1, dim=0) #<-- We modified this line
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        out = self.propagate(edge_index, x=x, norm=norm)
        out = out + self.bias
        return out
    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

test_model = GCNConv(10, 10)

The line of code triggered the issue:

script_test_model = torch.jit.script(test_model)

A side bug we found during this process:

If we completely follow the official tutorial's code without modifying that line with a comment, i.e.,

# the original code in the tutorial
row, col = edge_index

The error message will be:

Tensor cannot be used as a tuple: File "/var/folders/3f/5776h3152rs7rlmtldx8gyxh0000gn/T/ipykernel_15097/3913400558.py", line 14 x = self.lin(x)

row, col = torch.split(edge_index, 1, dim=0)

    row, col = edge_index
               ~~~~~~~~~~ <--- HERE
    deg = degree(col, x.size(0), dtype=x.dtype)
    deg_inv_sqrt = deg.pow(-0.5)

Versions (Corrected)

I am testing it with my Mac in an environment where torch_geometric is not installed. I downloaded the latest code base and directly imported it from this local source. This issue persists when we test our model on the GPU machine with 2.5.3. version is installed.

Versions of relevant libraries: [pip3] flake8==3.8.4 [pip3] mypy-extensions==0.4.3 [pip3] numpy==1.24.4 [pip3] numpydoc==1.1.0 [pip3] torch==2.2.1 [pip3] torchaudio==2.2.1 [pip3] torchvision==0.17.1 [conda] blas 1.0 mkl
[conda] mkl 2021.4.0 hecd8cb5_637
[conda] mkl-service 2.4.0 py38h9ed2024_0
[conda] mkl_fft 1.3.1 py38h4ab4a9b_0
[conda] mkl_random 1.2.2 py38hb2f4e1b_0
[conda] numpy 1.24.4 pypi_0 pypi [conda] numpydoc 1.1.0 pyhd3eb1b0_1
[conda] torch 2.2.1 pypi_0 pypi [conda] torchaudio 2.2.1 pypi_0 pypi [conda] torchvision 0.17.1 pypi_0 pypi

rachitk commented 1 week ago

The version of torch_geometric implied by your environment is 2.2.0, rather than the most recent release of 2.5.3. Does the issue persist in the most recent version of torch_geometric?

KevinHooah commented 1 week ago

The version of torch_geometric implied by your environment is 2.2.0, rather than the most recent release of 2.5.3. Does the issue persist in the most recent version of torch_geometric?

Hello, thank you for the reply. I found I copied the environment information from the wrong one.

We tried two approaches:

Both methods don't work, all tell me this error.

rusty1s commented 4 days ago

It works for me. Can you do me a favor and modify this line to

except Exception as e:
    print(e)

and see what comes back?

KevinHooah commented 3 days ago

It works for me. Can you do me a favor and modify this line to

except Exception as e:
    print(e)

and see what comes back?

Hi, following the step, the exception I got when I initiated the GCNConv(MessagePassing) is: <class '__main__.GCNConv'> is a built-in class

KevinHooah commented 3 days ago

Hello, a little more info: even though I changed the model's name to GCNConv_test(MessagePassing), the exception remains the same.