JaxGaussianProcesses / GPJax

Gaussian processes in JAX.
https://docs.jaxgaussianprocesses.com/
Apache License 2.0
438 stars 51 forks source link

feat: Heteroscedastic noise #113

Closed patel-zeel closed 2 weeks ago

patel-zeel commented 2 years ago

Feature Request

Describe the Feature Request

I think the current implementation does not support heteroscedastic noise variance.

https://github.com/thomaspinder/GPJax/blob/db40b9cb20103a5f7104b1ccd0ad12713f44bc06/gpjax/gps.py#L163

It can be tweaked with a few lines to support homoscedastic and heteroscedastic noise (if it does not break other checks and code).

Describe Preferred Solution

A generalized method of adding noise to the diagonal of the covariance matrix would solve the problem.

Related Code

def get_noisy_covariance(noise):
    covariance = jnp.ones((3,3))
    print("before\n", covariance)
    rows, columns = jnp.diag_indices_from(covariance)
    covariance = covariance.at[rows, columns].set(covariance[rows, columns]+noise)
    print("after\n", covariance)
noise = jnp.array(3.0)
get_noisy_covariance(noise)
before
 [[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]
after
 [[4. 1. 1.]
 [1. 4. 1.]
 [1. 1. 4.]]
noise = jnp.arange(3)
get_noisy_covariance(noise)
before
 [[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]
after
 [[1. 1. 1.]
 [1. 2. 1.]
 [1. 1. 3.]]

If the feature request is approved, would you be willing to submit a PR? Yes

thomaspinder commented 2 years ago

Thanks @patel-zeel for raising this. Supporting heteroscedastic likelihood functions is certainly something I'd like to support. In theory, it should be quite straightforward to do by making the Gaussian likelihood function return a vector of noise terms, and not a scalar value - the current status.

Unfortunately, I don't see myself being able to tend to this for at least another 6-8 weeks. However, I'd be very happy to support you in making a PR, if it's something you'd like to see in GPJax.

matthewrhysjones commented 1 year ago

Hi - I'm also very keen on being able to use a heteroscedastic noise model. Is there any update on this? Thanks :)

patel-zeel commented 1 year ago

Hi @matthewrhysjones, I am glad to know you are interested in this. @thomaspinder, and @Daniel-Dodd , I have some initial thoughts on how to go about this. Please let me know your thoughts:

  1. The most trivial way of doing this could be to learn noise variance at each training input, but this does not let us predict the noise variance at the test time.
  2. For my research project, I am modeling heteroscedastic noise as well as non-stationarity with input-dependent hyperparameters similar to Heinonen et al. where just doing the former is relatively easier. I am unsure about the best way to implement this method in GPJAX. I'd love to hear some pointers from you to get it started. To give you the context on the method, I am talking about only the MAP part of the Heinonen et al., where they use a latent GP to model the GP hyperparameters. At train time, we can get the non-whitened values of noise variance at each training input, prior $\log p(\boldsymbol{\omega})$ over them and the likelihood $\log p(\boldsymbol{y}|\boldsymbol{\omega})$. Then we optimize $\log p(\boldsymbol{\omega}) + \log p(\boldsymbol{y}|\boldsymbol{\omega})$ to get MAP estimate of $\boldsymbol{\omega}$ (the same concept extends to other hyperparams as well). At test time, we need to predict noise variance at new locations, conditioning on the noise variance at train points.
  3. Andersson et al. suggest a relatively simpler method of using a parametric model (basis function regression) to model the input-dependent lengthscale of Gibbs kernel with Gaussian basis functions (see section C in the appendix). The same trick can be used to model the noise variance as well. The idea is to place the mean of $m$ basis functions on a uniform grid and fix the standard deviations to be the distance between two grid points. Then we can learn $\theta$ multipliers to these basis functions by optimizing the log marginal likelihood as usual. At both train and test time, use the fitted basis function regressor to predict the noise variance values at the inputs.

I'd also love to know and discuss simpler and/or better methods that you may know/prefer to model heteroscedastic noise.

thomaspinder commented 1 year ago

Hi @patel-zeel and @matthewrhysjones

Right now, it would be possible to simply use a Gaussian likelihood object and manually set the variance parameter to be a vector (it will default to a scalar). GPJax will then be able to learn each element of the vector using type-II MLE. However, this will only work if you do not care about test-time inference as calling predict on a new set of inputs with differing shape to your conditioned dataset will throw a shape error. This is essentially 1) of @patel-zeel response.

I’ve not read the two papers that you are linked but I can take a look. More generally, I’d be very open to discussing a way to support heteroscedastic likelihoods as there are a range of alternative implementations (e.g., Lázardo-Gredilla & Titisias (2011) and Saul et. al., (2016).). It would be nice to have a flexible implementation that can easily accomodate some, if not all, of the aforementioned methods. Would either of you be keen to setting up a time where we can discuss what such a framework may look like?

patel-zeel commented 1 year ago

Happily @thomaspinder! Thank you for adding these papers; I'll take a look at them. Maybe It'll take me a day or two to go through them. My time zone is IST, so maybe if you could share something similar to doodle, we can find a common best time. We can also continue this discussion over GPJAX slack to avoid polluting the GitHub issue history and then later add a summary in this thread for future reference.

github-actions[bot] commented 3 weeks ago

There has been no recent activity on this issue. To keep our issues log clean, we remove old and inactive issues. Please update to the latest version of GPJax and check if that resolves the issue. Let us know if that works for you by leaving a comment. This issue is now marked as stale and will be closed if no further activity occurs. If you believe that this is incorrect, please comment. Thank you!

github-actions[bot] commented 2 weeks ago

There has been no activity on this PR for some time. Therefore, we will be automatically closing the PR if no new activity occurs within the next seven days. Thank you for your contributions.