ott-jax / ott

Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.
https://ott-jax.readthedocs.io
Apache License 2.0
524 stars 80 forks source link

Vmap-based parallelization possible with variable sized inputs? #573

Closed sankethvedula closed 1 month ago

sankethvedula commented 2 months ago

Hi,

Apologies if this is a trivial question; I couldn't find a simple answer in the docs/tutorials, I thought I could ask for your input.

I would like to solve variably-sized linear entropic OT problems in a batch.

Assume I have $N$ cost matrices $\mathbf{C}=[\mathbf{C}_1,\ldots,\mathbf{C}_N]$, that are not of the same shape. Similarly I have appropriately-sized marginals, $[a_1,\ldots,a_N]$ and $[b_1,\ldots,b_N]$, for each such OT problem.

Would it be possible to vmap the Sinkhorn solver to solve these problems in parallel with ott?

I read this tutorial. If I understand correctly, it shows how to vmap to solve several same sized OT problems, but I have not seen an instance where one solves differently-sized OT problems. Could you please let me know if I missed something?

Thanks!

marcocuturi commented 2 months ago

hi @sankethvedula , we're sorry for the very delayed answer.

I think you want to look at this : https://ott-jax.readthedocs.io/en/latest/_autosummary/ott.tools.segment_sinkhorn.segment_sinkhorn.html

we haven't written a tutorial for this functionality... but as usual, you can find proto-tutorials in tests: https://github.com/ott-jax/ott/blob/main/tests/tools/segment_sinkhorn_test.py

Thanks again and apologies for not being more reactive!

sankethvedula commented 2 months ago

Thanks @marcocuturi! This looks to be exactly what I need. I'll give it a shot and post here if it works, so it could be useful for future users.

marcocuturi commented 2 months ago

@sankethvedula all good on this or shall we close the issue?