pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.31k stars 246 forks source link

contrib.hsgp: support vector-valued kernel hyperparameters #1805

Closed brendancooley closed 5 months ago

brendancooley commented 6 months ago

1803 implements support for multidimensional Hilbert Space Gaussian Process approximations. However, it only supports estimation of a single set of kernel hyperparameters (e.g. squared exponential lengthscale). In principal, the lengthscale can vary across dimensions of the input space (see Riutort-Mayol 2022 eq 1, 2, and 3 for the associated spectral density functions).

Implementation requires updating the spectral_density_matern and spectral_density_squared_exponential (numpyro.contrib.hsgp.spectral_densities) to accept and process array-valued inputs for the length parameter (this may also require upstream changes to the vmap in diag_spectral_density_squared_exponential and diag_spectral_density_matern). The test models test_squared_exponential_gp_model and test_matern_gp_model (test.contrib.hsgp.test_approximation) should be updated to optionally sample vector-valued lengthscales and test cases demonstrating the functionality should be created.

brendancooley commented 5 months ago

Notes to self as I start working on this...

We might consider first adding some tests that ensure that our kernel approximations come close to matching the exact versions for m large enough. Something like the following:

from sklearn.gaussian_process.kernels import RBF

import jax.numpy as jnp

from numpyro.contrib.hsgp.laplacian import eigenfunctions, sqrt_eigenvalues
from numpyro.contrib.hsgp.spectral_densities import (
    diag_spectral_density_squared_exponential,
)

x1 = jnp.array([1.0, 1.0])[None, ...]
x2 = jnp.array([0.0, 0.0])[None, ...]
m = 10
ell = 3
sqrt_eig_v = sqrt_eigenvalues(ell=ell, m=m, dim=2)
eig_f1 = eigenfunctions(x1, ell=ell, m=m)
eig_f2 = eigenfunctions(x2, ell=ell, m=m)
spd = diag_spectral_density_squared_exponential(1.0, 1.0, ell, m, 2)[None, ...]
approx = (eig_f1 * eig_f2 * spd).sum(axis=1)
exact = RBF(1.0)(x1, x2)
assert jnp.isclose(approx, exact)
samanklesaria commented 5 months ago

I'd be interested in finishing this up if you're not too far along! For the squared exponential at least, I'd assume we can do something like the following:

def spectral_density_squared_exponential(
   dim: int, w: ArrayImpl, alpha: float, length: float | ArrayImpl
) -> float:
    ...
    length = jnp.broadcast_to(length, dim)
    c = alpha * jnp.prod(jnp.sqrt(2 * jnp.pi) * length)
    e = jnp.exp(-0.5 * jnp.sum(w**2 * length**2))
    return c * e

This would allow for the current behavior, but also let us have different length-scales for each dimension.

I could add tests like you describe above to the test/contrib/hsgp/test_approximation.py file.

brendancooley commented 5 months ago

@samanklesaria go for it! Perhaps we can swap ideas and merge implementations. I have a little bit of wip here. I would like to try and support batch dimensions on the lengthscale to enable batched approx GPs, in addition lengthscale heterogeneity within a single GP. Just need to work out the API a bit. I have a working example with a few tests on that branch. Still need to do the Matern case, and maybe periodic as well.

For a use case on the batching, see hsgp_lvm.ipynb on this branch