ott-jax / ott

Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.
https://ott-jax.readthedocs.io
Apache License 2.0
508 stars 79 forks source link

UOT: how to use LRSinkhorn for UOT? #428

Open feiyang-k opened 1 year ago

feiyang-k commented 1 year ago

Hi,

I really love the LRSinkhorn. Not just for its speed (but yeah, it is really, fast), I think this method is natural and may have broader implications in practical problems as well.

The current problem is that I cannot use it for UOT problems (but the codes do seem to have the support for that!)

I got the warning message that UOT is not supported for LRSinkhorn. But it seems the codes do include the functionality for computing UOT. And also, the documentation includes some pages on ott.solvers.linear.lr_utils.unbalanced_dykstra_lse and ott.solvers.linear.lr_utils.unbalanced_dykstra_kernel. I don't know how to use these methods.

(I can calculate the UOT with batch-wise method, which will give me the dual solution. But in this case, I wish to get the transport map which can tell me which samples in the larger distribution is mapped to).

Thank you!!

marcocuturi commented 1 year ago

Thanks for your interest in unbalanced LR!

We've laid out the ideas here: https://arxiv.org/abs/2305.19727 and @michalk8 has been pushing these modifications. We haven't written a proper tutorial for it yet, though, and we might need a few more weeks.

michalk8 commented 1 year ago

Hi @feiyang-k ,

Is it possible to calculate UOT with LRSinkhorn with the current ott package? Or is there any alternative way to do that?

we've just pushed the ott-jax==0.4.4 version, could you please try updating?

And also, the documentation includes some pages on ott.solvers.linear.lr_utils.unbalanced_dykstra_lse and ott.solvers.linear.lr_utils.unbalanced_dykstra_kernel. I don't know how to use these methods.

These aren't supposed to be used by the user, you can just use the LRSinkhorn, as in this tutorial and pass tau_a or tau_b into the LinearProblem. The 2 above-mentioned functions are exposed in the docs so that users know which arguments can be passed via kwargs_dys in LRSinkhorn.

(I can calculate the UOT with batch-wise method, which will give me the dual solution. But in this case, I wish to get the transport map which can tell me which samples in the larger distribution is mapped to).

Both UOT and ULOT give you access to the transport map (both to materialize it, which however defeats the purpose of the LR, as well as how to apply the map to a vector/matrix).

We haven't added a tutorial for unbalanced LR solvers yes, but as @marcocuturi says, we will add one in the near future. P.S.: if the resulting coupling is not good (by some metric you're using to evaluate), consider using the k-means initializer for LRSinkhorn (default is random) + look into the convergence/cost curves stored in out = solver(prob); out.errors; out.costs.

feiyang-k commented 1 year ago

Hi @marcocuturi and @michalk8 ,

Thanks so much for the detailed info! The referenced paper looks exciting. This is exactly what I've been looking for.

I updated to ott-jax==0.4.4 with jax==0.4.6 and jaxlib==0.4.6+cuda11+cudnn82. I tried a problem with the scale 1k by 10k. LRSinkhorn with rank=200 solves it in a few seconds. If I change it ot_prob = linear_problem.LinearProblem(geom, tau_a=0.1), the solution will not finish in minutes. Changing the initializer to k-means does not help.

I tried the docs branch with ott-jax==0.4.5dev and also updated to jax==0.4.10 and jaxlib==0.4.10+cuda12+cudnn88. The situation is the same. The GPU is NVIDIA RTX A6000 with Driver Version: 530.30.02 CUDA Version: 12.1 and Python 3.9.0.

Is there any idea on this? What is the jaxlib version you are using while developing these functions?

Thanks!


This block of a 10k*10k LR-OT problem solves in 6.8s


from ott.solvers.linear import sinkhorn_lr

geom = pointcloud.PointCloud(cld10k, cld10k, epsilon=1e-3)
ot_prob = linear_problem.LinearProblem(geom)

solver = sinkhorn_lr.LRSinkhorn(rank=int(200))
ot_lr = solver(ot_prob)
transp_cost = ot_lr.compute_reg_ot_cost(ot_prob, use_danskin=True)
transp_cost

This block of a 1k*10k LR-UOT problem does not complete after 10 minutes


from ott.solvers.linear import sinkhorn_lr

geom = pointcloud.PointCloud(cld1k, cld10k, epsilon=1e-3)
ot_prob = linear_problem.LinearProblem(geom, tau_a=0.1)

solver = jax.jit(sinkhorn_lr.LRSinkhorn(rank=200, initializer="k-means"))
ot_lr = solver(ot_prob)
transp_cost = ot_lr.compute_reg_ot_cost(ot_prob)
transp_cost

michalk8 commented 1 year ago

Here are some of my thoughts:

feiyang-k commented 1 year ago

Thanks @michalk8 ! This helped a lot. I tried this one by one.

solver = sinkhorn_lr.LRSinkhorn(rank=200, initializer="random", progress_fn=progress_fn, gamma=0.1, kwargs_dys={"max_iter":1000})

Thanks so much!

michalk8 commented 1 year ago

I added a gamma parameter to the solver function. This time it converges timely! gamma=1 or gamma=0.1 both work well; gamma=10 won't converge. This may be helpful to add to the tutorial :)

True, we should add this to the tutorial, will create an issue for this. Also we should mention there (+ in the docs) to have gamma * epsilon < 1 (when also using the entropic regularization).

duzc-Repos commented 6 months ago

Hi, I want use ULOT algorithm for our single cell sequencing data analysis, not only for its speed but also for its potential biological interpretation under low rank constraints. The question is whether the LOT or ULOT algorithm is suitable for data violating the unit simplex. More specifically, Let $x, y \in R^n$ be two gene expression vectors (non-negative), can I got the transport plan between $x, y$ under $||x||_1=a$ and $||y||_1=b$ where $a, b>1, a \not= b$? I don't want to normalise these two vectors, because the expression itself is meaningful information (somehow, this is a unbalanced optimal mass transport?).

Thank you !!