Closed abhiagwl closed 4 months ago
Hi, I'm not affiliated with OTT, but I stumbled upon your issue while looking for mine! Hope they won't mind if I answer :)
First of all, Sinkhorn computes a fast approximation of the Wasserstein distance. Scipy's function corresponds to epsilon=0
in this implementation. OTT does not allow epsilon=0, but you can get pretty close.
Second, Scipy's formula corresponds to ott.geometry.costs.Euclidean()
.
import ott
from ott.geometry import pointcloud
from ott.solvers import linear
X = jax.random.normal(jax.random.PRNGKey(0), shape = (100, 100))
Y = jax.random.normal(jax.random.PRNGKey(1), shape = (100, 100))
geom = pointcloud.PointCloud(X, Y, cost_fn=ott.geometry.costs.Euclidean(), epsilon=1e-3)
prob=ott.problems.linear.linear_problem.LinearProblem(geom)
ot=linear.sinkhorn.Sinkhorn()(prob)
ot.reg_ot_cost
Array(12.4431305, dtype=float32)
Thanks so much for your help @gjhuizing!!
yes, it's just what @gjhuizing said, OTT uses the SqEuclidean
cost function by default. it's unfortunate that the scipy
solver calls it wasserstein
and does not reference 1-wasserstein more explicitly!
Also, if we want to compare the outputs of the scipy
function with OTT's, it's safer to look at ot.primal_cost
(ot.reg_ot_cost
includes the entropy of the coupling)
Hi there!
Thank you for the great library. I have a very simple use case. I want to calculate the 1-Wasserstein distance between two sets of samples as done in scipy's
wasserstein_distance_nd
. Here is what I tried:But none of the 4 costs at the end match the value I would get from
scipy.stats.wasserstein_distance_nd(X, Y)
.I understand that it is possible that the values are different between the two implementations; however, they are very different right now. What am I missing?