tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.23k stars 1.09k forks source link

Jittered/safe Cholesky Operation #1291

Open markvdw opened 3 years ago

markvdw commented 3 years ago

Hi TFP team!

Overview

The most common annoyance among Gaussian process users, is the failure of the Cholesky decomposition. I believe that this really is a major issue for both seasoned and new users. One solution is to implement a Jittered/safe Cholesky op. Such an op would attempt to run a Cholesky decomposition, and if a failure occurs due to non-positive definiteness, it would repeatedly try again with a larger "jitter" term.

I believe such an op would have a large positive impact for GP users. I also believe that the best place for such an op would be in TF or TFP. I would appreciate the opinion of people in the TFP team on:

In practice, this is 80% a feature request when it comes to the specific jittered Cholesky op. However, more broadly we are putting in significant effort to make GPs easier to use, which can benefit TFP as well. This could be a neat collaboration that could lead to a neat paper (see "ideal outcome" below).

The Problem

Sparse variational GPs are super elegant because they have a clear mathematical answer for setting every free parameter in the algorithm. However, training them is still a pain. Cholesky errors are the cause of 90% of all annoyances. They are easily fixed by adding a small diagonal term to the diagonal. This works great, and is the advice that seasoned users often give (see e.g. https://github.com/tensorflow/probability/issues/195 or the GPflow issue page).

However, this is not really satisfactory. Changing jitter is manual and requires effort. There is also no optimal setting for all hyperparameter settings. Often the issue occurs during a bad linesearch when running BFGS, or a strange proposal during HMC. In the current situation, the jitter that needs to be used is the worst case value that guarantees the code runs even for strange hyperparameter settings which are considered occasionally, but which would be thrown away instantly. This is a poor state of affairs since:

This causes problems for seasoned users and newcomers alike.

The solution

The solution is simple: We need to transparently adjust the amount of jitter to the situation at hand. GPy did this years ago in numpy code, where it was easier to catch exceptions. You try with small/zero jitter, and then increase jitter for a few iterations until the Cholesky runs without failure, or until a maximum jitter is reached, at which point you do actually raise an exception. This would solve the most frequently asked question about GPs.

For sparse variational GPs (Titsias 2009) there is a very neat mathematical justification for doing this: increasing the jitter, always penalises your optimisation objective.

Ideal outcome

Given our recent work on provably good inducing point selection [1], we now have a recipe for an efficient algorithm for selecting all of the sparse variational GPs parameters. With some additional tricks, we can make sparse variational GPs as easy to train as normal GPs, without the annoying Cholesky errors, and with guarantees on performance.

This will be impactful for two reasons:

I envisage a paper to come out of this work that

  1. develops a full procedure for setting all parameters in a sparse variational GP,
  2. shows how easy it is to train GPs with the new procedure, and
  3. sets a reliable baseline for future GP and Bayesian Deep Learning research.

Currently, the lack of a jittered Cholesky op is the main impediment to making this a reality. It's a small change that would make a huge difference for GPs.

If anyone wants to collaborate on making this paper a reality, let me know. I would appreciate help in either coding up such an op, or guidance on how to do it.

I'm curious to hear what you think!

[1] http://proceedings.mlr.press/v97/burt19a.html

srvasude commented 2 years ago

Hi just wanted to respond with a few updates here:

We have https://www.tensorflow.org/probability/api_docs/python/tfp/experimental/linalg/simple_robustified_cholesky?hl=ar-SY&skip_cache=true which uses an LDL internally.

I also want to point out retrying_cholesky and friends: https://www.tensorflow.org/probability/api_docs/python/tfp/experimental/distributions/marginal_fns/retrying_cholesky

This essentially computes a cholesky and tries to increase jitter on failure.

Happy to chat and collaborate more on this!