Closed tisabe closed 3 years ago
Hey! Thanks for waiting for the response.
I was just looking at the paper, and I believe the existing GraphNetwork should do what you want. It follows this pseudocode:
updated_edges = edge_update_fn(previous_edges, senders, receivers, globals_)
updated_nodes = node_update_fn(nodes, update_edges_senders, updated_edges_receivers, globals_)
return GraphsTuple(nodes=update_nodes, edges=updated_edges)
Is this different to what is described in the paper?
Yes, this is also how I intended to write the network. However, to me it looks like the inputs of the node_update_fn are too limited to compute the node update as described in the paper. In the paper, each node update depends on the node features themselves and an aggregation of the incoming messages. The messages (defined for every edge) depend on edge features, sending and receiving nodes.
I'll try to put it into pseudocode, as it is a bit hard to describe:
messages = message_fn(edges, senders, receivers)
aggregate_messages_per_node = aggregate_message_fn(messages)
updated_nodes = node_update_fn(nodes, aggregate_messages_per_node)
I think this does not fit within the update_node_fn in GraphNetwork. The update_node_fn in GraphNetwork uses only aggregated edges, but to calculate the message function it needs individual edge features.
Ah, thanks for clarifying!
I think the easiest way to accomplish this is by using structured messages (pseudocode)
# initialize and edge message with zeros:
edge_message = {'message': jnp.zeros(...), 'latent': jnp.zeros(...)} # I'm using the phrase latent here for 'update' in the paper for disambiguation with jraph.
def update_edge_fn(edges, senders, receivers):
edges['latent'] = update_latent_fn(edges['latent'], senders, receivers) # the 'update_fn' in the paper
edges['message'] = update_message_fn(edges['latent'], senders, receivers)
return edges
In the node update function, you just need to make sure just to use the 'message' not the latent.
Yes, this looks like it solves my problem nicely. Thanks!
Hi there,
I was looking to implement a message passing network with edge updates as described in https://arxiv.org/abs/1806.03146. Looking at the Jraph paper, it is explained that calculating the messages M_t for each edge should be done with the edge update function phi_e, in the GraphNetwork from the model zoo. However, as I understand it, this prevents me from implementing a function that just updates the edges, based on the edge feature, sending and receiving node.
Is there a workaround using the current model zoo to seperate edge updates and edge-wise messages or is this a known problem?
Thanks!