cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.57k stars 562 forks source link

[Bug] DKL with InducingPointKernel produces a non-symmetric and non-positive-definite covariance matrix #2095

Open stevenstetzler opened 2 years ago

stevenstetzler commented 2 years ago

🐛 Bug

Using Deep Kernel Learning with the InducingPointKernel produces a non-positive-definite covariance matrix on test data (not the training data). The issue appears after the DNN weights + GP hyperparameters are trained on the training data. The matrix eventually becomes non-symmetric, and then at least one of the eigenvalues becomes negative. The asymmetry/negative eigenvalues are more negative/asymmetric than I would expect due to numerical stability issues at double precision (1e-7 - 1e-5, while the jitter added to make matrices PD during training is ~1e-8). The likelihood learns a noise term of ~1e-5 which is added to the diagonal of the test data covariance matrix. We additionally include a diagonal term for the observational noise which is ~1e-5, which we expect should fix any numerical instability issues when constructing the covariance matrix.

Does this still seem to be due to numerical instabilities, or is there some property of the model or data that could cause an asymmetric/non-PD covariance matrix? Can we not expect to achieve a predictive variance smaller than 1e-5? Are there settings to tweak that could improve stability during prediction?

To reproduce

Code and data are available at: https://epyc.astro.washington.edu/~stevengs/gp/. Download the .pkl and .py files to the same directory and run python dkl_sgpr_example.py.

Here is the model:

class SmallFeatureExtractor(torch.nn.Sequential):
    def __init__(self, dim, latent_dim):
        super(SmallFeatureExtractor, self).__init__()
        self.latent_dim = latent_dim
        self.dim = dim
        self.add_module('linear1', torch.nn.Linear(dim, 100))
        self.add_module('relu1', torch.nn.ReLU())
        self.add_module('linear2', torch.nn.Linear(100, 50))
        self.add_module('relu2', torch.nn.ReLU())
        self.add_module('linear3', torch.nn.Linear(50, 5))
        self.add_module('relu3', torch.nn.ReLU())
        self.add_module('linear4', torch.nn.Linear(5, latent_dim))

class DeepKernelSparseGP(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, feature_extractor, inducing_points):
        super(DeepKernelSparseGP, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.base_covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=inducing_points.shape[1]))       
        self.covar_module = gpytorch.kernels.InducingPointKernel(
            self.base_covar_module, inducing_points=inducing_points, likelihood=likelihood
        )
        self.scale_to_bounds = gpytorch.utils.grid.ScaleToBounds(-1., 1.)
        self.feature_extractor = feature_extractor

    def forward(self, x):
        x = self.scale_to_bounds(self.feature_extractor(x))
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

Stack trace/error message

Here I print out training log, the model parameters, and the covariance matrix constructed for the test data at the beginning and after the covariance matrix becomes asymmetric and has negative eigenvalues.

This is also the contents of dkl_sgpr_example.log.

before training
loss=1.465227
/gscratch/dirac/stevengs/nuclei/nuclei_summer/dkl_sgpr_example.py:103: UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at  ../aten/src/ATen/native/Copy.cpp:250.)
  eigvals = torch.linalg.eigvals(covar).double()
torch.allclose(covar, covar.T): True
torch.max(eigvals): tensor(0.7946, dtype=torch.float64)
torch.min(eigvals): tensor(0.5000, dtype=torch.float64)
all([e > 0 for e in eigvals]): True
covar:
 tensor([[0.5043, 0.0043, 0.0043,  ..., 0.0041, 0.0041, 0.0042],
        [0.0043, 0.5043, 0.0042,  ..., 0.0041, 0.0041, 0.0041],
        [0.0043, 0.0042, 0.5042,  ..., 0.0041, 0.0041, 0.0041],
        ...,
        [0.0041, 0.0041, 0.0041,  ..., 0.5042, 0.0042, 0.0042],
        [0.0041, 0.0041, 0.0041,  ..., 0.0042, 0.5042, 0.0042],
        [0.0042, 0.0041, 0.0041,  ..., 0.0042, 0.0042, 0.5042]],
       dtype=torch.float64, grad_fn=<AddBackward0>)
Parameters:
likelihood.noise_covar.noise
tensor([0.5000], dtype=torch.float64)

mean_module.constant
Parameter containing:
tensor([0.], dtype=torch.float64, requires_grad=True)

