Open feiyang-k opened 1 year ago
Thanks for your interest in unbalanced LR!
We've laid out the ideas here: https://arxiv.org/abs/2305.19727 and @michalk8 has been pushing these modifications. We haven't written a proper tutorial for it yet, though, and we might need a few more weeks.
Hi @feiyang-k ,
Is it possible to calculate UOT with LRSinkhorn with the current ott package? Or is there any alternative way to do that?
we've just pushed the ott-jax==0.4.4
version, could you please try updating?
And also, the documentation includes some pages on ott.solvers.linear.lr_utils.unbalanced_dykstra_lse and ott.solvers.linear.lr_utils.unbalanced_dykstra_kernel. I don't know how to use these methods.
These aren't supposed to be used by the user, you can just use the LRSinkhorn
, as in this tutorial and pass tau_a
or tau_b
into the LinearProblem
. The 2 above-mentioned functions are exposed in the docs so that users know which arguments can be passed via kwargs_dys
in LRSinkhorn
.
(I can calculate the UOT with batch-wise method, which will give me the dual solution. But in this case, I wish to get the transport map which can tell me which samples in the larger distribution is mapped to).
Both UOT and ULOT give you access to the transport map (both to materialize it, which however defeats the purpose of the LR, as well as how to apply the map to a vector/matrix).
We haven't added a tutorial for unbalanced LR solvers yes, but as @marcocuturi says, we will add one in the near future.
P.S.: if the resulting coupling is not good (by some metric you're using to evaluate), consider using the k-means
initializer for LRSinkhorn
(default is random
) + look into the convergence/cost curves stored in out = solver(prob); out.errors; out.costs
.
Hi @marcocuturi and @michalk8 ,
Thanks so much for the detailed info! The referenced paper looks exciting. This is exactly what I've been looking for.
I updated to ott-jax==0.4.4
with jax==0.4.6
and jaxlib==0.4.6+cuda11+cudnn82
. I tried a problem with the scale 1k by 10k. LRSinkhorn
with rank=200
solves it in a few seconds. If I change it ot_prob = linear_problem.LinearProblem(geom, tau_a=0.1)
, the solution will not finish in minutes. Changing the initializer to k-means
does not help.
I tried the docs
branch with ott-jax==0.4.5dev
and also updated to jax==0.4.10
and jaxlib==0.4.10+cuda12+cudnn88
. The situation is the same. The GPU is NVIDIA RTX A6000 with Driver Version: 530.30.02 CUDA Version: 12.1 and Python 3.9.0.
Is there any idea on this? What is the jaxlib
version you are using while developing these functions?
Thanks!
This block of a 10k*10k LR-OT problem solves in 6.8s
from ott.solvers.linear import sinkhorn_lr
geom = pointcloud.PointCloud(cld10k, cld10k, epsilon=1e-3)
ot_prob = linear_problem.LinearProblem(geom)
solver = sinkhorn_lr.LRSinkhorn(rank=int(200))
ot_lr = solver(ot_prob)
transp_cost = ot_lr.compute_reg_ot_cost(ot_prob, use_danskin=True)
transp_cost
This block of a 1k*10k LR-UOT problem does not complete after 10 minutes
from ott.solvers.linear import sinkhorn_lr
geom = pointcloud.PointCloud(cld1k, cld10k, epsilon=1e-3)
ot_prob = linear_problem.LinearProblem(geom, tau_a=0.1)
solver = jax.jit(sinkhorn_lr.LRSinkhorn(rank=200, initializer="k-means"))
ot_lr = solver(ot_prob)
transp_cost = ot_lr.compute_reg_ot_cost(ot_prob)
transp_cost
Here are some of my thoughts:
jaxlib
is the cause of the issues, as long as jax
can correctly use it, it should be finetau_a
might be too low, would try increasing it (unless you really want such unbalanced problem in your application)LRSinkhorn(..., kwargs_dys={"min_iter": ...})
tau_b=0.999
may improve the convergence of the inner Dykstra iterationsot_lr.errors
); maybe it converges, but never reaches it (which would result in doing the full 2k iterations)ot_lr.compute_reg_ot_cost
in the end, you can just access the pre-computed property transp_cost = ot_lr.reg_ot_cost
Thanks @michalk8 ! This helped a lot. I tried this one by one.
In this particular case, it seems k-means
isn't particularly helpful as the initial err seems even larger.
To my experience, k-means
clustering with scikit-learn
for 10k samples and 200 clusters usually finishes in seconds.
The callback
functions are super useful, especially for parameter tuning!
I added a gamma
parameter to the solver function. This time it converges timely! gamma=1
or gamma=0.1
both work well; gamma=10
won't converge. This may be helpful to add to the tutorial :)
solver = sinkhorn_lr.LRSinkhorn(rank=200, initializer="random", progress_fn=progress_fn, gamma=0.1, kwargs_dys={"max_iter":1000})
Thanks so much!
I added a gamma parameter to the solver function. This time it converges timely! gamma=1 or gamma=0.1 both work well; gamma=10 won't converge. This may be helpful to add to the tutorial :)
True, we should add this to the tutorial, will create an issue for this. Also we should mention there (+ in the docs) to have gamma * epsilon < 1
(when also using the entropic regularization).
Hi, I want use ULOT algorithm for our single cell sequencing data analysis, not only for its speed but also for its potential biological interpretation under low rank constraints. The question is whether the LOT or ULOT algorithm is suitable for data violating the unit simplex. More specifically, Let $x, y \in R^n$ be two gene expression vectors (non-negative), can I got the transport plan between $x, y$ under $||x||_1=a$ and $||y||_1=b$ where $a, b>1, a \not= b$? I don't want to normalise these two vectors, because the expression itself is meaningful information (somehow, this is a unbalanced optimal mass transport?).
Thank you !!
Hi,
I really love the
LRSinkhorn
. Not just for its speed (but yeah, it is really, fast), I think this method is natural and may have broader implications in practical problems as well.The current problem is that I cannot use it for UOT problems (but the codes do seem to have the support for that!)
I got the warning message that UOT is not supported for LRSinkhorn. But it seems the codes do include the functionality for computing UOT. And also, the documentation includes some pages on
ott.solvers.linear.lr_utils.unbalanced_dykstra_lse
andott.solvers.linear.lr_utils.unbalanced_dykstra_kernel
. I don't know how to use these methods.(I can calculate the UOT with
batch-wise
method, which will give me the dual solution. But in this case, I wish to get the transport map which can tell me which samples in the larger distribution is mapped to).Thank you!!