google-deepmind / graph_nets

Build Graph Nets in Tensorflow
https://arxiv.org/abs/1806.01261
Apache License 2.0
5.34k stars 783 forks source link

custom node_model_fn and dropout #109

Closed TianrenWang closed 4 years ago

TianrenWang commented 4 years ago
graph_network = modules.InteractionNetwork(
        edge_model_fn=lambda: snt.nets.MLP(output_size=depth, dropout_rate=0.2),
        node_model_fn=lambda: snt.nets.MLP(output_size=depth, dropout_rate=0.2),
        global_model_fn=lambda: snt.nets.MLP(output_size=depth, dropout_rate=0.2))

This code only uses MLP for each block, but I want each block to do an additional normalization after the MLP (like layer norm). Also, how do you tell the graph_network when it is training to use dropout?

alvarosg commented 4 years ago

LayerNorm is easy to use combined with snt.Sequential

def model_fn():
  return snt.Sequential([snt.nets.MLP(output_sizes=[64, 64], snt.LayerNorm(axis=0))])

graph_network = modules.InteractionNetwork(
  edge_model_fn=model_fn,
  node_model_fn=model_fn)

BatchNorm and DropOut are a bit cumbersome to implement, because both of them require passing the is_training parameter when the module is called inside of the GraphNet. A potential way around this is to create the MLPs/other modules outside of the module, and then use them with two separate InteractionNetworks, one for training and one for testing:


def model_fn():
  return snt.nets.MLP(output_sizes=[64, 64], dropout_rate=0.2)

edge_model = model_fn()
node_model = model_fn()

graph_network_train = modules.InteractionNetwork(
  edge_model_fn=lambda:  functools.partial(edge_model, is_training=True),
  node_model_fn=lambda:  functools.partial(node_model, is_training=True))

graph_network_test = modules.InteractionNetwork(
  edge_model_fn=lambda:  functools.partial(edge_model, is_training=False),
  node_model_fn=lambda:  functools.partial(node_model, is_training=False))

And then use different modules at train and test time, which will internally use the same MLPs and the same variables.

In this case the variables will be "owned" by neither graph_network_train nor graph_network_test, as these two become pure functions, but instead by edge_model and node_model, so you should directly use edge_model.variables and node_model.variables to retrieve the model variables.

Alternatively you could just modify EdgeBlock, NodeBlock and GlobalBlock in blocks.py, andInteractionNetworkandGraphNetworkinmodules.pyto receive anis_training` parameter, and pass it to the function, although of course this will only work is the function expects it.

Hope this helps!

TianrenWang commented 4 years ago

I ended up just manually adding dropout with TensorFlow API after every pass through the graph network, but your method could work too.

I have one question about the MLP though. I created an MLP model for the edge and node block with output sizes 1 and 300, respectively (with input sizes 1 and 300). The sizes of the matrix weight variables created for these MLP models are (601, 1) and (601, 300). Why are the first dimensions size 601, shouldn't they be 1 and 300, respectively, since those are the input sizes?

EDIT: Nevermind, I got the answer for my question.

alvarosg commented 4 years ago

The input to the edge model consists of (for each edge): the input edge features (size=1), the node features of the sender node (size=300), and the node features of the receiver node (size=300). Hence 601, when they are concatenated together.

For the node MLP, the inputs (for each node) are: the aggregated edge outputs from the edge function (size=1, since you mention output size of 1 for the edges), and node input size (300), so I would expect that to be 301, and not 601).

I would not recommend having an output size of 1 for the edges, as this is essentially the message size, so that is going to dramatically reduce the capacity of the interactions between the nodes.

If instead you use modules.GraphIndependent, then the input sizes will be what you expect, and the MLPs will have the dimensions that you expect, but there won't be any message passing.

There are more details in our paper Relational inductive biases, deep learning, and graph networks.