Open jameskermode opened 1 year ago
Just after posting I realised that cond_gp.mean_function()
should be manually vmap
-ed over the new test set, e.g. the following
>>> np.abs(jax.vmap(cond_gp.mean_function)(X_test) - cond_gp.mean).max()
2.488494e-06
So there's no need for a change there. But maybe the second part of my question is still valid: finding a way of making new predictions with a conditioned GP a bit more convenient.
Ha! I was in the process of explaining exactly that as your posted :D
perhaps
GaussianProcess.predict()
could be modified to check if the GP has already been conditioned and if so use the stored alpha values? This would mean making the y argument optional. There might be a better way to expose the funtionality.
There are strong reasons why this isn't really possible as proposed because of the way JAX relies on pure functional/side-effect free assumptions. So I'd say that documenting the use of the cond_gp.mean_function
as you've found here, would be a much better approach. Perhaps we could add a tutorial showing how you might use tinygp
in a context like this.
Thanks for the quick response!
Understood on the pure functions. I'll make a PR with a small documentation change as you suggest.
I was just experimenting to see if there's an analogous way to predict the (co)variance for new test points. There we can't avoid one new linear solve since we need to compute k^T * K^{-1} * k
where k is the vector of covariances between test and training points and K is the covariance matrix for training data, but we could at least reuse the Cholesky factorisation. I'll have a play with that and see if I can come up with anything worth contributing.
Awesome - thank you!!
This is what I have come up with for predicting at new points given a prior GP gp
and an already-conditioned GP cond_gp
. (It would have been nicer to require only cond_gp
, but I couldn't find a way to recover the original Cholesky factor needed to predict variance at new points from the conditioned GP; let me know if I missed something.)
Can you envisage that a cleaned up and documented version of this function could be added to the base GP class? I used a different name to maintain backwards compatibility for the current GaussianProcess.predict()
, but it would also be possible to add a cond_gp
optional argument to that if you think that would be neater, and then check that either y
or cond_gp
is present but not both.
def repredict(gp, cond_gp, X_test, return_var=False, return_cov=False):
assert gp.X is cond_gp.mean_function.X
assert gp.kernel is cond_gp.mean_function.kernel
k_star = gp.kernel(gp.X, X_test)
mean = k_star.T @ cond_gp.mean_function.alpha # reuse previous alpha
if return_var or return_cov:
v = gp.solver.solve_triangular(k_star)
cov = gp.kernel(X_test, X_test) - v.T @ v
if return_var:
var = jnp.diag(cov)
return mean, var
if return_cov:
return mean, cov
mu1, cov1 = repredict(gp, cond_gp, X_test, return_cov=True)
mu2, cov2 = gp.predict(y_train, X_test, return_cov=True)
print(jnp.abs(mu1 - mu2).max(), jnp.abs(cov1 - cov2).max()) # check for consistency
Thanks for tinygp, I find it very instructive and am planning to use it my teaching and research.
One feature I am so far missing is the ability to make predictions on a series of new test inputs after conditioning a GP. The current suggested usage seems to be to call
GaussianProcess(kernel, X_train).condition(y_train, X_test)
repeatedly for eachX_test
but this is a little wasteful if the training set is large, since it is possible to reuse thealpha
coefficients.I appreciate that sometimes one can combine all the predictions that will ever be needed into a single
X_test
matrix., but this is not always the case, e.g. when using a GP surrogate in an optimisation problem.Here is a MWE showing a possible workaround:
As expected, this calculation reproduces the stored mean:
In my view it would be nice to expose this functionality a little more conveniently. Browsing the source, it looks like
tinygp.means.Conditioned.__call__()
should allow this, but for this example it gives a vector of length 1 rather than length 100, I think due to a problem with thevmap
axes:I'd be happy to make a PR to change this function to work similarly to my example if you agree?
Once that is done, perhaps
GaussianProcess.predict()
could be modified to check if the GP has already been conditioned and if so use the storedalpha
values? This would mean making they
argument optional. There might be a better way to expose the funtionality.