google-deepmind / graph_nets

Build Graph Nets in Tensorflow
https://arxiv.org/abs/1806.01261
Apache License 2.0
5.34k stars 783 forks source link

issue with passing *_model_kwargs parameter #139

Closed fminiati closed 3 years ago

fminiati commented 3 years ago

Hi

thanks for building sonnet and graph_nets.

there seems to be an issue with passing edge_model_kwargs to a GraphNetwork module and even directly to its _edge_block function object. The same issue may occur with the other _model_kwargs parameters.

The error message I get is:

TypeError: _build() got an unexpected keyword argument 'edge_model_kwargs'

the code below reproduces the issue for me (the graph construction part is from your webpage).

Many thanks! fm

import numpy as np
import networkx as nx
import graph_nets as gn
import matplotlib.pyplot as plt

# Graph
graph_nx = nx.OrderedMultiDiGraph()

# Globals.
graph_nx.graph["features"] = np.array([0.6, 0.7, 0.8])

# Nodes.
graph_nx.add_node(0, features=np.array([0.3, 1.3]))
graph_nx.add_node(1, features=np.array([0.4, 1.4]))
graph_nx.add_node(2, features=np.array([0.5, 1.5]))
graph_nx.add_node(3, features=np.array([0.6, 1.6]))

# Edges.
graph_nx.add_edge(0, 1, features=np.array([3.6, 3.7]))
graph_nx.add_edge(2, 0, features=np.array([5.6, 5.7]))
graph_nx.add_edge(3, 0, features=np.array([6.6, 6.7]))

# turn into graph_tuple
input_graphs = gn.utils_tf.data_dicts_to_graphs_tuple( [gn.utils_np.networkx_to_data_dict(graph_nx)] )

# Create the graph network.
graph_net_module = gn.modules.GraphNetwork(
    edge_model_fn=lambda: snt.nets.MLP([32, 32]), #, dropout_rate=0.1),  #   <--- enabling this requires edge_model_kwargs
    node_model_fn=lambda: snt.nets.MLP([32, 32]),
    global_model_fn=lambda: snt.nets.MLP([32, 32]))

#output_graphs = graph_net_module(input_graphs)                                                       # <---- This works
output_graphs = graph_net_module(input_graphs, edge_model_kwargs={'is_training':True})              # <---- This doesn't
#output_graphs = graph_net_module._edge_block(input_graphs, edge_model_kwargs={'is_training':True})  # <---- This doesn't
alvarosg commented 3 years ago

Hi, thanks for your message! We have not made a library release since this feature was included, so the version that you can get with pip install does not have it, but you may instead instead the library directly from github, following the instructions here.

Long story short, install with: pip install git+git://github.com/deepmind/graph_nets.git :)

Hope this helps!

fminiati commented 3 years ago

Thanks for the quick reply. That fixed the issue!

fm

‐‐‐‐‐‐‐ Original Message ‐‐‐‐‐‐‐ On Wednesday, 17 March 2021 20:39, Alvaro @.***> wrote:

Hi, thanks for your message! We have not made a library release since this feature was included, so the version that you can get with pip install does not have it, but you may instead instead the library directly from github, following the instructions here.

Hope this helps!

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub, or unsubscribe.