Closed sooheon closed 3 years ago
Motivated by Richter and Wattenhofer, (2020), which uses normalization in place of softmax to good effect.
I changed arg names of segment_mean
to match segment_softmax
and other segment fns from jax, but I realize that may be an annoying change, let me know if you want that reverted or have any other feedback.
Yep I have signed the CLA
By the way, do you have any performance tips regarding implementation of the scatter functions?
Thanks! I've merged.
What performance concerns do you have specifically? Some accelerators are more efficient that others with scatters/gathers (e.g. TPUs are less efficient than others). Your mileage may also vary depending on how XLA will optimize your jitted code - it may be that they are fused.
If there is something specific, I can take a look.
segment_normalize is an alternative to segment_softmax for normalizing attention weights across incoming edges.
Add tests