[Bug] Erroneous detaching with (custom?) mean #2521

Open villetan opened 2 months ago

villetan commented 2 months ago

First of all, thank you for developing such a versatile and efficient library! I suspect that I came across a bug when working with BoTorch, but I believe it is originating from GPyTorch.

When working with a custom mean function that depends on some parameters $\theta$ that we wish to optimize (e.g. NN feature extractor), the derivative of the predictive mean is wrong.

The predictive mean $\mu(\cdot)$ (given a custom mean function $m_\theta$) at $x^*$ is given by

$$ \mu(x^ | X, y) = m_\theta(x^) K{*}^T(K + \sigma I)^{-1} (y - m\theta(X))\text{.} $$

Easy test for the derivative is to predict at the observed data points $X$, which gives us (when the observational noise is small, $\sigma \approx 0$)

$$ \mu(X | X, y) \approx m\theta(X) + y - m\theta(X) = y $$

whose derivative w.r.t the mean module's parameters should be zero.

It appears that at least in the following specific case this does not happen, and seems to be related to incorrect detaching at one place (see below for a hypothetical location where this happens)

To reproduce

Code snippet to reproduce

from botorch.models import SingleTaskGP
import torch

#define a model for the mean
class LinearModel(torch.nn.Module):
    def __init__(self, D=1, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.beta = torch.nn.Parameter(torch.randn(1,D))

    def forward(self, X):
        return (X * self.beta).sum(-1)

lm = LinearModel(1)
opt = torch.optim.Adam(lm.parameters(), lr = 0.424242)#for zeroing grads

#generate some data
x_data = torch.linspace(0, 1, N).view(-1, 1)
y_data = 2*torch.sin(10*x_data)  + 0.01*torch.randn_like(x_data)

#gp prediction manually
def gp_pred(xstar, obs_X, obs_Y, prior_mean, gp, detach_bug=False):
    samples_prior_mean = prior_mean(xstar)
    obs_prior_mean = prior_mean(obs_X)
    gp_y = obs_Y - obs_prior_mean.unsqueeze(-1)

    #gp pred
    K = gp.covar_module(obs_X, obs_X).to_dense().detach() 
    likelihood_additive_noise = gp.likelihood.noise_covar.raw_noise.detach()
    KplusNoise = K + likelihood_additive_noise * torch.eye(K.shape[0])
    Kstar = gp.covar_module(xstar, obs_X).to_dense().detach()
    KpNinv = torch.linalg.inv(KplusNoise)
    #pred_mean = Kstar @ KpNinv @ gp_y
    pred_mean = Kstar @ torch.linalg.solve(KplusNoise, gp_y)
    if detach_bug: #the bug is here
        pred_mean = pred_mean.detach()
    pred_mean_og_scale = pred_mean.squeeze() + samples_prior_mean

    pred_var = gp.covar_module.outputscale.detach() - (Kstar @ KpNinv @ Kstar.T).diag()
    return pred_mean_og_scale, pred_var

#define the Gpytorch model
import gpytorch
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, mean_module, covar_module):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = mean_module
        self.covar_module = covar_module

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

#botorch model
gp_botorch = SingleTaskGP(x_data, y_data, mean_module=lm)
gp_botorch.likelihood.noise_covar.noise = 0.0001 #to simulate near noiseless prediction

#gpytorch model
gp_gpytorch = ExactGPModel(x_data, y_data.squeeze(), gp_botorch.likelihood, lm, gp_botorch.covar_module)

#test that manual, botorch and gpytorch actually output same predictions
x_star = torch.linspace(0, 1, 100).view(-1, 1)
pred_mean, pred_var = gp_pred(x_star, x_data, y_data, lm, gp_botorch)
pred_botorch = gp_botorch(x_star)
pred_gpytorch = gp_gpytorch(x_star)
#check botorch and gpytorch are equal
print("botorch vs. gpytorch")
print(torch.abs(pred_botorch.mean - pred_gpytorch.mean).max())
print(torch.abs(pred_botorch.variance - pred_gpytorch.variance).max())
#check predictions botorch and manual are approximately equal
print("botorch vs manual")
print(torch.abs(pred_mean - gp_botorch(x_star).mean).max())
print(torch.abs(pred_var - gp_botorch(x_star).variance).max())

#prediciton at the training points
pred_mean, pred_var = gp_pred(x_data, x_data, y_data, lm, gp_botorch)
print(lm.beta.grad) #correct: ≈ 0

pred_mean, pred_var = gp_pred(x_data, x_data, y_data, lm, gp_botorch, detach_bug=True)
print(lm.beta.grad) #incorrect: ≠ 0

pred_botorch = gp_botorch(x_data)
print(lm.beta.grad) #incorrect: ≠ 0

pred_gpytorch = gp_gpytorch(x_data)
print(lm.beta.grad) #incorrect: ≠ 0

Now the last three predictions are equal in outputs and in gradients (imo incorrect), but the first one matches in outputs and produces correct gradient.

Expected Behavior

The gradient of the predictive mean at the observation locations $X$ to be $0$. See #correct in the above snippet.

System information

