google-deepmind / jraph

A Graph Neural Network Library in Jax
https://jraph.readthedocs.io/en/latest/
Apache License 2.0
1.35k stars 88 forks source link

Ideas on Message Passing #2

Closed reshinthadithyan closed 3 years ago

reshinthadithyan commented 3 years ago

I) I'm looking onto implementing the convolution operation specified in Graph Isomorphism Network. Which transforms the Node Features by a set of Dense Layers. Can an haiku.module object be called on the node transformation function? If not how should that be done?

node_update_fn = haiku.Seq(node_feature) + haiku.Seq(incoming edge_feature)

jg8610 commented 3 years ago

Hi there, thanks for your question.

This question is mainly about how haiku works. The answer may be different depending on if you switch to another NN framework like flax, but the good news is that you can carry on using Jraph 👍

Haiku nets must be wrapped in another function. For example, from their docs:

def loss_fn(images, labels):
  mlp = hk.Sequential([
      hk.Linear(300), jax.nn.relu,
      hk.Linear(100), jax.nn.relu,
      hk.Linear(10),
  ])
  logits = mlp(images)
  return jnp.mean(softmax_cross_entropy(logits, labels))

So in your case you would need to write your node_update_fn as a function that contains the haiku nets

def node_update_fn(node, incoming_edge_feature):
  node_net = hk.Sequential([
      hk.Linear(300), jax.nn.relu,
      hk.Linear(100), jax.nn.relu,
      hk.Linear(10),
  ])
  edge_net = hk.Sequential([
      hk.Linear(300), jax.nn.relu,
      hk.Linear(100), jax.nn.relu,
      hk.Linear(10),
  ])
  return node_net(nodes) + edge_net(incoming_edge_feature)

It's convenient to write the whole graph net inside a function, that way you only need to apply haikus transform to the outer most function.

def forward_pass_graph_net(graph):
  net = jraph.GraphNetwork(update_node_fn=update_node_fn, ... )
  return net(graph)

You can then use hk.transform to transform your function into a pure function (with no side effects) so it can be used with jax: forward_pass_graph_net_t = hk.transform(forward_pass_graph_net)

Not: if you are using a configured graph net, by default you will also receive global features, and features from the edges for which you node is a sender. So just for completeness, you will have to handle those.

def node_update_fn(node, incoming_edge_feature, unused_outgoing_edge_feature, unused_global_feature):
  del unused_outgoing_edge_feature
  del unused_global_feature
  node_net = ...
  edge_net = ...
  return node_net(nodes) + edge_net(incoming_edge_feature)

Hope that's helpful, if you have any more questions let me know!