pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
MIT License
21.12k stars 3.63k forks source link

torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults #9115

Closed UpCoder closed 5 months ago

UpCoder commented 6 months ago

🐛 Describe the bug

🐛 Describe the bug when I transfer the model by torch.jit, I meet the error as title describe. The minimal code as follow:

from torch_geometric.nn import GATConv, GraphConv, TopKPooling
class CustomeModel(torch.nn.Module):
    def __init__(self):
        super(CustomeModel, self).__init__()
        self.gcn_layer = GraphConv(512, 256)
    def forward(self, x1, x2):
        x, edge_index = x1, x2
        x = self.gcn_layer(x, edge_index)
        return x, edge_index
model = CustomeModel()
model = torch.jit.script(model)



Error Info

raise NotSupportedError(ctx_range, _vararg_kwarg_err) torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults: File "/home/tiger/.local/lib/python3.9/site-packages/torch_geometric/nn/conv/", line 459 edge_index: Adj, size: Size = None, **kwargs: Any,

    ) -> Tensor:
        r"""The initial call to start propagating messages.
UpCoder commented 6 months ago

I fix the bug by installing torch_geometric from 2.5.2->2.5.0

rusty1s commented 6 months ago

I have been trying to debug this for a while, but I must admit that I am not totally sure what might cause this. Can you please do me a favor and replace the except block here with

except Exception as e:
     raise e

and post the output? It looks like we are ignoring an exception here which we shouldn't.

UpCoder commented 6 months ago

I have been trying to debug this for a while, but I must admit that I am not totally sure what might cause this. Can you please do me a favor and replace the except block here with

except Exception as e:
     raise e

and post the output? It looks like we are ignoring an exception here which we shouldn't.

like this module 'torch_geometric.nn.conv.graph_conv_GraphConv_propagate' has no attribute 'propagate' module 'torch_geometric.nn.conv.graph_conv_GraphConv_propagate' has no attribute 'propagate' module 'torch_geometric.nn.conv.graph_conv_GraphConv_propagate' has no attribute 'propagate'

rusty1s commented 6 months ago

Thanks :)

rusty1s commented 6 months ago

And a last request: can you print the output of print(module_repr) here?

viktortnk commented 6 months ago
import typing
from typing import Union

import torch
from torch import Tensor

import torch_geometric.typing
from torch_geometric import is_compiling
from torch_geometric.utils import is_sparse
from torch_geometric.typing import Size, SparseTensor

from typing import List, NamedTuple, Optional, Union

import torch
from torch import Tensor

from torch_geometric.utils import is_torch_sparse_tensor
from torch_geometric.utils.sparse import ptr2index
from torch_geometric.typing import SparseTensor

class CollectArgs(NamedTuple):
    x_j: Tensor
    edge_weight: typing.Optional[Tensor]
    index: Tensor
    ptr: typing.Optional[Tensor]
    dim_size: typing.Optional[int]

def collect(
    edge_index: Union[Tensor, SparseTensor],
    x: OptPairTensor,
    edge_weight: OptTensor,
    size: List[Optional[int]],
) -> CollectArgs:

    i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)

    # Collect special arguments:
    if isinstance(edge_index, Tensor):
        if is_torch_sparse_tensor(edge_index):
            adj_t = edge_index
            if adj_t.layout == torch.sparse_coo:
                edge_index_i = adj_t.indices()[0]
                edge_index_j = adj_t.indices()[1]
                ptr = None
            elif adj_t.layout == torch.sparse_csr:
                ptr = adj_t.crow_indices()
                edge_index_j = adj_t.col_indices()
                edge_index_i = ptr2index(ptr, output_size=edge_index_j.numel())
                raise ValueError(f"Received invalid layout '{adj_t.layout}'")
            if edge_weight is None:
                edge_weight = adj_t.values()

            edge_index_i = edge_index[i]
            edge_index_j = edge_index[j]
            ptr = None

    elif isinstance(edge_index, SparseTensor):
        adj_t = edge_index
        edge_index_i, edge_index_j, _value = adj_t.coo()
        ptr, _, _ = adj_t.csr()
        if edge_weight is None:
            edge_weight = _value

        raise NotImplementedError

    # Collect user-defined arguments:
    # (1) - Collect `x_j`:
    if isinstance(x, (tuple, list)):
        assert len(x) == 2
        _x_0, _x_1 = x[0], x[1]
        if isinstance(_x_0, Tensor):
            self._set_size(size, 0, _x_0)
            x_j = self._index_select(_x_0, edge_index_j)
            x_j = None
        if isinstance(_x_1, Tensor):
            self._set_size(size, 1, _x_1)
    elif isinstance(x, Tensor):
        self._set_size(size, j, x)
        x_j = self._index_select(x, edge_index_j)
        x_j = None

    # Collect default arguments:

    index = edge_index_i
    size_i = size[i] if size[i] is not None else size[j]
    size_j = size[j] if size[j] is not None else size[i]
    dim_size = size_i

    return CollectArgs(

def propagate(
    edge_index: Union[Tensor, SparseTensor],
    x: OptPairTensor,
    edge_weight: OptTensor,
    size: Size = None,
) -> Tensor:

    # Begin Propagate Forward Pre Hook #########################################
    if not torch.jit.is_scripting() and not is_compiling():
        for hook in self._propagate_forward_pre_hooks.values():
            hook_kwargs = dict(
            res = hook(self, (edge_index, size, hook_kwargs))
            if res is not None:
                edge_index, size, hook_kwargs = res
                x = hook_kwargs['x']
                edge_weight = hook_kwargs['edge_weight']
    # End Propagate Forward Pre Hook ###########################################

    mutable_size = self._check_input(edge_index, size)
    fuse = is_sparse(edge_index) and self.fuse

    if fuse:
        # Begin Message and Aggregate Forward Pre Hook #########################
        if not torch.jit.is_scripting() and not is_compiling():
            for hook in self._message_and_aggregate_forward_pre_hooks.values():
                hook_kwargs = dict(
                res = hook(self, (edge_index, hook_kwargs))
                if res is not None:
                    edge_index, hook_kwargs = res
                    x = hook_kwargs['x']
        # End Message and Aggregate Forward Pre Hook ##########################

        out = self.message_and_aggregate(

        # Begin Message and Aggregate Forward Hook #############################
        if not torch.jit.is_scripting() and not is_compiling():
            for hook in self._message_and_aggregate_forward_hooks.values():
                hook_kwargs = dict(
                res = hook(self, (edge_index, hook_kwargs, ), out)
                out = res if res is not None else out
        # End Message and Aggregate Forward Hook ###############################

        out = self.update(


        kwargs = self.collect(

        # Begin Message Forward Pre Hook #######################################
        if not torch.jit.is_scripting() and not is_compiling():
            for hook in self._message_forward_pre_hooks.values():
                hook_kwargs = dict(
                res = hook(self, (hook_kwargs, ))
                hook_kwargs = res[0] if isinstance(res, tuple) else res
                if res is not None:
                    kwargs = CollectArgs(
        # End Message Forward Pre Hook #########################################

        out = self.message(

        # Begin Message Forward Hook ###########################################
        if not torch.jit.is_scripting() and not is_compiling():
            for hook in self._message_forward_hooks.values():
                hook_kwargs = dict(
                res = hook(self, (hook_kwargs, ), out)
                out = res if res is not None else out
        # End Message Forward Hook #############################################

        # Begin Aggregate Forward Pre Hook #####################################
        if not torch.jit.is_scripting() and not is_compiling():
            for hook in self._aggregate_forward_pre_hooks.values():
                hook_kwargs = dict(
                res = hook(self, (hook_kwargs, ))
                hook_kwargs = res[0] if isinstance(res, tuple) else res
                if res is not None:
                    kwargs = CollectArgs(
        # End Aggregate Forward Pre Hook #######################################

        out = self.aggregate(

        # Begin Aggregate Forward Hook #########################################
        if not torch.jit.is_scripting() and not is_compiling():
            for hook in self._aggregate_forward_hooks.values():
                hook_kwargs = dict(
                res = hook(self, (hook_kwargs, ), out)
                out = res if res is not None else out
        # End Aggregate Forward Hook ###########################################

        out = self.update(

    # Begin Propagate Forward Hook ############################################
    if not torch.jit.is_scripting() and not is_compiling():
        for hook in self._propagate_forward_hooks.values():
            hook_kwargs = dict(
            res = hook(self, (edge_index, mutable_size, hook_kwargs), out)
            out = res if res is not None else out
    # End Propagate Forward Hook ##############################################

    return out
rusty1s commented 6 months ago

I think I fixed this in Can you confirm by adding tmp.flush() to (as in the PR) and try again? If this still does not work, can you show me the output of print(module.__dict__)?

viktortnk commented 6 months ago

@rusty1s unfortunately still doesn't work

  ret = run_job(
{'__name__': 'torch_geometric.nn.conv.gatv2_conv_GATv2Conv_edge_updater', '__doc__': None, '__package__': 'torch_geometric.nn.conv', '__loader__': <_frozen_importlib_external.SourceFileLoader object at 0x7f6ddaf8f5d0>, '__spec__': ModuleSpec(name='torch_geometric.nn.conv.gatv2_conv_GATv2Conv_edge_updater', loader=<_frozen_importlib_external.SourceFileLoader object at 0x7f6ddaf8f5d0>, origin='/tmp/'), '__file__': '/tmp/', '__cached__': '/tmp/__pycache__/torch_geometric.nn.conv.gatv2_conv_GATv2Conv_edge_updater_91zbt5hq.cpython-311.pyc', '__builtins__': {'__name__': 'builtins', '__doc__': "Built-in functions, types, exceptions, and other objects.\n\nThis module provides direct access to all 'built-in'\nidentifiers of Python; for example, builtins.len is\nthe full name for the built-in function len().\n\nThis module is not normally accessed explicitly by most\napplications, but can be useful in modules that provide\nobjects with the same name as a built-in value, but in\nwhich the built-in of that name is also needed.", '__package__': '', '__loader__': <class '_frozen_importlib.BuiltinImporter'>, '__spec__': ModuleSpec(name='builtins', loader=<class '_frozen_importlib.BuiltinImporter'>, origin='built-in'), '__build_class__': <built-in function __build_class__>, '__import__': <built-in function __import__>, 'abs': <built-in function abs>, 'all': <built-in function all>, 'any': <built-in function any>, 'ascii': <built-in function ascii>, 'bin': <built-in function bin>, 'breakpoint': <built-in function breakpoint>, 'callable': <built-in function callable>, 'chr': <built-in function chr>, 'compile': <built-in function compile>, 'delattr': <built-in function delattr>, 'dir': <built-in function dir>, 'divmod': <built-in function divmod>, 'eval': <built-in function eval>, 'exec': <built-in function exec>, 'format': <built-in function format>, 'getattr': <built-in function getattr>, 'globals': <built-in function globals>, 'hasattr': <built-in function hasattr>, 'hash': <built-in function hash>, 'hex': <built-in function hex>, 'id': <built-in function id>, 'input': <built-in function input>, 'isinstance': <built-in function isinstance>, 'issubclass': <built-in function issubclass>, 'iter': <built-in function iter>, 'aiter': <built-in function aiter>, 'len': <built-in function len>, 'locals': <built-in function locals>, 'max': <built-in function max>, 'min': <built-in function min>, 'next': <built-in function next>, 'anext': <built-in function anext>, 'oct': <built-in function oct>, 'ord': <built-in function ord>, 'pow': <built-in function pow>, 'print': <built-in function print>, 'repr': <built-in function repr>, 'round': <built-in function round>, 'setattr': <built-in function setattr>, 'sorted': <built-in function sorted>, 'sum': <built-in function sum>, 'vars': <built-in function vars>, 'None': None, 'Ellipsis': Ellipsis, 'NotImplemented': NotImplemented, 'False': False, 'True': True, 'bool': <class 'bool'>, 'memoryview': <class 'memoryview'>, 'bytearray': <class 'bytearray'>, 'bytes': <class 'bytes'>, 'classmethod': <class 'classmethod'>, 'complex': <class 'complex'>, 'dict': <class 'dict'>, 'enumerate': <class 'enumerate'>, 'filter': <class 'filter'>, 'float': <class 'float'>, 'frozenset': <class 'frozenset'>, 'property': <class 'property'>, 'int': <class 'int'>, 'list': <class 'list'>, 'map': <class 'map'>, 'object': <class 'object'>, 'range': <class 'range'>, 'reversed': <class 'reversed'>, 'set': <class 'set'>, 'slice': <class 'slice'>, 'staticmethod': <class 'staticmethod'>, 'str': <class 'str'>, 'super': <class 'super'>, 'tuple': <class 'tuple'>, 'type': <class 'type'>, 'zip': <class 'zip'>, '__debug__': True, 'BaseException': <class 'BaseException'>, 'BaseExceptionGroup': <class 'BaseExceptionGroup'>, 'Exception': <class 'Exception'>, 'GeneratorExit': <class 'GeneratorExit'>, 'KeyboardInterrupt': <class 'KeyboardInterrupt'>, 'SystemExit': <class 'SystemExit'>, 'ArithmeticError': <class 'ArithmeticError'>, 'AssertionError': <class 'AssertionError'>, 'AttributeError': <class 'AttributeError'>, 'BufferError': <class 'BufferError'>, 'EOFError': <class 'EOFError'>, 'ImportError': <class 'ImportError'>, 'LookupError': <class 'LookupError'>, 'MemoryError': <class 'MemoryError'>, 'NameError': <class 'NameError'>, 'OSError': <class 'OSError'>, 'ReferenceError': <class 'ReferenceError'>, 'RuntimeError': <class 'RuntimeError'>, 'StopAsyncIteration': <class 'StopAsyncIteration'>, 'StopIteration': <class 'StopIteration'>, 'SyntaxError': <class 'SyntaxError'>, 'SystemError': <class 'SystemError'>, 'TypeError': <class 'TypeError'>, 'ValueError': <class 'ValueError'>, 'Warning': <class 'Warning'>, 'FloatingPointError': <class 'FloatingPointError'>, 'OverflowError': <class 'OverflowError'>, 'ZeroDivisionError': <class 'ZeroDivisionError'>, 'BytesWarning': <class 'BytesWarning'>, 'DeprecationWarning': <class 'DeprecationWarning'>, 'EncodingWarning': <class 'EncodingWarning'>, 'FutureWarning': <class 'FutureWarning'>, 'ImportWarning': <class 'ImportWarning'>, 'PendingDeprecationWarning': <class 'PendingDeprecationWarning'>, 'ResourceWarning': <class 'ResourceWarning'>, 'RuntimeWarning': <class 'RuntimeWarning'>, 'SyntaxWarning': <class 'SyntaxWarning'>, 'UnicodeWarning': <class 'UnicodeWarning'>, 'UserWarning': <class 'UserWarning'>, 'BlockingIOError': <class 'BlockingIOError'>, 'ChildProcessError': <class 'ChildProcessError'>, 'ConnectionError': <class 'ConnectionError'>, 'FileExistsError': <class 'FileExistsError'>, 'FileNotFoundError': <class 'FileNotFoundError'>, 'InterruptedError': <class 'InterruptedError'>, 'IsADirectoryError': <class 'IsADirectoryError'>, 'NotADirectoryError': <class 'NotADirectoryError'>, 'PermissionError': <class 'PermissionError'>, 'ProcessLookupError': <class 'ProcessLookupError'>, 'TimeoutError': <class 'TimeoutError'>, 'IndentationError': <class 'IndentationError'>, 'IndexError': <class 'IndexError'>, 'KeyError': <class 'KeyError'>, 'ModuleNotFoundError': <class 'ModuleNotFoundError'>, 'NotImplementedError': <class 'NotImplementedError'>, 'RecursionError': <class 'RecursionError'>, 'UnboundLocalError': <class 'UnboundLocalError'>, 'UnicodeError': <class 'UnicodeError'>, 'BrokenPipeError': <class 'BrokenPipeError'>, 'ConnectionAbortedError': <class 'ConnectionAbortedError'>, 'ConnectionRefusedError': <class 'ConnectionRefusedError'>, 'ConnectionResetError': <class 'ConnectionResetError'>, 'TabError': <class 'TabError'>, 'UnicodeDecodeError': <class 'UnicodeDecodeError'>, 'UnicodeEncodeError': <class 'UnicodeEncodeError'>, 'UnicodeTranslateError': <class 'UnicodeTranslateError'>, 'ExceptionGroup': <class 'ExceptionGroup'>, 'EnvironmentError': <class 'OSError'>, 'IOError': <class 'OSError'>, 'open': <built-in function open>, 'quit': Use quit() or Ctrl-D (i.e. EOF) to exit, 'exit': Use exit() or Ctrl-D (i.e. EOF) to exit, 'copyright': Copyright (c) 2001-2023 Python Software Foundation.
All Rights Reserved.

Copyright (c) 2000
All Rights Reserved.

Copyright (c) 1995-2001 Corporation for National Research Initiatives.
All Rights Reserved.

Copyright (c) 1991-1995 Stichting Mathematisch Centrum, Amsterdam.
All Rights Reserved., 'credits':     Thanks to CWI, CNRI,, Zope Corporation and a cast of thousands
    for supporting Python development.  See for more information., 'license': Type license() to see the full license text, 'help': Type help() for interactive help, or help(object) for help about object., '__pybind11_internals_v4_gcc_libstdcpp_cxxabi1011__': <capsule object NULL at 0x7f6e3c705b00>, '__pybind11_internals_v4_gcc_libstdcpp_cxxabi1017__': <capsule object NULL at 0x7f6dee886940>}, 'typing': <module 'typing' from '/home/vltaranenko/miniforge3/envs/qc22/lib/python3.11/'>, 'Union': typing.Union, 'torch': <module 'torch' from '/home/vltaranenko/miniforge3/envs/qc22/lib/python3.11/site-packages/torch/'>, 'Tensor': <class 'torch.Tensor'>, 'torch_geometric': <module 'torch_geometric' from '/home/vltaranenko/miniforge3/envs/qc22/lib/python3.11/site-packages/torch_geometric/'>, 'is_compiling': <function is_compiling at 0x7f6df5169e40>, 'is_sparse': <function is_sparse at 0x7f6deb492b60>, 'Size': typing.Optional[typing.Tuple[int, int]], 'SparseTensor': <class 'torch_sparse.tensor.SparseTensor'>, 'Optional': typing.Optional, 'Tuple': typing.Tuple, 'F': <module 'torch.nn.functional' from '/home/vltaranenko/miniforge3/envs/qc22/lib/python3.11/site-packages/torch/nn/'>, 'Parameter': <class 'torch.nn.parameter.Parameter'>, 'MessagePassing': <class 'torch_geometric.nn.conv.message_passing.MessagePassing'>, 'Linear': <class 'torch_geometric.nn.dense.linear.Linear'>, 'glorot': <function glorot at 0x7f6de3853420>, 'zeros': <function zeros at 0x7f6de3853600>, 'Adj': typing.Union[torch.Tensor, torch_sparse.tensor.SparseTensor], 'NoneType': typing.Optional[torch.Tensor], 'OptTensor': typing.Optional[torch.Tensor], 'PairTensor': typing.Tuple[torch.Tensor, torch.Tensor], 'torch_sparse': <module 'torch_sparse' from '/home/vltaranenko/miniforge3/envs/qc22/lib/python3.11/site-packages/torch_sparse/'>, 'add_self_loops': <function add_self_loops at 0x7f6deb483b00>, 'is_torch_sparse_tensor': <function is_torch_sparse_tensor at 0x7f6deb492ac0>, 'remove_self_loops': <function remove_self_loops at 0x7f6deb493560>, 'softmax': <function softmax at 0x7f6deb490e00>, 'set_sparse_value': <function set_sparse_value at 0x7f6deb492fc0>, 'overload': <function _overload_method at 0x7f6df6e793a0>, 'GATv2Conv': <class 'torch_geometric.nn.conv.gatv2_conv.GATv2Conv'>, 'List': typing.List, 'NamedTuple': <function NamedTuple at 0x7f6ea0ceede0>, 'ptr2index': <function ptr2index at 0x7f6deb493060>, 'CollectArgs': <class 'torch_geometric.nn.conv.gatv2_conv_GATv2Conv_edge_updater.CollectArgs'>, 'edge_collect': <function edge_collect at 0x7f6dcf261e40>, 'edge_updater': <function edge_updater at 0x7f6dcf262c00>}
rusty1s commented 6 months ago

Mh, what a bummer. I'll take a look. Thanks.

rusty1s commented 5 months ago

Just pushed 2.5.3 with a fix :)

aJay0422 commented 1 month ago

I've upgraded the torch_geometric to 2.5.3 version by pip install torch_geometric --upgrade and the issue still remains. Here is the issue:

Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
  File "/home/shangjie/miniconda3/envs/env_ai/lib/python3.8/site-packages/torch_geometric/nn/conv/", line 466
        edge_index: Adj,
        size: Size = None,
        **kwargs: Any,
         ~~~~~~~ <--- HERE
    ) -> Tensor:
        r"""The initial call to start propagating messages.

Here is the line in my code that triggers the issue: scripted_model = torch.jit.script(model)