First of all, thank you for developing such a versatile and efficient library! I suspect that I came across a bug when working with BoTorch, but I believe it is originating from GPyTorch.
When working with a custom mean function that depends on some parameters $\theta$ that we wish to optimize (e.g. NN feature extractor), the derivative of the predictive mean is wrong.
The predictive mean $\mu(\cdot)$ (given a custom mean function $m_\theta$) at $x^*$ is given by
Easy test for the derivative is to predict at the observed data points $X$, which gives us (when the observational noise is small, $\sigma \approx 0$)
$$
\mu(X | X, y) \approx m\theta(X) + y - m\theta(X) = y
$$
whose derivative w.r.t the mean module's parameters should be zero.
It appears that at least in the following specific case this does not happen, and seems to be related to incorrect detaching at one place (see below for a hypothetical location where this happens)
To reproduce
Code snippet to reproduce
from botorch.models import SingleTaskGP
import torch
torch.manual_seed(123)
#define a model for the mean
class LinearModel(torch.nn.Module):
def __init__(self, D=1, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.beta = torch.nn.Parameter(torch.randn(1,D))
def forward(self, X):
return (X * self.beta).sum(-1)
lm = LinearModel(1)
opt = torch.optim.Adam(lm.parameters(), lr = 0.424242)#for zeroing grads
#generate some data
N=50
x_data = torch.linspace(0, 1, N).view(-1, 1)
y_data = 2*torch.sin(10*x_data) + 0.01*torch.randn_like(x_data)
#gp prediction manually
def gp_pred(xstar, obs_X, obs_Y, prior_mean, gp, detach_bug=False):
samples_prior_mean = prior_mean(xstar)
obs_prior_mean = prior_mean(obs_X)
gp_y = obs_Y - obs_prior_mean.unsqueeze(-1)
#gp pred
K = gp.covar_module(obs_X, obs_X).to_dense().detach()
likelihood_additive_noise = gp.likelihood.noise_covar.raw_noise.detach()
KplusNoise = K + likelihood_additive_noise * torch.eye(K.shape[0])
Kstar = gp.covar_module(xstar, obs_X).to_dense().detach()
KpNinv = torch.linalg.inv(KplusNoise)
#pred_mean = Kstar @ KpNinv @ gp_y
pred_mean = Kstar @ torch.linalg.solve(KplusNoise, gp_y)
if detach_bug: #the bug is here
pred_mean = pred_mean.detach()
pred_mean_og_scale = pred_mean.squeeze() + samples_prior_mean
pred_var = gp.covar_module.outputscale.detach() - (Kstar @ KpNinv @ Kstar.T).diag()
return pred_mean_og_scale, pred_var
#define the Gpytorch model
import gpytorch
class ExactGPModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood, mean_module, covar_module):
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = mean_module
self.covar_module = covar_module
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
#botorch model
gp_botorch = SingleTaskGP(x_data, y_data, mean_module=lm)
gp_botorch.likelihood.noise_covar.noise = 0.0001 #to simulate near noiseless prediction
gp_botorch.eval()
gp_botorch.likelihood.eval()
#gpytorch model
gp_gpytorch = ExactGPModel(x_data, y_data.squeeze(), gp_botorch.likelihood, lm, gp_botorch.covar_module)
gp_gpytorch.eval()
#test that manual, botorch and gpytorch actually output same predictions
x_star = torch.linspace(0, 1, 100).view(-1, 1)
pred_mean, pred_var = gp_pred(x_star, x_data, y_data, lm, gp_botorch)
pred_botorch = gp_botorch(x_star)
pred_gpytorch = gp_gpytorch(x_star)
#check botorch and gpytorch are equal
print("botorch vs. gpytorch")
print(torch.abs(pred_botorch.mean - pred_gpytorch.mean).max())
print(torch.abs(pred_botorch.variance - pred_gpytorch.variance).max())
#check predictions botorch and manual are approximately equal
print("botorch vs manual")
print(torch.abs(pred_mean - gp_botorch(x_star).mean).max())
print(torch.abs(pred_var - gp_botorch(x_star).variance).max())
#plotting
# import matplotlib.pyplot as plt
# plt.scatter(x_data, y_data, label="data")
# plt.plot(x_star, pred_mean.detach(), label="manual")
# plt.plot(x_star, pred_botorch.mean.detach(), label="botorch")
# plt.plot(x_star, pred_gpytorch.mean.detach(), label="gpytorch")
# plt.legend()
# plt.show()
#prediciton at the training points
opt.zero_grad()
pred_mean, pred_var = gp_pred(x_data, x_data, y_data, lm, gp_botorch)
pred_mean.mean().backward()
print(lm.beta.grad) #correct: ≈ 0
opt.zero_grad()
pred_mean, pred_var = gp_pred(x_data, x_data, y_data, lm, gp_botorch, detach_bug=True)
pred_mean.mean().backward()
print(lm.beta.grad) #incorrect: ≠ 0
opt.zero_grad()
pred_botorch = gp_botorch(x_data)
pred_botorch.mean.mean().backward()
print(lm.beta.grad) #incorrect: ≠ 0
opt.zero_grad()
pred_gpytorch = gp_gpytorch(x_data)
pred_gpytorch.mean.mean().backward()
print(lm.beta.grad) #incorrect: ≠ 0
Now the last three predictions are equal in outputs and in gradients (imo incorrect), but the first one matches in outputs and produces correct gradient.
Expected Behavior
The gradient of the predictive mean at the observation locations $X$ to be $0$. See #correct in the above snippet.
System information
Please complete the following information:
Gpytorch 1.11
PyTorch 2.1.0
Sonoma 14.4.1
I could not locate a bug in GPyTorch code, but hopefully you will be able to locate it with this report.
🐛 Bug
Hi,
First of all, thank you for developing such a versatile and efficient library! I suspect that I came across a bug when working with BoTorch, but I believe it is originating from GPyTorch.
When working with a custom mean function that depends on some parameters $\theta$ that we wish to optimize (e.g. NN feature extractor), the derivative of the predictive mean is wrong.
The predictive mean $\mu(\cdot)$ (given a custom mean function $m_\theta$) at $x^*$ is given by
$$ \mu(x^ | X, y) = m_\theta(x^) K{*}^T(K + \sigma I)^{-1} (y - m\theta(X))\text{.} $$
Easy test for the derivative is to predict at the observed data points $X$, which gives us (when the observational noise is small, $\sigma \approx 0$)
$$ \mu(X | X, y) \approx m\theta(X) + y - m\theta(X) = y $$
whose derivative w.r.t the mean module's parameters should be zero.
It appears that at least in the following specific case this does not happen, and seems to be related to incorrect detaching at one place (see below for a hypothetical location where this happens)
To reproduce
Code snippet to reproduce
Now the last three predictions are equal in outputs and in gradients (imo incorrect), but the first one matches in outputs and produces correct gradient.
Expected Behavior
The gradient of the predictive mean at the observation locations $X$ to be $0$. See #correct in the above snippet.
System information
Please complete the following information:
Sonoma 14.4.1
I could not locate a bug in GPyTorch code, but hopefully you will be able to locate it with this report.