dfm / tinygp

The tiniest of Gaussian Process libraries
https://tinygp.readthedocs.io
MIT License
296 stars 24 forks source link

Predicting on new test inputs after conditioning #163

Open jameskermode opened 1 year ago

jameskermode commented 1 year ago

Thanks for tinygp, I find it very instructive and am planning to use it my teaching and research.

One feature I am so far missing is the ability to make predictions on a series of new test inputs after conditioning a GP. The current suggested usage seems to be to call GaussianProcess(kernel, X_train).condition(y_train, X_test) repeatedly for each X_test but this is a little wasteful if the training set is large, since it is possible to reuse the alpha coefficients.

I appreciate that sometimes one can combine all the predictions that will ever be needed into a single X_test matrix., but this is not always the case, e.g. when using a GP surrogate in an optimisation problem.

Here is a MWE showing a possible workaround:

import numpy as np
import tinygp

X_train = np.linspace(0, 1, 10)
y_train = np.sin(x)
X_test = np.linspace(0, 1, 100)

kernel = tinygp.kernels.ExpSquared()
gp = tinygp.GaussianProcess(kernel, x)
cond_gp = gp.condition(y, X_test).gp

y2 = cond_gp.mean_function.kernel.matmul(X_test, cond_gp.mean_function.X, cond_gp.mean_function.alpha)

As expected, this calculation reproduces the stored mean:

>>> np.abs(y2 - cond_gp.mean).max()
2.488494e-06

In my view it would be nice to expose this functionality a little more conveniently. Browsing the source, it looks like tinygp.means.Conditioned.__call__() should allow this, but for this example it gives a vector of length 1 rather than length 100, I think due to a problem with the vmap axes:

>>> cond_gp.mean_function(X_test)
Array(-0.0139983, dtype=float32)

I'd be happy to make a PR to change this function to work similarly to my example if you agree?

Once that is done, perhaps GaussianProcess.predict() could be modified to check if the GP has already been conditioned and if so use the stored alpha values? This would mean making the y argument optional. There might be a better way to expose the funtionality.

jameskermode commented 1 year ago

Just after posting I realised that cond_gp.mean_function() should be manually vmap-ed over the new test set, e.g. the following

>>> np.abs(jax.vmap(cond_gp.mean_function)(X_test) - cond_gp.mean).max()
2.488494e-06

So there's no need for a change there. But maybe the second part of my question is still valid: finding a way of making new predictions with a conditioned GP a bit more convenient.

dfm commented 1 year ago

Ha! I was in the process of explaining exactly that as your posted :D

perhaps GaussianProcess.predict() could be modified to check if the GP has already been conditioned and if so use the stored alpha values? This would mean making the y argument optional. There might be a better way to expose the funtionality.

There are strong reasons why this isn't really possible as proposed because of the way JAX relies on pure functional/side-effect free assumptions. So I'd say that documenting the use of the cond_gp.mean_function as you've found here, would be a much better approach. Perhaps we could add a tutorial showing how you might use tinygp in a context like this.

jameskermode commented 1 year ago

Thanks for the quick response!

Understood on the pure functions. I'll make a PR with a small documentation change as you suggest.

I was just experimenting to see if there's an analogous way to predict the (co)variance for new test points. There we can't avoid one new linear solve since we need to compute k^T * K^{-1} * k where k is the vector of covariances between test and training points and K is the covariance matrix for training data, but we could at least reuse the Cholesky factorisation. I'll have a play with that and see if I can come up with anything worth contributing.

dfm commented 1 year ago

Awesome - thank you!!

jameskermode commented 1 year ago

This is what I have come up with for predicting at new points given a prior GP gp and an already-conditioned GP cond_gp. (It would have been nicer to require only cond_gp, but I couldn't find a way to recover the original Cholesky factor needed to predict variance at new points from the conditioned GP; let me know if I missed something.)

Can you envisage that a cleaned up and documented version of this function could be added to the base GP class? I used a different name to maintain backwards compatibility for the current GaussianProcess.predict(), but it would also be possible to add a cond_gp optional argument to that if you think that would be neater, and then check that either y or cond_gp is present but not both.

def repredict(gp, cond_gp, X_test, return_var=False, return_cov=False):
    assert gp.X is cond_gp.mean_function.X
    assert gp.kernel is cond_gp.mean_function.kernel
    k_star = gp.kernel(gp.X, X_test)
    mean = k_star.T @ cond_gp.mean_function.alpha # reuse previous alpha
    if return_var or return_cov:
        v = gp.solver.solve_triangular(k_star)
        cov = gp.kernel(X_test, X_test) - v.T @ v
    if return_var:
        var = jnp.diag(cov)        
        return mean, var
    if return_cov:
        return mean, cov

mu1, cov1 = repredict(gp, cond_gp, X_test, return_cov=True)
mu2, cov2 = gp.predict(y_train, X_test, return_cov=True)

print(jnp.abs(mu1 - mu2).max(), jnp.abs(cov1 - cov2).max()) # check for consistency