JaxGaussianProcesses / GPJax

Gaussian processes in JAX.
https://docs.jaxgaussianprocesses.com/
Apache License 2.0
436 stars 51 forks source link

bug: unstabilities with float32 #207

Closed felixchalumeau closed 1 year ago

felixchalumeau commented 1 year ago

Bug Report

Hi guys, First, thanks for developing this package for the Jax community, it is very useful. I'd like to ask you a question about the instabilities encountered when using float32.

GPJax version: 0.5.9

Current behavior:

When running the regression example with the float32 precision, the first sample of the prior is full of NaNs, which is not the case with float64.

The current solution is to use float64 precision but this is not super convenient when trying to use GPJax in a broader context.

Expected behavior:

I am trying to use GPJax as a component of a broader codebase with quite large matrixes, I hence can't afford to move all my components to float64 precision. and more generally speaking, it would not be a good practice to change the behavior of the rest of my code just for my GP.

I would hence like to know if you have any idea of another way to avoid those instabilities (and hence those NaNs) while keeping float32? For instance, adding small offsets where the unstabilities can happen?

At the same time, could you give a bit more explanation (documentation) about the bijectors? It was very easy to get onboard with GPs using your documentation but as a non-expert of the field, bijectors is very unclear for me and a did not find enough resources in the documentation to really understand them. Additionally, the way they are defined in the code is quite complex to follow. Would it be possible to give me a bit more details about those? Thanks!

Other information:

Thanks for your time, Felix

daniel-dodd commented 1 year ago

Hi @felixchalumeau,

Thanks for opening this issue.

On your first point:

This issue arrises from the algorithm we use to solve the matrix inverse, the Cholesky decomposition. This is the traditional way to invert a positive definite matrix. This is however, unstable on low precisions... i.e., float32, so this behaviour is as expected. We plan to improve the scope of alternative methods e.g., Conjugate gradients, in future, that may provide better stability depending on the situation. Is it possible for your work to convert stuff in your mean and kernel function of your GP, from float32 so that they return float64? If not, let me know and lets work on a solution.

On your second point:

For our use case, think of bijectors as a convenience for defining the domain of parameters. For example, the obs_noise parameter of a Gaussian likelihood, must be positive, thus has domain $(0, \infty)$. However, when we conduct gradient optimisation, we must conduct gradient steps on the real line $(-\infty, \infty)$. To solve this, we could e.g., re-parameterise our obs_noise parameter, as log_obs_noise in our model, and optimise this instead. What a bijector does is exactly this, it allows us to define these maps between the constrained space $(0, \infty)$ and the unconstrained space $(-\infty, \infty)$ for reparameterisation of that kind for optimisation purposes (see also this issue ). We hope this process will be much cleaner in our v1.0 update, which will be based of the Mytree construction.

Hope this helps. Do let me know if you have any followup questions / if the above is not clear.

Cheers, Dan

felixchalumeau commented 1 year ago

Thanks for that quick and clear answer @Daniel-Dodd Yes, the solution consisting in converting to float64 locally is the one I am using atm.

Ok, thank you :+1:

mathDR commented 1 year ago

Just adding this here as a possible solution (this would be a LOT of work though and most likely out of the scope of the GPJax library).

Apparently there is a way of doing Cholesky in single precision (with a correction in double precision every few iterations). It was implemented here for a CELL processor, but I would be very interested in seeing it work on NVIDIA GPU, but for the above question, it would be possible to implement it in pure CPU pretty nontrivially (given the algorithm in the paper).

thomaspinder commented 1 year ago

Thanks for flagging @mathDR. Unfortunately, implementing such an algorithm is out of the scope of GPJax, but it's cool to see such work all the same!