secondmind-labs / GPflux

Deep GPs built on top of TensorFlow/Keras and GPflow
https://secondmind-labs.github.io/GPflux/
Apache License 2.0
120 stars 24 forks source link

Bug when computing the mean of marginal variational q(f) #55

Open jmaronas opened 2 years ago

jmaronas commented 2 years ago

Describe the bug I think there is a bug when computing the mean of q(f) = \int p(f|u)q(u), because the mean function evaluated at the inducing point locations is not added.

To reproduce

When computing q(f), the code is given by this line here: https://github.com/secondmind-labs/GPflux/blob/e05d7ba86aa3739a34dc6615de8f9ff810642605/gpflux/layers/gp_layer.py#L257

Note that the function conditional computes, for the mean mean_cond : K_xz K_zz_inv m . After, \mu(x) is added in the return method https://github.com/secondmind-labs/GPflux/blob/e05d7ba86aa3739a34dc6615de8f9ff810642605/gpflux/layers/gp_layer.py#L268

However, I am quite sure that the variational mean is computed as:

K_xz K_zz_inv m + mu(x) -K_xz K_zz_inv mu(z),

which means that for mean functions different from the zero mean function, a term -K_xz K_zz_inv mu(z) has to be added to the variational mean.

This could be easily solved ( I think ) by calling the conditional as follows:

mean_cond, cov = conditional(
            inputs,
            self.inducing_variable,
            self.kernel,
            self.q_mu - self.mean_function(self.inducing_variables),  ## HERE IS THE MODIFICATION
            q_sqrt=self.q_sqrt,
            full_cov=full_cov,
            full_output_cov=full_output_cov,
            white=self.whiten,
        )

I have been checking the source code from Gpflow and it does not look like the term K_xz K_zz_inv mu(Z) is being considered.

In case this bug is confirmed, I will open the issue in GPFLOW as well, because the SVGP model, for example, has the same problem.

jmaronas commented 2 years ago

I should have pointed that what I say is only needed when we don't use the whitened representation of inducing points. However, for this general case I still don't see where Gpflow's.conditional considers mean_function(Z).

st-- commented 2 years ago

We discussed this question on slack, copied here for posterity:

Is that not just a reparametrization? we learn the mean of q(u), so it shouldn't make a difference [except perhaps for initialization] whether we split it up into "K_xz Kzz_inv (m - mean_function(Z))" or "K_xz Kzz_inv m'" (if m['] is just a learned parameter)?

so in the DGP example with an identity mean function, the way it's coded at the moment should mean that if you initialize q_mu=zeros then the approximate posterior mean should be exactly the identity function (i.e., q_mu encodes "change away from the mean function", not "function value itself")

st-- commented 2 years ago

@jmaronas replied:

okay. I see there is a reason behind and not just a bug (which was my initial worry).

Regarding your point, I agree that with mean functions that does not have learnable parameters (zero or identity) one can see this as a reparameterization. However, I am not quite sure that if the mean function has learnable parameters this is correct. If we think eg in a linear mean function, then mu(Z) = WZ, with Z being inducing points (variational parameters) and W being model hyparameters. Then letting m' = m+WZ would mean that a variational parameter m' must learn something belonging to a model hyperparameter and I am not 100% sure we can do this, since the goal of the variational parameter is to target the posterior which is not the goal of the model hyperparameter. We could think in the variational EM algorithm: the variational updates just target the posterior for fixed hyperparameters W.

Anyway, I guess that this small bias in the learning algorithm might not affect the final predictive performance ( although I haven't thought closely about this claim). If I would have to implement it, I wouldn't assume this reparameterization for the reason above. Anyway, I understand that for the SVGP to work with multitask GPs, Gpflow source code would need a bit of refactoring if add the mu(z) computation, since mu(z) \in R^p (with p being p latent Gps) and mu(x) \in R^d (with d being the number of output tasks). Also tbh I am not sure if anyone uses non-whitened sparse Gp representations, which is the only case where computing mu(z) is needed.