pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.12k stars 3.63k forks source link

torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults #9115

Closed UpCoder closed 5 months ago

UpCoder commented 6 months ago

🐛 Describe the bug

🐛 Describe the bug when I transfer the model by torch.jit, I meet the error as title describe. The minimal code as follow:

from torch_geometric.nn import GATConv, GraphConv, TopKPooling
class CustomeModel(torch.nn.Module):
    def __init__(self):
        super(CustomeModel, self).__init__()
        self.gcn_layer = GraphConv(512, 256)
    def forward(self, x1, x2):
        x, edge_index = x1, x2
        x = self.gcn_layer(x, edge_index)
        return x, edge_index
model = CustomeModel()
model = torch.jit.script(model)

Versions

Versions

Error Info

raise NotSupportedError(ctx_range, _vararg_kwarg_err) torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults: File "/home/tiger/.local/lib/python3.9/site-packages/torch_geometric/nn/conv/message_passing.py", line 459 edge_index: Adj, size: Size = None, **kwargs: Any,


    ) -> Tensor:
        r"""The initial call to start propagating messages.
UpCoder commented 6 months ago

I fix the bug by installing torch_geometric from 2.5.2->2.5.0

rusty1s commented 6 months ago

I have been trying to debug this for a while, but I must admit that I am not totally sure what might cause this. Can you please do me a favor and replace the except block here with

except Exception as e:
     raise e

and post the output? It looks like we are ignoring an exception here which we shouldn't.

UpCoder commented 6 months ago

I have been trying to debug this for a while, but I must admit that I am not totally sure what might cause this. Can you please do me a favor and replace the except block here with

except Exception as e:
     raise e

and post the output? It looks like we are ignoring an exception here which we shouldn't.

like this module 'torch_geometric.nn.conv.graph_conv_GraphConv_propagate' has no attribute 'propagate' module 'torch_geometric.nn.conv.graph_conv_GraphConv_propagate' has no attribute 'propagate' module 'torch_geometric.nn.conv.graph_conv_GraphConv_propagate' has no attribute 'propagate'

image
rusty1s commented 6 months ago

Thanks :)

rusty1s commented 6 months ago

And a last request: can you print the output of print(module_repr) here?

viktortnk commented 6 months ago
import typing
from typing import Union

import torch
from torch import Tensor

import torch_geometric.typing
from torch_geometric import is_compiling
from torch_geometric.utils import is_sparse
from torch_geometric.typing import Size, SparseTensor

from typing import List, NamedTuple, Optional, Union

import torch
from torch import Tensor

from torch_geometric.utils import is_torch_sparse_tensor
from torch_geometric.utils.sparse import ptr2index
from torch_geometric.typing import SparseTensor

class CollectArgs(NamedTuple):
    x_j: Tensor
    edge_weight: typing.Optional[Tensor]
    index: Tensor
    ptr: typing.Optional[Tensor]
    dim_size: typing.Optional[int]

def collect(
    self,
    edge_index: Union[Tensor, SparseTensor],
    x: OptPairTensor,
    edge_weight: OptTensor,
    size: List[Optional[int]],
) -> CollectArgs:

    i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)

    # Collect special arguments:
    if isinstance(edge_index, Tensor):
        if is_torch_sparse_tensor(edge_index):
            adj_t = edge_index
            if adj_t.layout == torch.sparse_coo:
                edge_index_i = adj_t.indices()[0]
                edge_index_j = adj_t.indices()[1]
                ptr = None
            elif adj_t.layout == torch.sparse_csr:
                ptr = adj_t.crow_indices()
                edge_index_j = adj_t.col_indices()
                edge_index_i = ptr2index(ptr, output_size=edge_index_j.numel())
            else:
                raise ValueError(f"Received invalid layout '{adj_t.layout}'")
            if edge_weight is None:
                edge_weight = adj_t.values()

        else:
            edge_index_i = edge_index[i]
            edge_index_j = edge_index[j]
            ptr = None

    elif isinstance(edge_index, SparseTensor):
        adj_t = edge_index
        edge_index_i, edge_index_j, _value = adj_t.coo()
        ptr, _, _ = adj_t.csr()
        if edge_weight is None:
            edge_weight = _value

    else:
        raise NotImplementedError

    # Collect user-defined arguments:
    # (1) - Collect `x_j`:
    if isinstance(x, (tuple, list)):
        assert len(x) == 2
        _x_0, _x_1 = x[0], x[1]
        if isinstance(_x_0, Tensor):
            self._set_size(size, 0, _x_0)
            x_j = self._index_select(_x_0, edge_index_j)
        else:
            x_j = None
        if isinstance(_x_1, Tensor):
            self._set_size(size, 1, _x_1)
    elif isinstance(x, Tensor):
        self._set_size(size, j, x)
        x_j = self._index_select(x, edge_index_j)
    else:
        x_j = None

    # Collect default arguments:

    index = edge_index_i
    size_i = size[i] if size[i] is not None else size[j]
    size_j = size[j] if size[j] is not None else size[i]
    dim_size = size_i

    return CollectArgs(
        x_j,
        edge_weight,
        index,
        ptr,
        dim_size,
    )

def propagate(
    self,
    edge_index: Union[Tensor, SparseTensor],
    x: OptPairTensor,
    edge_weight: OptTensor,
    size: Size = None,
) -> Tensor:

    # Begin Propagate Forward Pre Hook #########################################
    if not torch.jit.is_scripting() and not is_compiling():
        for hook in self._propagate_forward_pre_hooks.values():
            hook_kwargs = dict(
                x=x,
                edge_weight=edge_weight,
            )
            res = hook(self, (edge_index, size, hook_kwargs))
            if res is not None:
                edge_index, size, hook_kwargs = res
                x = hook_kwargs['x']
                edge_weight = hook_kwargs['edge_weight']
    # End Propagate Forward Pre Hook ###########################################

    mutable_size = self._check_input(edge_index, size)
    fuse = is_sparse(edge_index) and self.fuse

    if fuse:
        # Begin Message and Aggregate Forward Pre Hook #########################
        if not torch.jit.is_scripting() and not is_compiling():
            for hook in self._message_and_aggregate_forward_pre_hooks.values():
                hook_kwargs = dict(
                    x=x,
                )
                res = hook(self, (edge_index, hook_kwargs))
                if res is not None:
                    edge_index, hook_kwargs = res
                    x = hook_kwargs['x']
        # End Message and Aggregate Forward Pre Hook ##########################

        out = self.message_and_aggregate(
            edge_index,
            x,
        )

        # Begin Message and Aggregate Forward Hook #############################
        if not torch.jit.is_scripting() and not is_compiling():
            for hook in self._message_and_aggregate_forward_hooks.values():
                hook_kwargs = dict(
                    x=x,
                )
                res = hook(self, (edge_index, hook_kwargs, ), out)
                out = res if res is not None else out
        # End Message and Aggregate Forward Hook ###############################

        out = self.update(
            out,
        )

    else:

        kwargs = self.collect(
            edge_index,
            x,
            edge_weight,
            mutable_size,
        )

        # Begin Message Forward Pre Hook #######################################
        if not torch.jit.is_scripting() and not is_compiling():
            for hook in self._message_forward_pre_hooks.values():
                hook_kwargs = dict(
                    x_j=kwargs.x_j,
                    edge_weight=kwargs.edge_weight,
                )
                res = hook(self, (hook_kwargs, ))
                hook_kwargs = res[0] if isinstance(res, tuple) else res
                if res is not None:
                    kwargs = CollectArgs(
                        x_j=hook_kwargs['x_j'],
                        edge_weight=hook_kwargs['edge_weight'],
                        index=kwargs.index,
                        ptr=kwargs.ptr,
                        dim_size=kwargs.dim_size,
                    )
        # End Message Forward Pre Hook #########################################

        out = self.message(
            x_j=kwargs.x_j,
            edge_weight=kwargs.edge_weight,
        )

        # Begin Message Forward Hook ###########################################
        if not torch.jit.is_scripting() and not is_compiling():
            for hook in self._message_forward_hooks.values():
                hook_kwargs = dict(
                    x_j=kwargs.x_j,
                    edge_weight=kwargs.edge_weight,
                )
                res = hook(self, (hook_kwargs, ), out)
                out = res if res is not None else out
        # End Message Forward Hook #############################################

        # Begin Aggregate Forward Pre Hook #####################################
        if not torch.jit.is_scripting() and not is_compiling():
            for hook in self._aggregate_forward_pre_hooks.values():
                hook_kwargs = dict(
                    index=kwargs.index,
                    ptr=kwargs.ptr,
                    dim_size=kwargs.dim_size,
                )
                res = hook(self, (hook_kwargs, ))
                hook_kwargs = res[0] if isinstance(res, tuple) else res
                if res is not None:
                    kwargs = CollectArgs(
                        x_j=kwargs.x_j,
                        edge_weight=kwargs.edge_weight,
                        index=hook_kwargs['index'],
                        ptr=hook_kwargs['ptr'],
                        dim_size=hook_kwargs['dim_size'],
                    )
        # End Aggregate Forward Pre Hook #######################################

        out = self.aggregate(
            out,
            index=kwargs.index,
            ptr=kwargs.ptr,
            dim_size=kwargs.dim_size,
        )

        # Begin Aggregate Forward Hook #########################################
        if not torch.jit.is_scripting() and not is_compiling():
            for hook in self._aggregate_forward_hooks.values():
                hook_kwargs = dict(
                    index=kwargs.index,
                    ptr=kwargs.ptr,
                    dim_size=kwargs.dim_size,
                )
                res = hook(self, (hook_kwargs, ), out)
                out = res if res is not None else out
        # End Aggregate Forward Hook ###########################################

        out = self.update(
            out,
        )

    # Begin Propagate Forward Hook ############################################
    if not torch.jit.is_scripting() and not is_compiling():
        for hook in self._propagate_forward_hooks.values():
            hook_kwargs = dict(
                x=x,
                edge_weight=edge_weight,
            )
            res = hook(self, (edge_index, mutable_size, hook_kwargs), out)
            out = res if res is not None else out
    # End Propagate Forward Hook ##############################################

    return out
rusty1s commented 6 months ago

I think I fixed this in https://github.com/pyg-team/pytorch_geometric/pull/9151. Can you confirm by adding tmp.flush() to template.py (as in the PR) and try again? If this still does not work, can you show me the output of print(module.__dict__)?

viktortnk commented 6 months ago

@rusty1s unfortunately still doesn't work

  ret = run_job(
{'__name__': 'torch_geometric.nn.conv.gatv2_conv_GATv2Conv_edge_updater', '__doc__': None, '__package__': 'torch_geometric.nn.conv', '__loader__': <_frozen_importlib_external.SourceFileLoader object at 0x7f6ddaf8f5d0>, '__spec__': ModuleSpec(name='torch_geometric.nn.conv.gatv2_conv_GATv2Conv_edge_updater', loader=<_frozen_importlib_external.SourceFileLoader object at 0x7f6ddaf8f5d0>, origin='/tmp/torch_geometric.nn.conv.gatv2_conv_GATv2Conv_edge_updater_91zbt5hq.py'), '__file__': '/tmp/torch_geometric.nn.conv.gatv2_conv_GATv2Conv_edge_updater_91zbt5hq.py', '__cached__': '/tmp/__pycache__/torch_geometric.nn.conv.gatv2_conv_GATv2Conv_edge_updater_91zbt5hq.cpython-311.pyc', '__builtins__': {'__name__': 'builtins', '__doc__': "Built-in functions, types, exceptions, and other objects.\n\nThis module provides direct access to all 'built-in'\nidentifiers of Python; for example, builtins.len is\nthe full name for the built-in function len().\n\nThis module is not normally accessed explicitly by most\napplications, but can be useful in modules that provide\nobjects with the same name as a built-in value, but in\nwhich the built-in of that name is also needed.", '__package__': '', '__loader__': <class '_frozen_importlib.BuiltinImporter'>, '__spec__': ModuleSpec(name='builtins', loader=<class '_frozen_importlib.BuiltinImporter'>, origin='built-in'), '__build_class__': <built-in function __build_class__>, '__import__': <built-in function __import__>, 'abs': <built-in function abs>, 'all': <built-in function all>, 'any': <built-in function any>, 'ascii': <built-in function ascii>, 'bin': <built-in function bin>, 'breakpoint': <built-in function breakpoint>, 'callable': <built-in function callable>, 'chr': <built-in function chr>, 'compile': <built-in function compile>, 'delattr': <built-in function delattr>, 'dir': <built-in function dir>, 'divmod': <built-in function divmod>, 'eval': <built-in function eval>, 'exec': <built-in function exec>, 'format': <built-in function format>, 'getattr': <built-in function getattr>, 'globals': <built-in function globals>, 'hasattr': <built-in function hasattr>, 'hash': <built-in function hash>, 'hex': <built-in function hex>, 'id': <built-in function id>, 'input': <built-in function input>, 'isinstance': <built-in function isinstance>, 'issubclass': <built-in function issubclass>, 'iter': <built-in function iter>, 'aiter': <built-in function aiter>, 'len': <built-in function len>, 'locals': <built-in function locals>, 'max': <built-in function max>, 'min': <built-in function min>, 'next': <built-in function next>, 'anext': <built-in function anext>, 'oct': <built-in function oct>, 'ord': <built-in function ord>, 'pow': <built-in function pow>, 'print': <built-in function print>, 'repr': <built-in function repr>, 'round': <built-in function round>, 'setattr': <built-in function setattr>, 'sorted': <built-in function sorted>, 'sum': <built-in function sum>, 'vars': <built-in function vars>, 'None': None, 'Ellipsis': Ellipsis, 'NotImplemented': NotImplemented, 'False': False, 'True': True, 'bool': <class 'bool'>, 'memoryview': <class 'memoryview'>, 'bytearray': <class 'bytearray'>, 'bytes': <class 'bytes'>, 'classmethod': <class 'classmethod'>, 'complex': <class 'complex'>, 'dict': <class 'dict'>, 'enumerate': <class 'enumerate'>, 'filter': <class 'filter'>, 'float': <class 'float'>, 'frozenset': <class 'frozenset'>, 'property': <class 'property'>, 'int': <class 'int'>, 'list': <class 'list'>, 'map': <class 'map'>, 'object': <class 'object'>, 'range': <class 'range'>, 'reversed': <class 'reversed'>, 'set': <class 'set'>, 'slice': <class 'slice'>, 'staticmethod': <class 'staticmethod'>, 'str': <class 'str'>, 'super': <class 'super'>, 'tuple': <class 'tuple'>, 'type': <class 'type'>, 'zip': <class 'zip'>, '__debug__': True, 'BaseException': <class 'BaseException'>, 'BaseExceptionGroup': <class 'BaseExceptionGroup'>, 'Exception': <class 'Exception'>, 'GeneratorExit': <class 'GeneratorExit'>, 'KeyboardInterrupt': <class 'KeyboardInterrupt'>, 'SystemExit': <class 'SystemExit'>, 'ArithmeticError': <class 'ArithmeticError'>, 'AssertionError': <class 'AssertionError'>, 'AttributeError': <class 'AttributeError'>, 'BufferError': <class 'BufferError'>, 'EOFError': <class 'EOFError'>, 'ImportError': <class 'ImportError'>, 'LookupError': <class 'LookupError'>, 'MemoryError': <class 'MemoryError'>, 'NameError': <class 'NameError'>, 'OSError': <class 'OSError'>, 'ReferenceError': <class 'ReferenceError'>, 'RuntimeError': <class 'RuntimeError'>, 'StopAsyncIteration': <class 'StopAsyncIteration'>, 'StopIteration': <class 'StopIteration'>, 'SyntaxError': <class 'SyntaxError'>, 'SystemError': <class 'SystemError'>, 'TypeError': <class 'TypeError'>, 'ValueError': <class 'ValueError'>, 'Warning': <class 'Warning'>, 'FloatingPointError': <class 'FloatingPointError'>, 'OverflowError': <class 'OverflowError'>, 'ZeroDivisionError': <class 'ZeroDivisionError'>, 'BytesWarning': <class 'BytesWarning'>, 'DeprecationWarning': <class 'DeprecationWarning'>, 'EncodingWarning': <class 'EncodingWarning'>, 'FutureWarning': <class 'FutureWarning'>, 'ImportWarning': <class 'ImportWarning'>, 'PendingDeprecationWarning': <class 'PendingDeprecationWarning'>, 'ResourceWarning': <class 'ResourceWarning'>, 'RuntimeWarning': <class 'RuntimeWarning'>, 'SyntaxWarning': <class 'SyntaxWarning'>, 'UnicodeWarning': <class 'UnicodeWarning'>, 'UserWarning': <class 'UserWarning'>, 'BlockingIOError': <class 'BlockingIOError'>, 'ChildProcessError': <class 'ChildProcessError'>, 'ConnectionError': <class 'ConnectionError'>, 'FileExistsError': <class 'FileExistsError'>, 'FileNotFoundError': <class 'FileNotFoundError'>, 'InterruptedError': <class 'InterruptedError'>, 'IsADirectoryError': <class 'IsADirectoryError'>, 'NotADirectoryError': <class 'NotADirectoryError'>, 'PermissionError': <class 'PermissionError'>, 'ProcessLookupError': <class 'ProcessLookupError'>, 'TimeoutError': <class 'TimeoutError'>, 'IndentationError': <class 'IndentationError'>, 'IndexError': <class 'IndexError'>, 'KeyError': <class 'KeyError'>, 'ModuleNotFoundError': <class 'ModuleNotFoundError'>, 'NotImplementedError': <class 'NotImplementedError'>, 'RecursionError': <class 'RecursionError'>, 'UnboundLocalError': <class 'UnboundLocalError'>, 'UnicodeError': <class 'UnicodeError'>, 'BrokenPipeError': <class 'BrokenPipeError'>, 'ConnectionAbortedError': <class 'ConnectionAbortedError'>, 'ConnectionRefusedError': <class 'ConnectionRefusedError'>, 'ConnectionResetError': <class 'ConnectionResetError'>, 'TabError': <class 'TabError'>, 'UnicodeDecodeError': <class 'UnicodeDecodeError'>, 'UnicodeEncodeError': <class 'UnicodeEncodeError'>, 'UnicodeTranslateError': <class 'UnicodeTranslateError'>, 'ExceptionGroup': <class 'ExceptionGroup'>, 'EnvironmentError': <class 'OSError'>, 'IOError': <class 'OSError'>, 'open': <built-in function open>, 'quit': Use quit() or Ctrl-D (i.e. EOF) to exit, 'exit': Use exit() or Ctrl-D (i.e. EOF) to exit, 'copyright': Copyright (c) 2001-2023 Python Software Foundation.
All Rights Reserved.

Copyright (c) 2000 BeOpen.com.
All Rights Reserved.

Copyright (c) 1995-2001 Corporation for National Research Initiatives.
All Rights Reserved.

Copyright (c) 1991-1995 Stichting Mathematisch Centrum, Amsterdam.
All Rights Reserved., 'credits':     Thanks to CWI, CNRI, BeOpen.com, Zope Corporation and a cast of thousands
    for supporting Python development.  See www.python.org for more information., 'license': Type license() to see the full license text, 'help': Type help() for interactive help, or help(object) for help about object., '__pybind11_internals_v4_gcc_libstdcpp_cxxabi1011__': <capsule object NULL at 0x7f6e3c705b00>, '__pybind11_internals_v4_gcc_libstdcpp_cxxabi1017__': <capsule object NULL at 0x7f6dee886940>}, 'typing': <module 'typing' from '/home/vltaranenko/miniforge3/envs/qc22/lib/python3.11/typing.py'>, 'Union': typing.Union, 'torch': <module 'torch' from '/home/vltaranenko/miniforge3/envs/qc22/lib/python3.11/site-packages/torch/__init__.py'>, 'Tensor': <class 'torch.Tensor'>, 'torch_geometric': <module 'torch_geometric' from '/home/vltaranenko/miniforge3/envs/qc22/lib/python3.11/site-packages/torch_geometric/__init__.py'>, 'is_compiling': <function is_compiling at 0x7f6df5169e40>, 'is_sparse': <function is_sparse at 0x7f6deb492b60>, 'Size': typing.Optional[typing.Tuple[int, int]], 'SparseTensor': <class 'torch_sparse.tensor.SparseTensor'>, 'Optional': typing.Optional, 'Tuple': typing.Tuple, 'F': <module 'torch.nn.functional' from '/home/vltaranenko/miniforge3/envs/qc22/lib/python3.11/site-packages/torch/nn/functional.py'>, 'Parameter': <class 'torch.nn.parameter.Parameter'>, 'MessagePassing': <class 'torch_geometric.nn.conv.message_passing.MessagePassing'>, 'Linear': <class 'torch_geometric.nn.dense.linear.Linear'>, 'glorot': <function glorot at 0x7f6de3853420>, 'zeros': <function zeros at 0x7f6de3853600>, 'Adj': typing.Union[torch.Tensor, torch_sparse.tensor.SparseTensor], 'NoneType': typing.Optional[torch.Tensor], 'OptTensor': typing.Optional[torch.Tensor], 'PairTensor': typing.Tuple[torch.Tensor, torch.Tensor], 'torch_sparse': <module 'torch_sparse' from '/home/vltaranenko/miniforge3/envs/qc22/lib/python3.11/site-packages/torch_sparse/__init__.py'>, 'add_self_loops': <function add_self_loops at 0x7f6deb483b00>, 'is_torch_sparse_tensor': <function is_torch_sparse_tensor at 0x7f6deb492ac0>, 'remove_self_loops': <function remove_self_loops at 0x7f6deb493560>, 'softmax': <function softmax at 0x7f6deb490e00>, 'set_sparse_value': <function set_sparse_value at 0x7f6deb492fc0>, 'overload': <function _overload_method at 0x7f6df6e793a0>, 'GATv2Conv': <class 'torch_geometric.nn.conv.gatv2_conv.GATv2Conv'>, 'List': typing.List, 'NamedTuple': <function NamedTuple at 0x7f6ea0ceede0>, 'ptr2index': <function ptr2index at 0x7f6deb493060>, 'CollectArgs': <class 'torch_geometric.nn.conv.gatv2_conv_GATv2Conv_edge_updater.CollectArgs'>, 'edge_collect': <function edge_collect at 0x7f6dcf261e40>, 'edge_updater': <function edge_updater at 0x7f6dcf262c00>}
rusty1s commented 6 months ago

Mh, what a bummer. I'll take a look. Thanks.

rusty1s commented 5 months ago

Just pushed 2.5.3 with a fix :)

aJay0422 commented 1 month ago

I've upgraded the torch_geometric to 2.5.3 version by pip install torch_geometric --upgrade and the issue still remains. Here is the issue:

Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
  File "/home/shangjie/miniconda3/envs/env_ai/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py", line 466
        edge_index: Adj,
        size: Size = None,
        **kwargs: Any,
         ~~~~~~~ <--- HERE
    ) -> Tensor:
        r"""The initial call to start propagating messages.

Here is the line in my code that triggers the issue: scripted_model = torch.jit.script(model)