ott-jax / ott

Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.
https://ott-jax.readthedocs.io
Apache License 2.0
529 stars 82 forks source link

segment_topk and segment_sort #237

Open swamidass opened 1 year ago

swamidass commented 1 year ago

Is your feature request related to a problem? Please describe.

Graph networks often use sort_pooling, which requires sorting all the elements in a segment vector. The segment sizes are variable. There is current not a way to do this easily with ott-jax.

Describe the solution you'd like

It would be great for jax-ott to add an efficient way of doing soft-topk and/or soft-sorting of segments (like segment_sum or segment_max, but instead segment_topk, or segement_sort). The signare such functions might be:

def segment_topk(k :int, sorter: <1dvector>, segment_ids: <1d int vector>, num_segments: int, values : <1d or 2d vector>, fill_value) -> <2d or 3d array with first dimension same as sorter, second dimension = k, and third dimension same as values.shape[1]>

def segment_sort(sorter: <1dvector>, segment_ids: <1d int vector>, num_segments: int, values : <1d or 2d vector>) -> <1 or 2d array with first dimension matching sorter and second dimension matching values[-1]>:

In segment_topk, the fill values would be for missing values where k > the length of a segment.

Describe alternatives you've considered

I have an implementation of of segment_topk that works in jax, but does not use OTT. This implementation is jittable, but it is not a "soft" topk, nor is the implementation memory efficient.

There are also other architectures for these sort of problems based on attention that can work well in practice. However, with out a good ott implementation, there is no way to benchmark which ones are working best.

Additional context

The primitive for soft segment sum or segment topk makes most sense to add to ott-jax. Implementation of these functions will likely be useful to the jraph library (jax for graph networks).

swamidass commented 1 year ago

let me note, as well, that there is already a segmented sinkhorn solver:

https://ott-jax.readthedocs.io/en/latest/_autosummary/ott.tools.segment_sinkhorn.segment_sinkhorn.html?highlight=segment#ott.tools.segment_sinkhorn.segment_sinkhorn

So possibly this is a matter of creating a front end API function that makes use of this solver?

marcocuturi commented 1 year ago

Hi @swamidass ! thanks for your interest in OTT. At first sight, this does, indeed, look like a question of "ugpgrading" the current soft_sort to use the segment_sinkhorn API. We'll keep this in mind, or maybe someone can try to implement it, it's reasonably difficult I think.