JaxGaussianProcesses / GPJax

Gaussian processes in JAX.
https://docs.jaxgaussianprocesses.com/
Apache License 2.0
436 stars 51 forks source link

bug: possible bug in handling kernels that are combinations of combinations #428

Open matthewrhysjones opened 8 months ago

matthewrhysjones commented 8 months ago

Bug Report

0.8.0

Current behavior:

This may be a problem with how I am interpreting how GPJax handles combination kernels, so sorry if I've missed something.

It seems that kernels which are a combination of a combination kernel are not being handled as expected when more than one type of combination operator is used (e.g the kernel is a sum of product kernels, or the kernel is a product of sum kernels). There doesn't appear to be a problem if both combination operators are identical (a sum of a sum kernel, or a product of product kernel).

Expected behavior:

When using a combination of combination kernel, predictive mean should be identical whether using GPJax or computing manually.

Steps to reproduce: see below

Related code:

xall = jnp.linspace(-5,5,1000)
toy_fun = lambda x: 1/5*x**2 + jnp.sin(x*5)**3 + jnp.cos(x*3)**2

xtrain = xall[0::25][:, None]
ytrain = toy_fun(xtrain)
xtest = xall[:, None]
ytest = toy_fun(xtest)

D = gpx.gps.Dataset(xtrain, ytrain)

kernel1 = gpx.kernels.RBF()
kernel2 = gpx.kernels.Matern32()
sum_kernel = kernel1 + kernel2

# using GPJax
pos_kernel = sum_kernel * sum_kernel # pos = product of sum

pos_prior = gpx.gps.Prior(mean_function=gpx.mean_functions.Zero(), kernel = pos_kernel)
pos_posterior = pos_prior * gpx.gps.Gaussian(D.n)

latent_dist_pos = pos_posterior.likelihood(pos_posterior(xtest, train_data=D))
mu_pos = latent_dist_pos.mean()
std_pos = latent_dist_pos.stddev()

# manual calculation of predictive dist

kxx = (kernel1.gram(xtrain).to_dense() + kernel2.gram(xtrain).to_dense()) * (kernel1.gram(xtrain).to_dense() + kernel2.gram(xtrain).to_dense())
kxt = (kernel1.cross_covariance(xtrain,xtest) + kernel2.cross_covariance(xtrain,xtest)) * (kernel1.cross_covariance(xtrain,xtest) + kernel2.cross_covariance(xtrain,xtest))
ktt = (kernel1.gram(xtest).to_dense() + kernel2.gram(xtest).to_dense()) * (kernel1.gram(xtest).to_dense() + kernel2.gram(xtest).to_dense())

L = jnp.linalg.cholesky(kxx + 1*jnp.eye(D.n))  #1 here is to match the obs noise as assigned in the GPJax likelihood
alpha = jnp.linalg.solve(L.T,jnp.linalg.solve(L,ytrain))
v = jnp.linalg.solve(L,kxt)

mu_manual_pos = kxt.T @ alpha
cov_manual_pos = ktt - v.T @ v
var_manual_pos = jnp.diag(cov_manual_pos) +1 # adding obs variance to match GPJax stddev output

plt.plot(xtest,mu_manual_pos,':')
plt.plot(xtest,mu_pos,'--')

there is a discrepancy between "mu_manual_pos" and "mu_pos" when I don't believe there should be. Also true if we use a kernel that is a sum of individual product kernels. However, if the combination operators are identical (sum of sum, product of products), then the results become the same, and so it appears there is some problem with the way that GPJax is handling combinations of combinations that contain multiple operators.

Other information:

I found this issue when I've been working with kernels that are combinations of combinations for a personal project, where I am seeing drastic differences between using GPJax and manual computation. I've tried to simplify the problem for this post to make it as clear as possible.

ChrisBoettner commented 8 months ago

Hey Matthew, I just ran into the same problem. I think the issue is in the __post_init__ of the Combination kernel class.

    def __post_init__(self):
        # Add kernels to a list, flattening out instances of this class therein, as in GPFlow kernels.
        kernels_list: List[AbstractKernel] = []

        for kernel in self.kernels:
            if not isinstance(kernel, AbstractKernel):
                raise TypeError("can only combine Kernel instances")  # pragma: no cover

            if isinstance(kernel, self.__class__):
                kernels_list.extend(kernel.kernels)
            else:
                kernels_list.append(kernel)

        self.kernels = kernels_list

Here it calculates a flattened list of kernels, and saves it to the the kernels attribute. When the kernel is called, it returns the operation of the kernel across all kernels in the kernel list

  return self.operator(jnp.stack([k(x, y) for k in self.kernels]))

So the structure of operations of kernels is lost, it blindly applies the current operation (e.g. sum) for all sub-kernels. This explains why the results are consistent if all kernel operations are the same.

I assume the easy fix would be to have two attributes, self.kernels and self.flattened_kernels

st-- commented 7 months ago

This is indeed a bug, thank you for spotting it !

I don't think we need to have a separate flattened_kernels; I would either a) change SumKernel and ProductKernel to be actual subclasses of CombinationKernel (in which case the test on self.__class__ would only allow combining when the operation matches), or b) explicitly add an additional check that self.operator is kernel.operator.

Personally I'd prefer a) ... @thomaspinder @daniel-dodd ?

github-actions[bot] commented 1 week ago

This issue has been marked as stale because it has been open for 7 days with no activity.