Open GilgameshD opened 1 day ago
This is interesting. Could it come form the dist_bures_squared
function that might not be exactly 0 on teh diagonal @eloitanguy ?
that is interesting, i'll look into it.
Hi, thanks for your Issue, I managed to reproduce it.
The issue stems from the fact that in this example (with np.random.seed(0)
), as @rflamary suggested, torch.diag(ot.gmm.dist_bures_squared(mu0, mu1, S0, S1))
is not the zero vector as it should be, and it turns out that it is because ot.dist(mu0, mu1)
has nonzero diagonal entries (10^(-5), as is coherent with the final GMM distance of roughly 10^(-5) instead of numerical 0).
If instead of torch.float32
you take torch.float64
, the 10^(-5) diagonal entries in ot.dist(mu0, mu1)
become 10^(-14) which is acceptable. It seems that is imprecision is somehow due to numerical imprecision in ot.dist
when using torch.float32
.
Thanks for identifying the problem! Is there any other solution rather than using torch.float64?
Hi, I don't really have other ideas, but maybe @rflamary would know?
I know that ot.dist
performs a check to verify if the data matrices are the same object, in which case it enforces the diagonal to be 0. This does not solve your issue, but it's closely related so I'm bringing it up anyway.
Describe the bug
The distance between two of the same GMMs is not 0. Sometimes the distance could be as large as 1e-3 when I use my own data. Is this because of the numerical problem?
To Reproduce
The output distance of the above code is 1.2001e-05.
Expected behavior
Environment (please complete the following information):
pip
,conda
): source