JaxGaussianProcesses / GPJax

Gaussian processes in JAX.
https://docs.jaxgaussianprocesses.com/
Apache License 2.0
458 stars 53 forks source link

bug: triangular constraint not enforced in WhitenedVariationalGaussian #415

Closed meta-inf closed 11 months ago

meta-inf commented 11 months ago

Bug Report

GPJax version: 0.7.2 (issue also presents on main head)

Current behavior: For WhitenedVariationalGaussian, the triangular constraint for its variational_root_covariance is not being enforced, even though it is registered in the definition in parent class and implicitly assumed in code (compare this line with this). This leads to incorrect prior KL computation when evaluating ELBO.

Expected behavior:

In the following code, the matrix-valued parameter should have a tfp.bijctors.FillTriangular registered with it.

Steps to reproduce:

See below.

Related code:

import gpjax as gpx, jax.numpy as jnp

prior = gpx.Prior(
    mean_function=gpx.mean_functions.Zero(),
    kernel=gpx.kernels.Matern32(lengthscale=1., variance=1.))
likelihood = gpx.Gaussian(num_datapoints=10, obs_stddev=1.)

variational = gpx.VariationalGaussian(posterior=prior*likelihood, inducing_inputs=jnp.ones((3, 1)))
print(gpx.base.module.meta_leaves(variational))

variational = gpx.WhitenedVariationalGaussian(posterior=prior*likelihood, inducing_inputs=jnp.ones((3, 1)))
print(gpx.base.module.meta_leaves(variational))

Other information:

It seems like some inheritance-related bug, since the parent class doesn't have this issue.

daniel-dodd commented 11 months ago

Thanks good spot, @meta-inf. Have opened a quick PR to fix this.

daniel-dodd commented 11 months ago

This should be resolved in #416.