google-deepmind / jraph

A Graph Neural Network Library in Jax
https://jraph.readthedocs.io/en/latest/
Apache License 2.0
1.37k stars 89 forks source link

Message Passing with edge updates #20

Closed tisabe closed 3 years ago

tisabe commented 3 years ago

Hi there,

I was looking to implement a message passing network with edge updates as described in https://arxiv.org/abs/1806.03146. Looking at the Jraph paper, it is explained that calculating the messages M_t for each edge should be done with the edge update function phi_e, in the GraphNetwork from the model zoo. However, as I understand it, this prevents me from implementing a function that just updates the edges, based on the edge feature, sending and receiving node.

Is there a workaround using the current model zoo to seperate edge updates and edge-wise messages or is this a known problem?

Thanks!

jg8610 commented 3 years ago

Hey! Thanks for waiting for the response.

I was just looking at the paper, and I believe the existing GraphNetwork should do what you want. It follows this pseudocode:

updated_edges = edge_update_fn(previous_edges, senders, receivers, globals_)
updated_nodes = node_update_fn(nodes, update_edges_senders, updated_edges_receivers, globals_)
return GraphsTuple(nodes=update_nodes, edges=updated_edges)

Is this different to what is described in the paper?

tisabe commented 3 years ago

Yes, this is also how I intended to write the network. However, to me it looks like the inputs of the node_update_fn are too limited to compute the node update as described in the paper. In the paper, each node update depends on the node features themselves and an aggregation of the incoming messages. The messages (defined for every edge) depend on edge features, sending and receiving nodes.

I'll try to put it into pseudocode, as it is a bit hard to describe:

messages = message_fn(edges, senders, receivers)
aggregate_messages_per_node = aggregate_message_fn(messages)
updated_nodes = node_update_fn(nodes, aggregate_messages_per_node)

I think this does not fit within the update_node_fn in GraphNetwork. The update_node_fn in GraphNetwork uses only aggregated edges, but to calculate the message function it needs individual edge features.

jg8610 commented 3 years ago

Ah, thanks for clarifying!

I think the easiest way to accomplish this is by using structured messages (pseudocode)

# initialize and edge message with zeros:
edge_message = {'message':  jnp.zeros(...), 'latent': jnp.zeros(...)}  # I'm using the phrase latent here for 'update' in the paper for disambiguation with jraph.

def update_edge_fn(edges, senders, receivers):
  edges['latent'] = update_latent_fn(edges['latent'], senders, receivers)  # the 'update_fn' in the paper 
  edges['message'] = update_message_fn(edges['latent'], senders, receivers)
  return edges

In the node update function, you just need to make sure just to use the 'message' not the latent.

tisabe commented 3 years ago

Yes, this looks like it solves my problem nicely. Thanks!