GUDHI / gudhi-devel

The GUDHI library is a generic open source C++ library, with a Python interface, for Topological Data Analysis (TDA) and Higher Dimensional Geometry Understanding.
https://gudhi.inria.fr/
MIT License
258 stars 66 forks source link

Inconsistent Wasserstein distance #465

Open mglisse opened 3 years ago

mglisse commented 3 years ago

We sometimes get wildly different results with POT and hera

import gudhi.wasserstein
import gudhi.hera
import numpy as np
dgm1 = np.array([[2.23201301e+00, 2.30992361],
 [2.79212053e+00, 3.53268734],
 [9.42383758e-01, 4.08248747],
 [5.22799989e-01, 1.00000000e+16]])
dgm2 = np.array([[5.72533791e+00, 1.81300169e+01],
 [3.67182222e+00, 2.16658844e+01],
 [1.49374957e+01, 5.14375875e+01],
 [2.81675603e+00, 1.00000000e+16]])
print(gudhi.wasserstein.wasserstein_distance(dgm1,dgm2,order=2,internal_p=np.inf))
print(gudhi.hera.wasserstein_distance(dgm1,dgm2,order=2,internal_p=np.inf))

53.24944113844402 21.455667863335233

It looks like the first one did not match any point with the diagonal, whereas it seems better to match the first 3 points of each diagram to the diagonal.

mglisse commented 3 years ago

The problem seems to be in POT, likely a numerical issue due to the very large coordinates of the 4th point.

import numpy as np
import ot

M = np.array(
    [
        [2.50275352e02, 3.74653218e02, 2.41352736e03, 1.00000000e32, 1.51751540e-03],
        [2.13082030e02, 3.28812836e02, 2.29487946e03, 1.00000000e32, 1.37109800e-01],
        [1.97333083e02, 3.09175848e02, 2.24250550e03, 1.00000000e32, 2.46506283e00],
        [1.00000000e32, 1.00000000e32, 1.00000000e32, 5.26223432e00, 2.50000000e31],
        [3.84690152e01, 8.09465684e01, 3.33064175e02, 2.50000000e31, 0.00000000e00],
    ]
)

a = np.array([0.125, 0.125, 0.125, 0.125, 0.5])
b = np.array([0.125, 0.125, 0.125, 0.125, 0.5])
P = ot.emd(a=a, b=b, M=M, numItermax=2000000)
Q = np.array(
    [
        [0, 0, 0, 0, 0.125],
        [0, 0, 0, 0, 0.125],
        [0, 0, 0, 0, 0.125],
        [0, 0, 0, 0.125, 0],
        [0.125, 0.125, 0.125, 0, 0.125],
    ]
)
assert (P.sum(axis=0) == a).all()
assert (P.sum(axis=1) == a).all()
assert (Q.sum(axis=0) == a).all()
assert (Q.sum(axis=1) == a).all()
print(np.sum(np.multiply(P, M)))
print(np.sum(np.multiply(Q, M)))
tlacombe commented 3 years ago

Ok I could reproduce the issue with your mwe. What should we do, forward this to POT? In your initial example, were the points with 1e16 coordinates representing ``essential'' parts ?

mglisse commented 3 years ago

forward this to POT?

I think so, although there is a small chance they may consider it "normal".

mglisse commented 3 years ago

See https://github.com/PythonOT/POT/issues/229 .