ott-jax / ott

Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.
https://ott-jax.readthedocs.io
Apache License 2.0
524 stars 80 forks source link

How to calculate 1-Wasserstein distance between two sets of samples? #550

Closed abhiagwl closed 4 months ago

abhiagwl commented 4 months ago

Hi there!

Thank you for the great library. I have a very simple use case. I want to calculate the 1-Wasserstein distance between two sets of samples as done in scipy's wasserstein_distance_nd. Here is what I tried:

import ott
from ott.geometry import pointcloud
from ott.solvers import linear

geom = pointcloud.PointCloud(X, Y, cost_fn=ott.geometry.costs.PNormP(1))
prob=ott.problems.linear.linear_problem.LinearProblem(geom)
ot=linear.sinkhorn.Sinkhorn()(prob)
print([ot.primal_cost, ot.dual_cost, ot.reg_ot_cost, jnp.sum(ot.matrix * ot.geom.cost_matrix)])
#[Array(106.526634, dtype=float32), Array(57.67239, dtype=float32), Array(109.89263, dtype=float32), Array(106.526634, dtype=float32)]

But none of the 4 costs at the end match the value I would get from scipy.stats.wasserstein_distance_nd(X, Y).

import scipy
import jax

X = jax.random.normal(jax.random.PRNGKey(0), shape = (100, 100))
Y = jax.random.normal(jax.random.PRNGKey(1), shape = (100, 100))

scipy.stats.wasserstein_distance_nd(X, Y)
#12.43855146073213

I understand that it is possible that the values are different between the two implementations; however, they are very different right now. What am I missing?

gjhuizing commented 4 months ago

Hi, I'm not affiliated with OTT, but I stumbled upon your issue while looking for mine! Hope they won't mind if I answer :)

First of all, Sinkhorn computes a fast approximation of the Wasserstein distance. Scipy's function corresponds to epsilon=0 in this implementation. OTT does not allow epsilon=0, but you can get pretty close.

Second, Scipy's formula corresponds to ott.geometry.costs.Euclidean().

import ott
from ott.geometry import pointcloud
from ott.solvers import linear

X = jax.random.normal(jax.random.PRNGKey(0), shape = (100, 100))
Y = jax.random.normal(jax.random.PRNGKey(1), shape = (100, 100))

geom = pointcloud.PointCloud(X, Y, cost_fn=ott.geometry.costs.Euclidean(), epsilon=1e-3)
prob=ott.problems.linear.linear_problem.LinearProblem(geom)
ot=linear.sinkhorn.Sinkhorn()(prob)
ot.reg_ot_cost

Array(12.4431305, dtype=float32)

marcocuturi commented 4 months ago

Thanks so much for your help @gjhuizing!!

yes, it's just what @gjhuizing said, OTT uses the SqEuclidean cost function by default. it's unfortunate that the scipy solver calls it wasserstein and does not reference 1-wasserstein more explicitly!

Also, if we want to compare the outputs of the scipy function with OTT's, it's safer to look at ot.primal_cost (ot.reg_ot_cost includes the entropy of the coupling)