google-deepmind / graph_nets

Build Graph Nets in Tensorflow
https://arxiv.org/abs/1806.01261
Apache License 2.0
5.36k stars 782 forks source link

Question about repeat implementation #141

Closed Padarn closed 3 years ago

Padarn commented 3 years ago

Hello!

I am reading through the source code for graph_nets to understand better what goes into building a GNN library, so sorry for some questions that may seem pointless.

I'm curious about the repeat implementation: You have function signature

def repeat(tensor, repeats, axis=0, name="repeat", sum_repeats_hint=None):

I can see how sum_repeats_hint is used, but it seems like a very small optimization that might even be taken care of by the tensorflow graph optimizer (Grappler) - If not perhaps it should be!

Also, Is there a reason not to use tf.repeat? https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/repeat

alvarosg commented 3 years ago

Thank you for your message. If I recall correctly, originally we had a separate implementation because tf.repeat was using a CPU-only op that required moving a large array between GPU and CPU, and it was causing an impact in performance. I think that was fixed over time, however, we kept our implementation to be able to pass sum_repeats_hint. This is used so the output of repeat has a statically fixed shape known at graph compilation time, which is (or at least was until recently) a requirement for the TPU compiler.

Padarn commented 3 years ago

Thanks @alvarosg for your response. I had to look into the TPU compiler a bit to understand what you meant but now its clear. Appreciate you taking the time, will close this.