JaxGaussianProcesses / GPJax

Gaussian processes in JAX.
https://docs.jaxgaussianprocesses.com/
Apache License 2.0
436 stars 51 forks source link

bug: Matern kernels often return NaN during optimization #150

Closed dirmeier closed 1 year ago

dirmeier commented 1 year ago

Bug Report

GPJax version: 0.5.0

Hi, in virtually all cases that I have tested, when using Matern kernels the objective gets NaN at some point. A simple solution would be to exp-tranform all variables with a lower bound (in case you are not already doing this).

See this minimal reproducible example [1] with the bug and a kernel implementation with exponentiates scaling parameters.

Cheers, S

[1] https://colab.research.google.com/drive/1jS-36i7AL6ulNo4K_yBbERFhg4Gu2oIX?usp=sharing

dirmeier commented 1 year ago

Sorry, I closed this because I thought I just overlooked it in the manual, but it is actually not clear to me, how to constrain parameters. If I initialize parameters as in the introductory example on regression, which bijectors are created?

parameter_state = gpx.initialise( 
    posterior, key, kernel={"lengthscale": jnp.array([0.5])}    
)
params, trainable, bijectors = parameter_state.unpack()
print(bijectors)

{'kernel': {'lengthscale': <distrax._src.bijectors.lambda_bijector.Lambda at 0x7f503289e490>,
  'variance': <distrax._src.bijectors.lambda_bijector.Lambda at 0x7f503289e490>},
 'likelihood': {'obs_noise': <distrax._src.bijectors.lambda_bijector.Lambda at 0x7f503289e490>},
 'mean_function': {}}
daniel-dodd commented 1 year ago

Hi @dirmeier, there was a stability issue with the Euclidean distance in the Matern kernels. This has been addressed in the v0.5.1 release.

daniel-dodd commented 1 year ago

Sorry, I closed this because I thought I just overlooked it in the manual, but it is actually not clear to me, how to constrain parameters. If I initialize parameters as in the introductory example on regression, which bijectors are created?

@dirmeier Thanks for the feedback, we ought to add an example to the docs! With regards to performing transformations, we have two functions:

So for optimisation you would do the following:

parameter_state = gpx.initialise( 
    posterior, key, kernel={"lengthscale": jnp.array([0.5])}    
)
unconstrained_init_params, trainable, bijectors = parameter_state.unpack()

#1. Un-constrain your parameters:
unconstrained_init_params = gpx.unconstrain(params, bijectors)

#2. Do a train loop on the unconstrained space:
unconstrained_learned_params = ...

#3. Constrain you learned parameters:
learned_params = gpx.constrain(unconstrained_learned_params, bijectors)

But note that gpx.fit automatically handles transformations for you. So only pass "constrained" parameters in the original space to this method.

Hope this helps. :)

dirmeier commented 1 year ago

Great, thanks for the help.