Open matthewrhysjones opened 8 months ago
Hey Matthew, I just ran into the same problem. I think the issue is in the __post_init__ of the Combination kernel class.
def __post_init__(self):
# Add kernels to a list, flattening out instances of this class therein, as in GPFlow kernels.
kernels_list: List[AbstractKernel] = []
for kernel in self.kernels:
if not isinstance(kernel, AbstractKernel):
raise TypeError("can only combine Kernel instances") # pragma: no cover
if isinstance(kernel, self.__class__):
kernels_list.extend(kernel.kernels)
else:
kernels_list.append(kernel)
self.kernels = kernels_list
Here it calculates a flattened list of kernels, and saves it to the the kernels attribute. When the kernel is called, it returns the operation of the kernel across all kernels in the kernel list
return self.operator(jnp.stack([k(x, y) for k in self.kernels]))
So the structure of operations of kernels is lost, it blindly applies the current operation (e.g. sum) for all sub-kernels. This explains why the results are consistent if all kernel operations are the same.
I assume the easy fix would be to have two attributes, self.kernels and self.flattened_kernels
This is indeed a bug, thank you for spotting it !
I don't think we need to have a separate flattened_kernels; I would either
a) change SumKernel
and ProductKernel
to be actual subclasses of CombinationKernel
(in which case the test on self.__class__
would only allow combining when the operation matches), or
b) explicitly add an additional check that self.operator is kernel.operator
.
Personally I'd prefer a) ... @thomaspinder @daniel-dodd ?
This issue has been marked as stale because it has been open for 7 days with no activity.
Bug Report
0.8.0
Current behavior:
This may be a problem with how I am interpreting how GPJax handles combination kernels, so sorry if I've missed something.
It seems that kernels which are a combination of a combination kernel are not being handled as expected when more than one type of combination operator is used (e.g the kernel is a sum of product kernels, or the kernel is a product of sum kernels). There doesn't appear to be a problem if both combination operators are identical (a sum of a sum kernel, or a product of product kernel).
Expected behavior:
When using a combination of combination kernel, predictive mean should be identical whether using GPJax or computing manually.
Steps to reproduce: see below
Related code:
there is a discrepancy between "mu_manual_pos" and "mu_pos" when I don't believe there should be. Also true if we use a kernel that is a sum of individual product kernels. However, if the combination operators are identical (sum of sum, product of products), then the results become the same, and so it appears there is some problem with the way that GPJax is handling combinations of combinations that contain multiple operators.
Other information:
I found this issue when I've been working with kernels that are combinations of combinations for a personal project, where I am seeing drastic differences between using GPJax and manual computation. I've tried to simplify the problem for this post to make it as clear as possible.