google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

Why is the covariance a 2D matrix? #160

Open LeonhardStorm opened 2 years ago

LeonhardStorm commented 2 years ago

Hi there, thanks for making and maintaining this excellent project!

I was wondering why the covariance output of the predict function is a Matrix in the shape of the kernel function. As far as I can tell, in all of the examples only the diagonal values of this are used, so what do the other values represent? Are they relevant/useful at all?

Thanks!

romanngg commented 2 years ago

Indeed in examples and visualizations we have only used the diagonal entires (marginal variances), but in general outputs are described by a Gaussian process (GP) with non-zero non-diagonal covariance entries (see covariance expressions in eq. 13, 15, 16 in https://arxiv.org/pdf/1902.06720.pdf), hence we return the full covariance matrix of this GP. As for any GP, the off-diagonal entries just represent the covariance between outputs of your GP at two different input points.

If you want to sample outputs on x_test from your (posterior, post-training) GP, you would need the full covariance on x_test for it. You may also need the full covariance if, for example, you want to estimate / tune the predictive probability p(y_test | x_train, y_train, x_test) (pdf of the GP with the posterior mean and covariance on x_test that you obtain from predict). It can also be used in ensembling, if you use inverse-variance weighting (https://en.wikipedia.org/wiki/Inverse-variance_weighting#Multivariate_Case) to generate ensemble predictions on x_test given several GPs (see section E in https://arxiv.org/pdf/2007.15801.pdf).

HaozhenZhao commented 1 month ago

Hi @romanngg, thanks for such an amazing library! Regarding "sample outputs on x_test from posterior GP" part, could you please indicate why we need the full $n\cdot n$ matrix for sampling? I thought we would need $k \cdot k$ matrix for that (where k is the number of features). Also, when I try to get the full matrix with trace_axes=(), there will be an error TypeError: cannot reshape array of shape (100, 10) (size 1000) into shape (100,) (size 100), where I have 100 samples with 10 labels. However with trace_axes=(-1,)` its fine. Could you please indicate the reason, many thanks!