google-deepmind / graph_nets

Build Graph Nets in Tensorflow
https://arxiv.org/abs/1806.01261
Apache License 2.0
5.34k stars 783 forks source link

Distributed training #121

Closed krzysztofrusek closed 3 years ago

krzysztofrusek commented 4 years ago

As far as I understand GraphNets represents batches of graphs as one disconnected graph. This is very efficient for single device training, however I think this approach does not work in a distributed training environment. My understanding is that a Strategy shards the batch along the first dimension, thus all observations along the first dimension must be independent. This is obviously not the case for a disconnected graph as parts of the same graph could end up being processed by different devices and in the end, message passing will be not accurate. Are there any plans to use e.g. RaggedTensors in graph_nets?

alvarosg commented 4 years ago

Hi, thanks for your message!

We currently have no plans to move towards RaggedTensors due to some limitations of GPU and TPU support w.r.t. Ragged Tensors. However, there are two ways that we often successfully run distributed training that does not require any changes to the library:

Hope this helps!

charlinergr commented 4 years ago

Hi !

Could you provide an example of a distributed training ?

alvarosg commented 4 years ago

We do not currently have any examples in the library, however some users seem to be doing distributed training with our library: