Open matthewcarbone opened 2 months ago
@matthewcarbone - Thanks for brining it up. I'm aware of this issue. It is caused by the randomness in drawing samples which is not consistent across different batches. Basically, this line: https://github.com/ziatdinovmax/gpax/blob/fad59bc61bcf5de3d3e301c81b826b25c0716002/gpax/models/gp.py#L292. The refactored version will fix it since it will replace sampling with exact computation of (co)variance.
@ziatdinovmax how are you doing the exact sampling batch-wise? I'm really curious because I can't figure out anything other than various approximative methods e.g. LOVE (used via gpytorch's pred_fast method or whatever).
So during my attempts to run
predict
on very large "testing set" grids, I ran into the classic memory overflow issues thatpredict_in_batches
tries to protect against. However, I don't think function is behaving as intended.Consider this simple example. Let's use the data from the
simpleGP.ipynb
notebook,and fit a simple fully Bayesian GP
Now, let's compare the result of predicting using
gp_model.predict
withgp_model.predict_in_batches
. First, we can draw samples from the GP usingpredict
viaSimilarly, we can do the "same" in batches:
At least at first glance, one would expect
samples
andsamples_b
to be the same. They're not though.The behavior of
predict_in_batches
is equivalent to the following:which can be confirmed empirically via plotting. For example, consider the 3rd sample of the kernel parameters, and the 3rd sample of the GP for that set of kernel parameters (arbitrary choices):
This produces the plot
One can immediately see that
samples
is correct, but the batch sampling method is not.Why?
The equivalence between the
samples_b
andsamples_b_check
above showcases why this is the case. The first "chunk" of the input is drawn from the GP conditioned over only that range. The second chunk is conditioned separately, and despite using the same seed, it is not conditional (as it should be) on the previous chunk's draw. As such, we end up with this "janky" plot where each chunk of (in this case, 10) samples is correlated within its chunk, but de-correlated from the other chunks, when they should be correlated as per the kernel function.Suggestions to fix this
I have an idea of how to fix this problem. At the end of the day, the expensive step is computing the mean and covariance matrix (due to the expensive inversion, probably).
Can we not calculate the posterior mean in batches exactly? Although we cannot calculate the posterior covariance matrix in batches exactly, we can approximate it as a block-diagonal matrix, where each block is
batch_size
xbatch_size
. The non-block components could be approximated as 0. Although highly unlikely to work for long-range (e.g. periodic) kernels, I think this will be better than the current implementation.@ziatdinovmax what do you think about this?