cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.57k stars 560 forks source link

[Bug] Sign error in Matern52KernelGrad #2563

Open ingolia opened 2 months ago

ingolia commented 2 months ago

🐛 Bug

I think the sign of the covariance between function values and first derivatives is flipped in Matern52KernelGrad.

On a simple test case, the sign of the function-to-derivative covariance terms for Matern52KernelGrad is flipped relative to RBFKernelGrad.

Intuitively, the function value at -1 should be negatively correlated with the derivative at 0 while the function value at +1 should be positively correlated with the derivative at 0 (think about the finite difference formula for estimating the derivative at 0 expressed in terms of a kernel function). The signs on RBFKernelGrad match that expectation, the signs on Matern52KernelGrad do not.

Based on the description in the code, # 2) First gradient block, cov(f^m, omega^n_d) has a negative sign, but in the derivations PDF on the pull request, this term corresponds to equation (6) for $\frac{\partial k}{\partial x_i^n}$ which does not have a negative sign due to cancellation of the negatives in (∂k/∂r) and (∂r/∂x^n). Conversely, # 3) Second gradient block, cov(omega^m_d, f^n) says, the - signs on -outer2 and -five_thirds cancel out but I think this should not happen for (∂r/∂x^m) where x^m appears with a positive sign in the distance function r.

Finally, I tried to use this code to predict derivative values from a GP trained on function values for some simple functions and it predicts values that are correct, except the wrong sign.

The signs on the derivative-to-derivative covariance terms seem correct.

To reproduce

import torch

torch.set_printoptions(precision=4, sci_mode=False, linewidth=160)

from linear_operator import to_dense
from gpytorch.kernels import RBFKernel, RBFKernelGrad, MaternKernel, Matern52KernelGrad

matern = MaternKernel()
matern_grad = Matern52KernelGrad()

rbf = RBFKernel()
rbf_grad = RBFKernelGrad()

inputs = torch.tensor([[-1.0, 0.0], [0.0, 0.0], [1.0, 0.0]])

print(matern(inputs).to_dense().detach())
print(matern_grad(inputs).to_dense().detach())

print(rbf(inputs).to_dense().detach())
print(rbf_grad(inputs).to_dense().detach())

Running this code, Matérn52 gives

tensor([[ 1.0000,  0.0000,  0.0000,  0.3056,  0.5822,  0.0000,  0.0336,  0.0816,  0.0000],
...

whereas RBF gives

tensor([[ 1.0000,  0.0000,  0.0000,  0.3532, -0.7352,  0.0000,  0.0156, -0.0648,  0.0000],
...

Expected Behavior

The term in matern_grad(inputs)[0, 4] should give cov( f( (-1,0) ), ∂f/∂x^1 ( (0,0) ) ) and this should be negative, like the negative value for rbf_grad(inputs)[0, 4].

System information

Please complete the following information: GPyTorch version 1.12 Torch version 2.1.1 MacOS Ventura 13.6.7

ingolia commented 2 months ago

I realized I should reference #2512 and tag @m-julian on this.

lanzithinking commented 3 weeks ago

Thank you very much for reporting this bug! It solved my puzzle on testing Matern52KernelGrad without reproducing the results of the demo.

m-julian commented 2 weeks ago

Thank you both for reporting this, @ingolia sorry I must have missed the original notification for this issue so I'm responding a bit late. I'll have a look at it and submit a PR this week.

lanzithinking commented 2 weeks ago

As reported by @ingolia , line 97: -five_thirds -> five_thirds, line 105: five_thirds -> -five_thirds. Then everything works perfectly.