Closed mrksr closed 4 years ago
Chris: thoughts?
Markus: consider filing an issue on github to track the request?
Brian Patton | Software Engineer | bjp@google.com
On Wed, Jan 22, 2020 at 7:16 AM Markus Kaiser notifications@github.com wrote:
Hi all,
The VariationalGaussianProcess was reformulated using a VariationalKernel similar to the SchurComplement-Kernel. This kernel is defined with respect to the inducing points Z and contains a cholesky decomposition of the kernel.matrix(Z, Z) matrix. It is computed here:
This decomposition does not add a diagonal jitter, which sometimes leads to the decomposition failing, for example for initializations which happen to be suboptimal. Could the jitter be added using _add_diagonal_shift or is there a specific reason not to?
Thanks!
— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/tensorflow/probability/issues/739?email_source=notifications&email_token=AFJFSIY2PQXSW4EEJEJPR7TQ7A2LPA5CNFSM4KKEZKS2YY3PNVWWK3TUL52HS4DFUVEXG43VMWVGG33NNVSW45C7NFSM4IH5ZHRA, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSI5XYKCN52FEN6XEYIDQ7A2LPANCNFSM4KKEZKSQ .
@brianwa84 this is a github issue :)
@mrksr while adding a jitter here might allow the cholesky to proceed, I worry a bit about sprinkling jitter in too many places, especially in a non-user-configurable manner (and I'd also be concerned about demanding separate jitter parameters for each of the various places we take cholesky's).
I wonder if a viable workaround would be to try to ensure that the inducing points aren't too close together (in the kernel-induced metric). Often in my experience, ensuring they're initialized sufficiently far apart can help here. I.e., one might consider using QMC (e.g. Halton sequence) to initialize the inducing points, rather than something like uniform random over some hypercube.
As an addendum: you could consider adding a regularization term to the overall loss that discourages inducing points from getting too close. Whereas adding diagonal jitter amounts to a change in the model (hence my trepidation), such a regularization leaves the model fixed but changes the optimization procedure. It's also an easier lever for the user to wield freely, without any change to the VGP APIs.
Thoughts welcome!
Lol woops I thought it was on the tfprobability mailing list. :-)
On Wed, Jan 29, 2020, 1:49 PM Christopher Suter notifications@github.com wrote:
As an addendum: you could consider adding a regularization term to the overall loss that discourages inducing points from getting too close. Whereas adding diagonal jitter amounts to a change in the model (hence my trepidation), such a regularization leaves the model fixed but changes the optimization procedure. It's also an easier lever for the user to wield freely, without any change to the VGP APIs.
Thoughts welcome!
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/tensorflow/probability/issues/739?email_source=notifications&email_token=AFJFSI2WYDRNS33QPQSC3ULRAHFUVA5CNFSM4KKEZKS2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEKIKA4Y#issuecomment-579903603, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSI3DQJ73ECWAQDUOHV3RAHFUVANCNFSM4KKEZKSQ .
I have now found some time to play with the model and generate a small example. Thank you @csuter for the pointer to the halton sequence, that's definitely a more principled way to initialize inducing points.
Let's take a look at this relatively simple example using tfp.layers
for interop with keras. This is meant as a relatively real-worldy example, I'm aware it is not quite minimal.
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
class RBFLayer(tf.keras.layers.Layer):
def __init__(self, feature_dimensions, **kwargs):
super().__init__(**kwargs)
softplus_inverse_of_one = tfp.bijectors.Softplus().inverse(1.).numpy()
self.amplitude = self.add_weight(
name='amplitude',
shape=[1],
initializer=tf.keras.initializers.Constant(softplus_inverse_of_one),
)
self.length_scale = self.add_weight(
name='length_scale',
shape=feature_dimensions,
initializer=tf.keras.initializers.Constant(softplus_inverse_of_one),
)
@property
def kernel(self):
return tfp.math.psd_kernels.ExponentiatedQuadratic(
amplitude=tf.math.softplus(self.amplitude),
length_scale=tf.math.softplus(self.length_scale),
)
num_inducing_points = 50
halton_sequence = tfp.mcmc.sample_halton_sequence(1, num_inducing_points).numpy()
halton_sequence = (-2) + (2 - (-2)) * halton_sequence
tf.keras.backend.set_floatx('float64')
model = tf.keras.Sequential([
tfp.layers.VariationalGaussianProcess(
num_inducing_points=num_inducing_points,
kernel_provider=RBFLayer([1]),
inducing_index_points_initializer=tf.keras.initializers.Constant(halton_sequence),
),
])
model(np.array([[0.]]))
# Cholesky decomposition unsuccessful most of the time.
In this example, we formulate a 1D GP both on the input and output side using an RBF kernel. Keras assumes roughly standardized data and we initialize the RBF parameters with 1
. The lengthscale is slightly on the longer side, but I would call it the standard initialization for GPs. It's a bit hard for me to reproduce exactly when the initialization is bad, but the problem definitely gets more pronounced for low dimensions. Not very surprising, after all, as there is more chance of inducing points being close together. Also note that 50 inducing points is not very much. It would be good if 500 also worked.
It's quite common to add jitter to the inducing kernel cholesky, see for example:
In my own experience, using a "direct" cholesky is not numerically stable. Failure modes are hard to identify and be surprising especially to non-expert users. I guess that's why most implementations accept that adding jitter is the lesser of two evils.
I understand the dislike of changing the model and I agree that it is not the best solution. Your proposal of adding some prior (or regularization) that prevents inducing points from being to close together is interesting. I am however not sure if the end effect might not also be that the model is altered, just in a less obvious way?
I can see multiple ways forward:
K_mm
cholesky as to the decompositions in the posterior calculation.K_mm
I'm not sure about (3) as that jitter does not seem to be significantly different than all the other ones. (2) is what I would call the standard approach in other implementations and in my experience works fine. But I have never really tried to understand what might go wrong. I have briefly looked into (4) in the past and I seem to remember that while better algorithms exist, they tend to be hard to implement and computationally expensive. Not sure if they are worth it.
Sticking with the status quo, (1), might also be an option. Is there something we could recommend (and add to the layer implementation) that works at least most of the time?
@csuter Have you had time to take a look? I'm also happy to open a (rather simple) PR if that helps :).
Thanks for the thoughtful analysis, @mrksr. I think the simplest path forward would be to add the jitter as you originally suggested. A PR would be most welcome.
This should be fixed now via #813.
Hi all,
The
VariationalGaussianProcess
was reformulated using aVariationalKernel
similar to theSchurComplement
-Kernel. This kernel is defined with respect to the inducing pointsZ
and contains a cholesky decomposition of thekernel.matrix(Z, Z)
matrix. It is computed here: https://github.com/tensorflow/probability/blob/e0346910ecca45f6ee83a7deccb3fe7d690e296a/tensorflow_probability/python/distributions/variational_gaussian_process.py#L127This decomposition does not add a diagonal jitter, which sometimes leads to the decomposition failing, for example for initializations which happen to be suboptimal. Could the jitter be added using
_add_diagonal_shift
or is there a specific reason not to?Thanks!