JaxGaussianProcesses / GPJax

Gaussian processes in JAX.
https://docs.jaxgaussianprocesses.com/
Apache License 2.0
436 stars 51 forks source link

feat(gpjax/kernels/base.py): add diagonal #429

Closed stephen-huan closed 1 month ago

stephen-huan commented 8 months ago

Type of changes

Checklist

Description

AbstractKernel has wrappers for cross_covariance and gram but not for diagonal, when all three are defined by AbstractKernelComputation (diagonal is actually concretely implemented by AbstractKernelComputation).

This PR adds diagonal to AbstractKernel.

Although the change is trivial, I believe it's important for the following reasons.

Currently, the way to get the diagonal of a kernel matrix efficiently would be code like

import jax.numpy as jnp
from gpjax.kernels import DenseKernelComputation, Matern52

kernel = Matern52()
diagonal = DenseKernelComputation().diagonal
# diag([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.])
print(diagonal(kernel, jnp.zeros((10, 3))))

Other than being needlessly verbose, this obscures the fact that most of the *Computation classes (DiagonalKernelComputation, etc.) could be used in place of DenseKernelComputation without changing the underlying implementation from AbstractKernelComputation. Instead, code like

kernel.diagonal(jnp.zeros((10, 3)))

is cleaner and would allow specific kernels to provide a more efficient, specialized implementation (for example, the Matern family has a unit diagonal no matter the input data).

thomaspinder commented 6 months ago

Hi @stephen-huan - thanks for adding this. Would you be able to include a unit test for this function please?

stephen-huan commented 5 months ago

@thomaspinder I've added tests, sorry for the long delay!

stephen-huan commented 2 months ago

@thomaspinder any progress on this PR?

I added consistency tests that are tangentially related to this PR: now kernel.gram(x) is checked by kernel.cross_covariance(x, x), kernel.diagonal(x) is checked by diagonal(kernel.gram(x)), and kernel.cross_covariance(a, b) is checked by kernel.gram(c) where c is vstack((a, b)).

I also noticed a check of max(a - b) < tol, when it should probably be max(abs(a - b)) < tol.

thomaspinder commented 2 months ago

Thanks @stephen-huan - nice to see this included in GPJax!