google-deepmind / graph_nets

Build Graph Nets in Tensorflow
https://arxiv.org/abs/1806.01261
Apache License 2.0
5.34k stars 783 forks source link

error #42

Closed luyifanlu closed 5 years ago

luyifanlu commented 5 years ago

SMALL_GRAPH_1 = { "globals": [-1.1, -1.2, -1.3],

"nodes": [[-10.1, -10.2], [-20.1, -20.2], [-30.1, -30.2]],

"nodes": np.zeros((3,4,4,4), dtype=np.float32),
"edges": [[-101., -102., -103., -104.]],
"senders": [1,],
"receivers": [2,],

} input_graphs = utils_tf.data_dicts_to_graphs_tuple( [SMALL_GRAPH_1]) graph_net_module = gn.modules.GraphNetwork( edge_model_fn=lambda: snt.nets.MLP([32, 32]), node_model_fn=functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]), padding='VALID')), global_model_fn=lambda: snt.nets.MLP([32, 32]) ) output_graphs = graph_net_module(input_graphs)

When i run the code ,there ara an error(Shape must be rank 2 but is rank 4 for 'graph_network_9_1) I"d like to konw wheather only snt.nets.MLP() function can be use

alvarosg commented 5 years ago

A graph network is made of an EdgeBlock, a NodeBlock, and a GlobalBlock, connected sequentially, each of them using the edge_model_fn, node_model_fn and global_model_fn respectively.

The input to the first edge_model consistes of each edge concatenated with the globals, and the nodes connected by that edge. The problem in your examples is that each node is rank 2, but the each edge and the globals are rank 1, so they cannot be concatenated together along the last axis.

If you just want to apply a function to the nodes, globals and edges independently, you could use GraphIndependent:

import functools
import graph_nets as gn
import sonnet as snt
import numpy as np

SMALL_GRAPH_1 = {
    "globals": [-1.1, -1.2, -1.3],
    "nodes": np.zeros((3, 4, 4, 4), dtype=np.float32),
    "edges": [[-101., -102., -103., -104.]],
    "senders": [1,],
    "receivers": [2,],
}

input_graphs = utils_tf.data_dicts_to_graphs_tuple([SMALL_GRAPH_1])
graph_net_module = gn.modules.GraphIndependent(
    edge_model_fn=functools.partial(snt.nets.MLP, output_sizes=[32, 32]),
    node_model_fn=functools.partial(snt.Conv2D, output_channels=10, 
                                    kernel_shape=[3, 3], padding='VALID'),
    global_model_fn=functools.partial(snt.nets.MLP, output_sizes=[32, 32]))
graph_net_module(input_graphs)

And as you would expect, in this case, the output is still rank 4 in the nodes, and rank 2, in edges and globals.

Alternatively, you can first calculate embeddings of the nodes using convnets+flatten, and then using the nodes with the embeddings in a full graph network:

def node_embedding_model_fn():
  return snt.Sequential([snt.Conv2D(output_channels=10, 
                                    kernel_shape=[3, 3], padding='VALID'),
                         snt.BatchFlatten()])

nodes_encoder = gn.modules.GraphIndependent(
    node_model_fn=node_embedding_model_fn)

graph_net_module = gn.modules.GraphNetwork(
    edge_model_fn=functools.partial(snt.nets.MLP, output_sizes=[32, 32]),
    node_model_fn=functools.partial(snt.nets.MLP, output_sizes=[32, 32]),
    global_model_fn=functools.partial(snt.nets.MLP, output_sizes=[32, 32]))

graphs_with_node_embeddings = nodes_encoder(input_graphs)
graph_net_module(graphs_with_node_embeddings)

Another possible use case is to have the globals and edges also be image-like tensors of the right sizes (both before and after the model) to be concatenated along the channel axis.

SMALL_GRAPH_1 = {
    "globals": np.zeros((4, 4, 4), dtype=np.float32),
    "nodes": np.zeros((3, 4, 4, 4), dtype=np.float32),
    "edges": np.zeros((1, 4, 4, 4), dtype=np.float32),
    "senders": [1,],
    "receivers": [2],
}

input_graphs = utils_tf.data_dicts_to_graphs_tuple([SMALL_GRAPH_1])

model_fn = functools.partial(snt.Conv2D, output_channels=10, 
                            kernel_shape=[3, 3], padding='SAME')

graph_net_module = gn.modules.GraphNetwork(
    edge_model_fn=model_fn,
    node_model_fn=model_fn,
    global_model_fn=model_fn)

graph_net_module(input_graphs)

You may also choose to set some of the fields to None, but in that case you will need to customize the edge/node/global_block options to indicated they should ignore those fields.

What exactly are you trying to achieve?

alvarosg commented 5 years ago

Assuming solved due to inactivity, and closing.