Closed zwei-beiner closed 1 week ago
Hi @zwei-beiner , there are 2 differences between the above benchmark:
lse_mode=True
in sinkhorn.Sinkhorn
) to have better numerical stabilityModifying some of the above code to
def W_pytorch(x, y):
nsamples = x.shape[0]
cost_matrix = ot.utils.dist(x, y)
a = torch.ones(nsamples) / nsamples
b = torch.ones(nsamples) / nsamples
loss = ot.sinkhorn2(a=a, b=b, M=cost_matrix, method='sinkhorn_log', reg=1e-2, stopThr=1e-06)
return loss
@jax.jit
def W_jax(x, y):
geom = pointcloud.PointCloud(x, y, epsilon=1e-2)
ot_prob = linear_problem.LinearProblem(geom)
solver = sinkhorn.Sinkhorn(threshold=1e-6, lse_mode=True, max_iterations=1000)
ot = solver(ot_prob)
return ot.primal_cost
I get
tensor(0.2506)
Time: 44.20695209503174
OTT:
0.25064818126813293
Time: 29.805460929870605
0.25064818126813293
Time: 29.549583911895752
0.25064818126813293
Time: 29.669023990631104
Thanks for the reply! I can confirm that OTT is faster now.
I have another question: Suppose that I want to calculate $W_2$ for a sequence of distributions $\mu_n$ which converge to some distribution $\mu$ and I can only access the distributions through samples. I want to show that $W_2\rightarrow 0$ as $n\rightarrow \infty$. The problem is that $W_2$ converges very slowly to zero with increasing sample size, even when I calculate it on two sets of samples from the same distribution. Is there any way around this, i.e. can I force $W_2\rightarrow 0$?
Thanks a lot @michalk8 and thanks @zwei-beiner for the question!
Comparing OTT and POT is challenging because some of the default parameters are not the same. Also, POT has multiple implementations, as mentioned above by @michalk8.
I would also add that the error (used to control convergence) used by default in OTT is a 1-norm, whereas it is a 2-norm in POT. As a consequence, you may have to pass norm_error=2.0
to your Sinkhorn solver, as done in the tutorial https://ott-jax.readthedocs.io/en/latest/tutorials/OTT_%26_POT.html
@zwei-beiner did you manage to reproduce the timing benchmarks with the above suggestions?
@michalk8 With the above suggestions, I get the following timings:
tensor(0.2392)
Time: 128.88550853729248
OTT:
0.23915911736256118
Time: 67.72269916534424
0.23915911736256118
Time: 69.8149642944336
0.23915911736256118
Time: 69.45572185516357
However, when I set lse_mode=False
and method='sinkhorn'
, POT gets a significant speedup but OTT is essentially unchanged:
tensor(0.2392)
Time: 10.499557495117188
OTT:
0.23915911736256149
Time: 72.12179160118103
0.23915911736256149
Time: 72.88275790214539
0.23915911736256149
Time: 74.41572165489197
As mentioned above, you should be using method = "sinkhorn_log"
when comparing to ours lse_mode=True
.
I'm trying to calculate the W2 distance with OTT, and it seems to be ~10x slower than POT. This was run on CPU.
Is there any way to speed up the calculation with OTT?
Also, please let me know if this is the correct way of calculating the W2 distance with OTT (i.e. calculating it from the transport plan since there is no simple way of accessing it as an attribute of the Sinkhorn solver output).
Output: