ziatdinovmax / gpax

Gaussian Processes for Experimental Sciences
http://gpax.rtfd.io
MIT License
212 stars 28 forks source link

predict_in_batches does not appear to behave correctly #116

Open matthewcarbone opened 2 months ago

matthewcarbone commented 2 months ago

So during my attempts to run predict on very large "testing set" grids, I ran into the classic memory overflow issues that predict_in_batches tries to protect against. However, I don't think function is behaving as intended.

Consider this simple example. Let's use the data from the simpleGP.ipynb notebook,

np.random.seed(0)

NUM_INIT_POINTS = 25 # number of observation points
NOISE_LEVEL = 0.1 # noise level

# Generate noisy data from a known function
f = lambda x: np.sin(10*x)

X = np.random.uniform(-1., 1., NUM_INIT_POINTS)
y = f(X) + np.random.normal(0., NOISE_LEVEL, NUM_INIT_POINTS)

and fit a simple fully Bayesian GP

key1, key2 = gpax.utils.get_keys()
gp_model = gpax.ExactGP(1, kernel='RBF')
gp_model.fit(key1, X, y, num_chains=1)

Now, let's compare the result of predicting using gp_model.predict with gp_model.predict_in_batches. First, we can draw samples from the GP using predict via

X_test = np.linspace(-1, 1, 100)
noiseless = True
_, samples = gp_model.predict(key2, X_test, n=200, noiseless=noiseless)

Similarly, we can do the "same" in batches:

_, samples_b = gp_model.predict_in_batches(key2, X_test, n=200, batch_size=10, noiseless=noiseless)

At least at first glance, one would expect samples and samples_b to be the same. They're not though.

The behavior of predict_in_batches is equivalent to the following:

from gpax.utils.utils import split_in_batches
samples_b = []
for xx in split_in_batches(X_test, batch_size=10):
  _, s = gp_model.predict(key2, xx, n=200, noiseless=noiseless)
  samples_b.append(s)
samples_b_check = np.concatenate(samples_b, axis=-1)

which can be confirmed empirically via plotting. For example, consider the 3rd sample of the kernel parameters, and the 3rd sample of the GP for that set of kernel parameters (arbitrary choices):

ii = 3
plt.plot(X_test, samples[ii, ii, :])
plt.plot(X_test, samples_b_check[ii, ii, :], "r")
plt.plot(X_test, samples_b[ii, ii, :], "k--")
plt.show()

This produces the plot image

One can immediately see that samples is correct, but the batch sampling method is not.

Why?

The equivalence between the samples_b and samples_b_check above showcases why this is the case. The first "chunk" of the input is drawn from the GP conditioned over only that range. The second chunk is conditioned separately, and despite using the same seed, it is not conditional (as it should be) on the previous chunk's draw. As such, we end up with this "janky" plot where each chunk of (in this case, 10) samples is correlated within its chunk, but de-correlated from the other chunks, when they should be correlated as per the kernel function.

Suggestions to fix this

I have an idea of how to fix this problem. At the end of the day, the expensive step is computing the mean and covariance matrix (due to the expensive inversion, probably).

Can we not calculate the posterior mean in batches exactly? Although we cannot calculate the posterior covariance matrix in batches exactly, we can approximate it as a block-diagonal matrix, where each block is batch_size x batch_size. The non-block components could be approximated as 0. Although highly unlikely to work for long-range (e.g. periodic) kernels, I think this will be better than the current implementation.

@ziatdinovmax what do you think about this?

ziatdinovmax commented 2 months ago

@matthewcarbone - Thanks for brining it up. I'm aware of this issue. It is caused by the randomness in drawing samples which is not consistent across different batches. Basically, this line: https://github.com/ziatdinovmax/gpax/blob/fad59bc61bcf5de3d3e301c81b826b25c0716002/gpax/models/gp.py#L292. The refactored version will fix it since it will replace sampling with exact computation of (co)variance.

matthewcarbone commented 2 months ago

@ziatdinovmax how are you doing the exact sampling batch-wise? I'm really curious because I can't figure out anything other than various approximative methods e.g. LOVE (used via gpytorch's pred_fast method or whatever).