Closed gw265981 closed 5 hours ago
It appears the root of the problem lies in catastrophic cancellation, which happens when subtracting two approximations that are very close in value. The resulting approximation of the subtraction can have a very large precision error even if the individual approximations are good. E.g., take 2 numbers with 8 significant figures that coincide in the first 5 s.f.: 0.20802966, 0.20802812. Their difference is now 0.00000154, so we went from 8 s.f. to 3 s.f. If the original numbers are approximations, the resulting subtraction can vary a lot.
In our case we have MMD = K_xx + K_yy - 2*K_xy
. MMD measures a distance between distributions, so it will be close to 0 if the distributions are close, leading to catastrophic cancellation. To make matters worse, MMD seems to decrease when the number of data points, N, is large even if the coreset stays at a constant proportion of N:
I am still not sure what the best solution here would be. Some ideas:
1) Rewrite the MMD algorithm such that the difference is computed more directly. I don't know how feasible/easy this would be, so it would likely take a bit of time to research.
2) Increase the precision of computation. Not very feasible as it appears you have to set precision globally in JAX: https://github.com/jax-ml/jax/discussions/19443. We can warn users to enable this config in scripts where precision is important.
3) Increase the precision threshold/set the floor to 0 for MMD. A temporary hack if we want to postpone thinking about the full solution for now.
Logs for the example above:
K_xx:
32-bit: float32, mean time: 0.6503s, std: 0.2882
64-bit: float64, mean time: 1.3937s, std: 0.3300
Relative difference: 0.000273%
K_xy:
32-bit: float32, mean time: 0.0779s, std: 0.0117
64-bit: float64, mean time: 0.0956s, std: 0.0179
Relative difference: 0.000242%
K_yy:
32-bit: float32, mean time: 0.0615s, std: 0.0022
64-bit: float64, mean time: 0.0583s, std: 0.0080
Relative difference: 0.000027%
K_xx + K_yy - 2*K_xy:
float32: -1.341104507446289e-07
float64: 6.73454663216444e-08
Relative difference: 299.138053%
Setting the floor to zero (option 3) seems like the obvious option to me when it's numerical analysis explaining this - we can't just go on a threshold setting law combatting physics. We may want some safety check in place to guard against a coding bug.
I agree that if we go that route, just truncating to 0 is the best option. I assume the threshold was there originally to allow some precision tolerance but detect if something weird is happening - and I think we understand what is happening now.
@tp832944 @gw265981 I agree with setting the floor to zero - happy for code updates to be made to implement this.
What's the problem?
The MMD metric between a dataset and a coreset sometimes returns
nan
when the dataset is relatively large (order of 10^5 points).This is likely a precision error:
kernel_xx_mean + kernel_yy_mean - 2 * kernel_xy_mean
inMMD.compute
evaluates to a small negative number in these cases, but larger than the precision threshold that was put in place to catch this. Then the square root turns it intonan
.The issue goes away when double precision is used:
from jax import config; config.update("jax_enable_x64", True)
.The simplest solution is to increase the precision tolerance threshold to catch these cases, but I am not sure what would be an appropriate value.
We could just truncate the expression at 0, but it could mask a bug down the line. Also the coreset may be relatively small (<10%) for this to happen and it is strange to have an MMD of 0 in that case.
Another solution is to increase precision to float64, but I am not sure about the performance and other implications of this.
How can we reproduce the issue?
Python version
3.12
Package version
0.3.0
Operating system
Windows 10
Other packages
No response
Relevant log output
No response