google-deepmind / jraph

A Graph Neural Network Library in Jax
https://jraph.readthedocs.io/en/latest/
Apache License 2.0
1.37k stars 90 forks source link

Add segment_variance and segment_normalize to utils.py #5

Closed sooheon closed 3 years ago

sooheon commented 3 years ago

segment_normalize is an alternative to segment_softmax for normalizing attention weights across incoming edges.

Add tests

sooheon commented 3 years ago

Motivated by Richter and Wattenhofer, (2020), which uses normalization in place of softmax to good effect.

sooheon commented 3 years ago

I changed arg names of segment_mean to match segment_softmax and other segment fns from jax, but I realize that may be an annoying change, let me know if you want that reverted or have any other feedback.

sooheon commented 3 years ago

Yep I have signed the CLA

sooheon commented 3 years ago

By the way, do you have any performance tips regarding implementation of the scatter functions?

jg8610 commented 3 years ago

Thanks! I've merged.

What performance concerns do you have specifically? Some accelerators are more efficient that others with scatters/gathers (e.g. TPUs are less efficient than others). Your mileage may also vary depending on how XLA will optimize your jitted code - it may be that they are fused.

If there is something specific, I can take a look.