cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.46k stars 545 forks source link

[Bug] Erroneous detaching with (custom?) mean #2521

Open villetan opened 2 months ago

villetan commented 2 months ago

🐛 Bug

Hi,

First of all, thank you for developing such a versatile and efficient library! I suspect that I came across a bug when working with BoTorch, but I believe it is originating from GPyTorch.

When working with a custom mean function that depends on some parameters $\theta$ that we wish to optimize (e.g. NN feature extractor), the derivative of the predictive mean is wrong.

The predictive mean $\mu(\cdot)$ (given a custom mean function $m_\theta$) at $x^*$ is given by

$$ \mu(x^ | X, y) = m_\theta(x^) K{*}^T(K + \sigma I)^{-1} (y - m\theta(X))\text{.} $$

Easy test for the derivative is to predict at the observed data points $X$, which gives us (when the observational noise is small, $\sigma \approx 0$)

$$ \mu(X | X, y) \approx m\theta(X) + y - m\theta(X) = y $$

whose derivative w.r.t the mean module's parameters should be zero.

It appears that at least in the following specific case this does not happen, and seems to be related to incorrect detaching at one place (see below for a hypothetical location where this happens)

To reproduce

Code snippet to reproduce

from botorch.models import SingleTaskGP
import torch
torch.manual_seed(123)

#define a model for the mean
class LinearModel(torch.nn.Module):
    def __init__(self, D=1, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.beta = torch.nn.Parameter(torch.randn(1,D))

    def forward(self, X):
        return (X * self.beta).sum(-1)

lm = LinearModel(1)
opt = torch.optim.Adam(lm.parameters(), lr = 0.424242)#for zeroing grads

#generate some data
N=50
x_data = torch.linspace(0, 1, N).view(-1, 1)
y_data = 2*torch.sin(10*x_data)  + 0.01*torch.randn_like(x_data)

#gp prediction manually
def gp_pred(xstar, obs_X, obs_Y, prior_mean, gp, detach_bug=False):
    samples_prior_mean = prior_mean(xstar)
    obs_prior_mean = prior_mean(obs_X)
    gp_y = obs_Y - obs_prior_mean.unsqueeze(-1)

    #gp pred
    K = gp.covar_module(obs_X, obs_X).to_dense().detach() 
    likelihood_additive_noise = gp.likelihood.noise_covar.raw_noise.detach()
    KplusNoise = K + likelihood_additive_noise * torch.eye(K.shape[0])
    Kstar = gp.covar_module(xstar, obs_X).to_dense().detach()
    KpNinv = torch.linalg.inv(KplusNoise)
    #pred_mean = Kstar @ KpNinv @ gp_y
    pred_mean = Kstar @ torch.linalg.solve(KplusNoise, gp_y)
    if detach_bug: #the bug is here
        pred_mean = pred_mean.detach()
    pred_mean_og_scale = pred_mean.squeeze() + samples_prior_mean

    pred_var = gp.covar_module.outputscale.detach() - (Kstar @ KpNinv @ Kstar.T).diag()
    return pred_mean_og_scale, pred_var

#define the Gpytorch model
import gpytorch
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, mean_module, covar_module):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = mean_module
        self.covar_module = covar_module

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

#botorch model
gp_botorch = SingleTaskGP(x_data, y_data, mean_module=lm)
gp_botorch.likelihood.noise_covar.noise = 0.0001 #to simulate near noiseless prediction
gp_botorch.eval()
gp_botorch.likelihood.eval()

#gpytorch model
gp_gpytorch = ExactGPModel(x_data, y_data.squeeze(), gp_botorch.likelihood, lm, gp_botorch.covar_module)
gp_gpytorch.eval()

#test that manual, botorch and gpytorch actually output same predictions
x_star = torch.linspace(0, 1, 100).view(-1, 1)
pred_mean, pred_var = gp_pred(x_star, x_data, y_data, lm, gp_botorch)
pred_botorch = gp_botorch(x_star)
pred_gpytorch = gp_gpytorch(x_star)
#check botorch and gpytorch are equal
print("botorch vs. gpytorch")
print(torch.abs(pred_botorch.mean - pred_gpytorch.mean).max())
print(torch.abs(pred_botorch.variance - pred_gpytorch.variance).max())
#check predictions botorch and manual are approximately equal
print("botorch vs manual")
print(torch.abs(pred_mean - gp_botorch(x_star).mean).max())
print(torch.abs(pred_var - gp_botorch(x_star).variance).max())

#plotting
# import matplotlib.pyplot as plt
# plt.scatter(x_data, y_data, label="data")
# plt.plot(x_star, pred_mean.detach(), label="manual")
# plt.plot(x_star, pred_botorch.mean.detach(), label="botorch")
# plt.plot(x_star, pred_gpytorch.mean.detach(), label="gpytorch")
# plt.legend()
# plt.show()

#prediciton at the training points
opt.zero_grad()
pred_mean, pred_var = gp_pred(x_data, x_data, y_data, lm, gp_botorch)
pred_mean.mean().backward()
print(lm.beta.grad) #correct: ≈ 0

opt.zero_grad()
pred_mean, pred_var = gp_pred(x_data, x_data, y_data, lm, gp_botorch, detach_bug=True)
pred_mean.mean().backward()
print(lm.beta.grad) #incorrect: ≠ 0

opt.zero_grad()
pred_botorch = gp_botorch(x_data)
pred_botorch.mean.mean().backward()
print(lm.beta.grad) #incorrect: ≠ 0

opt.zero_grad()
pred_gpytorch = gp_gpytorch(x_data)
pred_gpytorch.mean.mean().backward()
print(lm.beta.grad) #incorrect: ≠ 0

Now the last three predictions are equal in outputs and in gradients (imo incorrect), but the first one matches in outputs and produces correct gradient.

Expected Behavior

The gradient of the predictive mean at the observation locations $X$ to be $0$. See #correct in the above snippet.

System information

Please complete the following information: