ott-jax / ott

Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.
https://ott-jax.readthedocs.io
Apache License 2.0
479 stars 77 forks source link

Timing comparison with POT #556

Closed zwei-beiner closed 1 week ago

zwei-beiner commented 2 weeks ago

I'm trying to calculate the W2 distance with OTT, and it seems to be ~10x slower than POT. This was run on CPU.

Is there any way to speed up the calculation with OTT?

Also, please let me know if this is the correct way of calculating the W2 distance with OTT (i.e. calculating it from the transport plan since there is no simple way of accessing it as an attribute of the Sinkhorn solver output).

import time
import numpy as np

import torch
import ot

import jax
import jax.numpy as jnp
from ott.geometry import costs, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

torch.set_default_dtype(torch.float64)
jax.config.update("jax_enable_x64", True)

def make_samples(nsamples, ndims):
    key1, key2 = jax.random.split(jax.random.PRNGKey(0))
    samples1 = jax.random.ball(key1, ndims, shape=(nsamples,))
    samples2 = jax.random.ball(key2, ndims, shape=(nsamples,))
    return samples1, samples2
samples1, samples2 = make_samples(3000, 10)

def W_pytorch(x, y):
    nsamples = x.shape[0]
    cost_matrix = ot.utils.dist(x, y)
    a = torch.ones(nsamples) / nsamples
    b = torch.ones(nsamples) / nsamples
    loss = ot.sinkhorn2(a=a, b=b, M=cost_matrix, reg=1e-2, stopThr=1e-06)
    return loss
print("POT:")
torch_samples1 = torch.from_numpy(np.asarray(samples1))
torch_samples2 = torch.from_numpy(np.asarray(samples2))
tic = time.time()
print(W_pytorch(torch_samples1, torch_samples2))
print("Time:", time.time() - tic)

@jax.jit
def W_jax(x, y):
    geom = pointcloud.PointCloud(x, y, epsilon=1e-2)
    ot_prob = linear_problem.LinearProblem(geom)
    solver = sinkhorn.Sinkhorn(threshold=1e-6)
    ot = solver(ot_prob)
    return jnp.sum(ot.matrix * ot.geom.cost_matrix)
print("OTT:")
for _ in range(3): # Running multiple times because of jit compilation time
    tic = time.time()
    print(W_jax(samples1, samples2))
    print("Time:", time.time() - tic)

Output:

POT:
tensor(0.2506)
Time: 11.448308944702148
OTT:
0.2506479110728857
Time: 130.14876246452332
0.2506479110728857
Time: 133.8663191795349
0.2506479110728857
Time: 130.24367785453796
michalk8 commented 2 weeks ago

Hi @zwei-beiner , there are 2 differences between the above benchmark:

  1. We run our computations by default in LSE mode (lse_mode=True in sinkhorn.Sinkhorn) to have better numerical stability
  2. We use by default max 2k iterations

Modifying some of the above code to

def W_pytorch(x, y):
    nsamples = x.shape[0]
    cost_matrix = ot.utils.dist(x, y)
    a = torch.ones(nsamples) / nsamples
    b = torch.ones(nsamples) / nsamples
    loss = ot.sinkhorn2(a=a, b=b, M=cost_matrix, method='sinkhorn_log', reg=1e-2, stopThr=1e-06)
    return loss

@jax.jit
def W_jax(x, y):
    geom = pointcloud.PointCloud(x, y, epsilon=1e-2)
    ot_prob = linear_problem.LinearProblem(geom)
    solver = sinkhorn.Sinkhorn(threshold=1e-6, lse_mode=True, max_iterations=1000)
    ot = solver(ot_prob)
    return ot.primal_cost

I get

tensor(0.2506)
Time: 44.20695209503174
OTT:
0.25064818126813293
Time: 29.805460929870605
0.25064818126813293
Time: 29.549583911895752
0.25064818126813293
Time: 29.669023990631104
zwei-beiner commented 2 weeks ago

Thanks for the reply! I can confirm that OTT is faster now.

I have another question: Suppose that I want to calculate $W_2$ for a sequence of distributions $\mu_n$ which converge to some distribution $\mu$ and I can only access the distributions through samples. I want to show that $W_2\rightarrow 0$ as $n\rightarrow \infty$. The problem is that $W_2$ converges very slowly to zero with increasing sample size, even when I calculate it on two sets of samples from the same distribution. Is there any way around this, i.e. can I force $W_2\rightarrow 0$?

marcocuturi commented 2 weeks ago

Thanks a lot @michalk8 and thanks @zwei-beiner for the question!

Comparing OTT and POT is challenging because some of the default parameters are not the same. Also, POT has multiple implementations, as mentioned above by @michalk8.

I would also add that the error (used to control convergence) used by default in OTT is a 1-norm, whereas it is a 2-norm in POT. As a consequence, you may have to pass norm_error=2.0 to your Sinkhorn solver, as done in the tutorial https://ott-jax.readthedocs.io/en/latest/tutorials/OTT_%26_POT.html

michalk8 commented 1 week ago

@zwei-beiner did you manage to reproduce the timing benchmarks with the above suggestions?

zwei-beiner commented 1 week ago

@michalk8 With the above suggestions, I get the following timings:

tensor(0.2392)
Time: 128.88550853729248
OTT:
0.23915911736256118
Time: 67.72269916534424
0.23915911736256118
Time: 69.8149642944336
0.23915911736256118
Time: 69.45572185516357

However, when I set lse_mode=False and method='sinkhorn', POT gets a significant speedup but OTT is essentially unchanged:

tensor(0.2392)
Time: 10.499557495117188
OTT:
0.23915911736256149
Time: 72.12179160118103
0.23915911736256149
Time: 72.88275790214539
0.23915911736256149
Time: 74.41572165489197
michalk8 commented 1 week ago

As mentioned above, you should be using method = "sinkhorn_log" when comparing to ours lse_mode=True.