google-deepmind / graph_nets

Build Graph Nets in Tensorflow
https://arxiv.org/abs/1806.01261
Apache License 2.0
5.34k stars 783 forks source link

_build() takes 2 positional arguments but 3 were given #134

Closed tribustale closed 3 years ago

tribustale commented 3 years ago

Hi, really interesting library! I am trying to modify one of your modules, modules.RelationNetwork in particular.

Therefore, after importing all the necessary, I copy the definition of "class RelationNetwork" into my notebook:

class RelationNetwork2(_base.AbstractModule):
    def __init__(self,
                 edge_model_fn,
                 global_model_fn,
                 reducer=tf.math.unsorted_segment_sum,
                 name="relation_network"):

        super(RelationNetwork2, self).__init__(name=name)

        with self._enter_variable_scope():
            self._edge_block = blocks.EdgeBlock(
                edge_model_fn=edge_model_fn,
                use_edges=False,
                use_receiver_nodes=True,
                use_sender_nodes=True,
                use_globals=False)

            self._global_block = blocks.GlobalBlock(
                global_model_fn=global_model_fn,
                use_edges=True,
                use_nodes=False,
                use_globals=False,
                edges_reducer=reducer)

    def _build(self,
               graph,
               edge_model_kwargs=None,
               global_model_kwargs=None):

        output_graph = self._global_block(
            self._edge_block(graph, edge_model_kwargs), global_model_kwargs)
        return graph.replace(globals=output_graph.globals)

I initiate the class:

graph_network2 = RelationNetwork2(
    edge_model_fn=lambda: snt.nets.MLP([32, 32]),
    global_model_fn=lambda: snt.nets.MLP([32, 3]),
    reducer=tf.unsorted_segment_sum
)

I then pass the input_graphs to the model and try to build it:

output_graphs = graph_network2(input_graphs)

but it returns the follwing error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-94-7776a47429c6> in <module>
----> 1 output_graphs = graph_network2(input_graphs)

~/.virtualenvs/dl4cv/lib/python3.7/site-packages/sonnet/src/utils.py in _decorate_unbound_method(self, *args, **kwargs)
     87       def _decorate_unbound_method(self, *args, **kwargs):
     88         bound_method = f.__get__(self, self.__class__)  # pytype: disable=attribute-error
---> 89         return decorator_fn(bound_method, self, args, kwargs)
     90 
     91       return _decorate_unbound_method

~/.virtualenvs/dl4cv/lib/python3.7/site-packages/sonnet/src/base.py in wrap_with_name_scope(method, instance, args, kwargs)
    270     # snt.Module enters the module name scope for all methods. To disable this
    271     # for a particular method annotate it with `@snt.no_name_scope`.
--> 272     return method(*args, **kwargs)
    273 
    274 

~/.virtualenvs/dl4cv/lib/python3.7/site-packages/sonnet/src/utils.py in _decorate_unbound_method(self, *args, **kwargs)
     87       def _decorate_unbound_method(self, *args, **kwargs):
     88         bound_method = f.__get__(self, self.__class__)  # pytype: disable=attribute-error
---> 89         return decorator_fn(bound_method, self, args, kwargs)
     90 
     91       return _decorate_unbound_method

~/.virtualenvs/dl4cv/lib/python3.7/site-packages/sonnet/src/base.py in wrap_with_name_scope(method, instance, args, kwargs)
    270     # snt.Module enters the module name scope for all methods. To disable this
    271     # for a particular method annotate it with `@snt.no_name_scope`.
--> 272     return method(*args, **kwargs)
    273 
    274 

~/.virtualenvs/dl4cv/lib/python3.7/site-packages/graph_nets/_base.py in __call__(self, *args, **kwargs)
     76 
     77     def __call__(self, *args, **kwargs):
---> 78       return self._build(*args, **kwargs)
     79 
     80     @abc.abstractmethod

~/.virtualenvs/dl4cv/lib/python3.7/site-packages/sonnet/src/utils.py in _decorate_unbound_method(self, *args, **kwargs)
     87       def _decorate_unbound_method(self, *args, **kwargs):
     88         bound_method = f.__get__(self, self.__class__)  # pytype: disable=attribute-error
---> 89         return decorator_fn(bound_method, self, args, kwargs)
     90 
     91       return _decorate_unbound_method

~/.virtualenvs/dl4cv/lib/python3.7/site-packages/sonnet/src/base.py in wrap_with_name_scope(method, instance, args, kwargs)
    270     # snt.Module enters the module name scope for all methods. To disable this
    271     # for a particular method annotate it with `@snt.no_name_scope`.
--> 272     return method(*args, **kwargs)
    273 
    274 

<ipython-input-89-ef75c00a2016> in _build(self, edge_model_kwargs, global_model_kwargs, *graph)
     38 
     39         output_graph = self._global_block(
---> 40             self._edge_block(graph, edge_model_kwargs), global_model_kwargs)
     41         return graph.replace(globals=output_graph.globals)

~/.virtualenvs/dl4cv/lib/python3.7/site-packages/sonnet/src/utils.py in _decorate_unbound_method(self, *args, **kwargs)
     87       def _decorate_unbound_method(self, *args, **kwargs):
     88         bound_method = f.__get__(self, self.__class__)  # pytype: disable=attribute-error
---> 89         return decorator_fn(bound_method, self, args, kwargs)
     90 
     91       return _decorate_unbound_method

~/.virtualenvs/dl4cv/lib/python3.7/site-packages/sonnet/src/base.py in wrap_with_name_scope(method, instance, args, kwargs)
    270     # snt.Module enters the module name scope for all methods. To disable this
    271     # for a particular method annotate it with `@snt.no_name_scope`.
--> 272     return method(*args, **kwargs)
    273 
    274 

~/.virtualenvs/dl4cv/lib/python3.7/site-packages/sonnet/src/utils.py in _decorate_unbound_method(self, *args, **kwargs)
     87       def _decorate_unbound_method(self, *args, **kwargs):
     88         bound_method = f.__get__(self, self.__class__)  # pytype: disable=attribute-error
---> 89         return decorator_fn(bound_method, self, args, kwargs)
     90 
     91       return _decorate_unbound_method

~/.virtualenvs/dl4cv/lib/python3.7/site-packages/sonnet/src/base.py in wrap_with_name_scope(method, instance, args, kwargs)
    270     # snt.Module enters the module name scope for all methods. To disable this
    271     # for a particular method annotate it with `@snt.no_name_scope`.
--> 272     return method(*args, **kwargs)
    273 
    274 

~/.virtualenvs/dl4cv/lib/python3.7/site-packages/graph_nets/_base.py in __call__(self, *args, **kwargs)
     76 
     77     def __call__(self, *args, **kwargs):
---> 78       return self._build(*args, **kwargs)
     79 
     80     @abc.abstractmethod

~/.virtualenvs/dl4cv/lib/python3.7/site-packages/sonnet/src/utils.py in _decorate_unbound_method(self, *args, **kwargs)
     87       def _decorate_unbound_method(self, *args, **kwargs):
     88         bound_method = f.__get__(self, self.__class__)  # pytype: disable=attribute-error
---> 89         return decorator_fn(bound_method, self, args, kwargs)
     90 
     91       return _decorate_unbound_method

~/.virtualenvs/dl4cv/lib/python3.7/site-packages/sonnet/src/base.py in wrap_with_name_scope(method, instance, args, kwargs)
    270     # snt.Module enters the module name scope for all methods. To disable this
    271     # for a particular method annotate it with `@snt.no_name_scope`.
--> 272     return method(*args, **kwargs)
    273 
    274 

TypeError: _build() takes 2 positional arguments but 3 were given

Can you help me solve this?

Thanks

alvarosg commented 3 years ago

I think you copied the code from the current master version (which has not yet been released in pypi), which contains some additional features not available in the latest released in pypi, which is probably the one you have installed. Specifically, the edge_model_kwargs global_model_kwargs, attributes.

If you don't need those arguments you may just do:

class RelationNetwork2(_base.AbstractModule):
    def __init__(self,
                 edge_model_fn,
                 global_model_fn,
                 reducer=tf.math.unsorted_segment_sum,
                 name="relation_network"):

        super(RelationNetwork2, self).__init__(name=name)

        with self._enter_variable_scope():
            self._edge_block = blocks.EdgeBlock(
                edge_model_fn=edge_model_fn,
                use_edges=False,
                use_receiver_nodes=True,
                use_sender_nodes=True,
                use_globals=False)

            self._global_block = blocks.GlobalBlock(
                global_model_fn=global_model_fn,
                use_edges=True,
                use_nodes=False,
                use_globals=False,
                edges_reducer=reducer)

    def _build(self,
               graph,
               edge_model_kwargs=None,
               global_model_kwargs=None):

        output_graph = self._global_block(self._edge_block(graph))
        return graph.replace(globals=output_graph.globals)

If you need the arguments, you may want to install the library from the github repo directly:

pip install git+git://github.com/deepmind/graph_nets.git

Hope this helps!

tribustale commented 3 years ago

Thank you very much for the fast and correct answer! Everything working fine now!