PythonOT / POT

POT : Python Optimal Transport
https://PythonOT.github.io/
MIT License
2.44k stars 502 forks source link

The distance between two of the same GMMs is not 0 #695

Open GilgameshD opened 1 day ago

GilgameshD commented 1 day ago

Describe the bug

The distance between two of the same GMMs is not 0. Sometimes the distance could be as large as 1e-3 when I use my own data. Is this because of the numerical problem?

To Reproduce

import numpy as np
import torch
import ot

if __name__ == "__main__":
    K = 10
    D = 300
    pi0 = np.random.rand(K)
    pi0 /= np.sum(pi0)
    mu0 = np.random.rand(K, D)
    S0 = np.eye(D)[None].repeat(K, axis=0)

    pi0 = torch.as_tensor(pi0, dtype=torch.float32)
    mu0 = torch.as_tensor(mu0, dtype=torch.float32)
    S0 = torch.as_tensor(S0, dtype=torch.float32)

    pi1 = pi0.clone()
    mu1 = mu0.clone()
    S1 = S0.clone()

    print((pi0 == pi1).all())
    print((mu0 == mu1).all())
    print((S0 == S1).all())

    dist = ot.gmm.gmm_ot_loss(mu0, mu1, S0, S1, pi0, pi1)
    print(dist)

The output distance of the above code is 1.2001e-05.

Expected behavior

Environment (please complete the following information):

rflamary commented 1 day ago

This is interesting. Could it come form the dist_bures_squared function that might not be exactly 0 on teh diagonal @eloitanguy ?

eloitanguy commented 1 day ago

that is interesting, i'll look into it.

eloitanguy commented 1 day ago

Hi, thanks for your Issue, I managed to reproduce it. The issue stems from the fact that in this example (with np.random.seed(0)), as @rflamary suggested, torch.diag(ot.gmm.dist_bures_squared(mu0, mu1, S0, S1)) is not the zero vector as it should be, and it turns out that it is because ot.dist(mu0, mu1) has nonzero diagonal entries (10^(-5), as is coherent with the final GMM distance of roughly 10^(-5) instead of numerical 0). If instead of torch.float32 you take torch.float64, the 10^(-5) diagonal entries in ot.dist(mu0, mu1) become 10^(-14) which is acceptable. It seems that is imprecision is somehow due to numerical imprecision in ot.dist when using torch.float32.

GilgameshD commented 17 hours ago

Thanks for identifying the problem! Is there any other solution rather than using torch.float64?

eloitanguy commented 16 hours ago

Hi, I don't really have other ideas, but maybe @rflamary would know? I know that ot.dist performs a check to verify if the data matrices are the same object, in which case it enforces the diagonal to be 0. This does not solve your issue, but it's closely related so I'm bringing it up anyway.