base_covar_module.outputscale
tensor(0.6931, dtype=torch.float64)

base_covar_module.base_kernel.lengthscale
tensor([[0.6931, 0.6931]], dtype=torch.float64)

covar_module.inducing_points
Parameter containing:
tensor([[ 0.0521, -0.6976],
        [ 0.0632, -0.4839],
        [ 0.2233,  0.3943],
        [ 0.1918, -0.7864],
        [ 0.7498,  0.8417],
        [-0.8966,  0.8190],
        [-0.8619,  0.4701],
        [-0.0591, -0.0799],
        [-0.4869,  0.7875],
        [ 0.1616, -0.0695],
        [-0.0354,  0.7248],
        [ 0.9909, -0.5697],
        [-0.8870, -0.6765],
        [-0.5430, -0.8170],
        [-0.3609, -0.3869],
        [-0.0071,  0.2852]], dtype=torch.float64, requires_grad=True)

epoch=0/100 batch=0 loss=1.465227
torch.allclose(covar, covar.T): True
torch.max(eigvals): tensor(0.6255, dtype=torch.float64)
torch.min(eigvals): tensor(0.4168, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=1/100 batch=0 loss=0.509937
torch.allclose(covar, covar.T): True
torch.max(eigvals): tensor(0.0650, dtype=torch.float64)
torch.min(eigvals): tensor(0.0024, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=2/100 batch=0 loss=-1.449859
torch.allclose(covar, covar.T): True
torch.max(eigvals): tensor(0.0025, dtype=torch.float64)
torch.min(eigvals): tensor(0.0006, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=3/100 batch=0 loss=-2.212336
torch.allclose(covar, covar.T): True
torch.max(eigvals): tensor(0.0006, dtype=torch.float64)
torch.min(eigvals): tensor(0.0003, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=4/100 batch=0 loss=-2.671866
torch.allclose(covar, covar.T): True
torch.max(eigvals): tensor(0.0003, dtype=torch.float64)
torch.min(eigvals): tensor(0.0002, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

// covariance matrix becomes asymmetric

epoch=5/100 batch=0 loss=-3.014810
torch.allclose(covar, covar.T): False
torch.max(eigvals): tensor(0.0003, dtype=torch.float64)
torch.min(eigvals): tensor(0.0001, dtype=torch.float64)
all([e > 0 for e in eigvals]): True
covar:
 tensor([[1.2720e-04, 1.7792e-05, 1.3826e-05,  ..., 1.6445e-06, 1.6768e-06,
         1.6631e-06],
        [1.7791e-05, 1.1988e-04, 1.1284e-05,  ..., 1.3325e-06, 1.3585e-06,
         1.3472e-06],
        [1.3826e-05, 1.1284e-05, 1.1415e-04,  ..., 1.0328e-06, 1.0529e-06,
         1.0441e-06],
        ...,
        [1.6338e-06, 1.3218e-06, 1.0222e-06,  ..., 1.0594e-04, 5.9573e-07,
         6.0441e-07],
        [1.6661e-06, 1.3479e-06, 1.0424e-06,  ..., 5.9571e-07, 1.0598e-04,
         6.2799e-07],
        [1.6526e-06, 1.3367e-06, 1.0336e-06,  ..., 6.0437e-07, 6.2797e-07,
         1.0601e-04]], dtype=torch.float64, grad_fn=<AddBackward0>)
Parameters:
likelihood.noise_covar.noise
tensor([8.8240e-05], dtype=torch.float64)

mean_module.constant
Parameter containing:
tensor([0.0057], dtype=torch.float64, requires_grad=True)

base_covar_module.outputscale
tensor(0.6027, dtype=torch.float64)

base_covar_module.base_kernel.lengthscale
tensor([[1.2276, 1.1023]], dtype=torch.float64)

covar_module.inducing_points
Parameter containing:
tensor([[ 0.0765, -0.8154],
        [ 0.1267, -0.3993],
        [ 0.0757,  0.3193],
        [ 0.3583, -0.5108],
        [ 0.7529,  0.7354],
        [-0.8866,  0.8077],
        [-0.8409,  0.4318],
        [ 0.5946, -0.1004],
        [-0.4840,  0.7056],
        [ 0.9833, -0.4459],
        [-0.0277,  0.7254],
        [ 0.6723, -0.4268],
        [-0.7652, -0.6188],
        [-0.4834, -0.6123],
        [-0.4852, -0.6137],
        [-0.1640,  0.1421]], dtype=torch.float64, requires_grad=True)

epoch=6/100 batch=0 loss=-3.216315
torch.allclose(covar, covar.T): True
torch.max(eigvals): tensor(0.0001, dtype=torch.float64)
torch.min(eigvals): tensor(8.4475e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=7/100 batch=0 loss=-3.349052
torch.allclose(covar, covar.T): True
torch.max(eigvals): tensor(9.9394e-05, dtype=torch.float64)
torch.min(eigvals): tensor(7.5150e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=8/100 batch=0 loss=-3.450857
torch.allclose(covar, covar.T): True
torch.max(eigvals): tensor(7.4041e-05, dtype=torch.float64)
torch.min(eigvals): tensor(6.7327e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=9/100 batch=0 loss=-3.533311
torch.allclose(covar, covar.T): True
torch.max(eigvals): tensor(6.7669e-05, dtype=torch.float64)
torch.min(eigvals): tensor(6.0395e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=10/100 batch=0 loss=-3.603859
torch.allclose(covar, covar.T): True
torch.max(eigvals): tensor(6.0371e-05, dtype=torch.float64)
torch.min(eigvals): tensor(5.4178e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=11/100 batch=0 loss=-3.656347
torch.allclose(covar, covar.T): True
torch.max(eigvals): tensor(5.6204e-05, dtype=torch.float64)
torch.min(eigvals): tensor(4.5188e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=12/100 batch=0 loss=-3.706605
torch.allclose(covar, covar.T): True
torch.max(eigvals): tensor(5.1626e-05, dtype=torch.float64)
torch.min(eigvals): tensor(4.6964e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=13/100 batch=0 loss=-3.747696
torch.allclose(covar, covar.T): False
torch.max(eigvals): tensor(5.2272e-05, dtype=torch.float64)
torch.min(eigvals): tensor(4.1187e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=14/100 batch=0 loss=-3.779977
torch.allclose(covar, covar.T): False
torch.max(eigvals): tensor(4.8281e-05, dtype=torch.float64)
torch.min(eigvals): tensor(4.0903e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=15/100 batch=0 loss=-3.819055
torch.allclose(covar, covar.T): True
torch.max(eigvals): tensor(4.7312e-05, dtype=torch.float64)
torch.min(eigvals): tensor(4.2271e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=16/100 batch=0 loss=-3.848625
torch.allclose(covar, covar.T): True
torch.max(eigvals): tensor(4.5553e-05, dtype=torch.float64)
torch.min(eigvals): tensor(3.3372e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=17/100 batch=0 loss=-3.870684
torch.allclose(covar, covar.T): False
torch.max(eigvals): tensor(5.6592e-05, dtype=torch.float64)
torch.min(eigvals): tensor(4.1854e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=18/100 batch=0 loss=-3.891119
torch.allclose(covar, covar.T): True
torch.max(eigvals): tensor(5.8252e-05, dtype=torch.float64)
torch.min(eigvals): tensor(4.1263e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=19/100 batch=0 loss=-3.909578
torch.allclose(covar, covar.T): False
torch.max(eigvals): tensor(6.4548e-05, dtype=torch.float64)
torch.min(eigvals): tensor(3.9948e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=20/100 batch=0 loss=-3.934292
torch.allclose(covar, covar.T): False
torch.max(eigvals): tensor(6.3655e-05, dtype=torch.float64)
torch.min(eigvals): tensor(3.8333e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=21/100 batch=0 loss=-3.961100
torch.allclose(covar, covar.T): False
torch.max(eigvals): tensor(3.8688e-05, dtype=torch.float64)
torch.min(eigvals): tensor(3.6147e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=22/100 batch=0 loss=-3.979790
torch.allclose(covar, covar.T): False
torch.max(eigvals): tensor(3.7945e-05, dtype=torch.float64)
torch.min(eigvals): tensor(2.4795e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=23/100 batch=0 loss=-3.999898
torch.allclose(covar, covar.T): False
torch.max(eigvals): tensor(3.7241e-05, dtype=torch.float64)
torch.min(eigvals): tensor(1.5308e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=24/100 batch=0 loss=-4.025196
torch.allclose(covar, covar.T): False
torch.max(eigvals): tensor(3.6644e-05, dtype=torch.float64)
torch.min(eigvals): tensor(2.0070e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=25/100 batch=0 loss=-4.050356
torch.allclose(covar, covar.T): True
torch.max(eigvals): tensor(7.6968e-05, dtype=torch.float64)
torch.min(eigvals): tensor(3.4462e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=26/100 batch=0 loss=-4.075860
torch.allclose(covar, covar.T): False
torch.max(eigvals): tensor(3.4989e-05, dtype=torch.float64)
torch.min(eigvals): tensor(3.0445e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=27/100 batch=0 loss=-4.097970
torch.allclose(covar, covar.T): False
torch.max(eigvals): tensor(5.3375e-05, dtype=torch.float64)
torch.min(eigvals): tensor(3.2858e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=28/100 batch=0 loss=-4.119830
torch.allclose(covar, covar.T): False
torch.max(eigvals): tensor(3.3981e-05, dtype=torch.float64)
torch.min(eigvals): tensor(2.9871e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): True

epoch=29/100 batch=0 loss=-4.138579
torch.allclose(covar, covar.T): False
torch.max(eigvals): tensor(3.2372e-05, dtype=torch.float64)
torch.min(eigvals): tensor(-2.1481e-05, dtype=torch.float64)
all([e > 0 for e in eigvals]): False
covar:
 tensor([[ 3.0744e-05, -5.9116e-07, -6.0014e-07,  ..., -5.9368e-07,
         -5.9529e-07, -5.9687e-07],
        [-5.9130e-07,  3.0728e-05, -6.0373e-07,  ..., -5.9206e-07,
         -5.9108e-07, -5.9004e-07],
        [-6.0042e-07, -6.0388e-07,  3.0718e-05,  ..., -5.9150e-07,
         -5.8834e-07, -5.8509e-07],
        ...,
        [-6.4287e-07, -6.4132e-07, -6.4081e-07,  ...,  3.0723e-05,
         -5.7952e-07, -5.5667e-07],
        [-6.4501e-07, -6.4087e-07, -6.3817e-07,  ..., -5.8023e-07,
          3.0776e-05, -5.1833e-07],
        [-6.4710e-07, -6.4034e-07, -6.3544e-07,  ..., -5.5809e-07,
         -5.1904e-07,  3.0846e-05]], dtype=torch.float64,
       grad_fn=<AddBackward0>)
Parameters:
likelihood.noise_covar.noise
tensor([1.4199e-05], dtype=torch.float64)

mean_module.constant
Parameter containing:
tensor([0.0069], dtype=torch.float64, requires_grad=True)

base_covar_module.outputscale
tensor(0.6304, dtype=torch.float64)

base_covar_module.base_kernel.lengthscale
tensor([[1.5524, 1.4203]], dtype=torch.float64)

covar_module.inducing_points
Parameter containing:
tensor([[ 0.1540, -0.6542],
        [-0.1453, -0.4734],
        [-0.0165,  0.2547],
        [ 0.3614, -0.3993],
        [ 0.7173,  0.5495],
        [-0.8548,  0.7846],
        [-0.8305,  0.3709],
        [ 0.8371, -0.3482],
        [-0.4750,  0.5513],
        [ 1.2503, -0.7863],
        [-0.0201,  0.7434],
        [ 0.8170, -0.5308],
        [-0.8348, -0.6516],
        [-0.6655, -0.5397],
        [-0.7762, -0.7725],
        [-0.1686, -0.0699]], dtype=torch.float64, requires_grad=True)

Traceback (most recent call last):
  File "/gscratch/dirac/stevengs/nuclei/nuclei_summer/dkl_sgpr_example.py", line 167, in <module>
    train(100)
  File "/gscratch/dirac/stevengs/nuclei/nuclei_summer/dkl_sgpr_example.py", line 158, in train
    debug()
  File "/gscratch/dirac/stevengs/nuclei/nuclei_summer/dkl_sgpr_example.py", line 132, in debug
    assert(eigvals_positive)
AssertionError

Expected Behavior

I expect the covariance matrix produced to be positive-definite within the numerical precision available.

System information

Please complete the following information:

gpleiss commented 2 years ago

I don't think this is a bug in GPyTorch. It also seems like the issue is that your variance is collapsing to zero (minimum eigenvalue is similar to maximum eigenvalue). This would cause the numerical issues that you're seeing.