cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.46k stars 546 forks source link

Fix training status of noise model of `HeteroskedasticNoise` after exceptions #2382

Closed fjzzq2002 closed 11 months ago

fjzzq2002 commented 11 months ago

In the current implementation of HeteroskedasticNoise.forward, self.noise_model.train(training) is set after the output from self.noise_model is received. When an exception is thrown by self.noise_model(), this reset is not called, leaving self.noise_model in evaluation mode. This patch fixes this scenario by adding a try-finally block.

The following is a typical error example:

import gpytorch
import torch

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

train_x = torch.tensor([[1.0], [2.0]])
train_y = torch.tensor([0.0, 0.0])
test_x = torch.tensor([[3.0]])
likelihood = gpytorch.likelihoods.GaussianLikelihood()
noise_model = ExactGPModel(train_x, train_y, likelihood).to(torch.double)
noise_model(train_x)
final_likelihood = gpytorch.likelihoods.HeteroskedasticNoise(noise_model)
assert noise_model.training and final_likelihood.training

# under a normal lengthscale, our likelihood works as expected
noise_model.covar_module.base_kernel.raw_lengthscale.data[[0]] = 0
print(final_likelihood(test_x).to_dense())

# now assume due to an imperfect optimizer the lengthscale got really low
noise_model.covar_module.base_kernel.raw_lengthscale.data[[0]] = -720
assert 0 < noise_model.covar_module.base_kernel.lengthscale < 1e-310

# as a result, we got a numerical error whenever we try to eval on noise_model
noise_model.eval()
try:
    print(noise_model(test_x))
except Exception as e:
    print("Error:", e)
noise_model.train()

# now we run the final_likelihood which ends in another error
try:
    print(final_likelihood(test_x).to_dense())
except Exception as e:
    print("Error:", e)

# after the call, noise_model is still in evaluation mode, so the cache is not cleared
assert final_likelihood.training and not noise_model.training

# even if we reset lengthscale back to normal, it still cannot give the correct likelihood
noise_model.covar_module.base_kernel.raw_lengthscale.data[[0]] = 0
try:
    print(final_likelihood(test_x).to_dense())
except Exception as e:
    print("Error:", e)

# works after calling train() to clear the cache
noise_model.train()
print(final_likelihood(test_x).to_dense())

We also believe it resolves https://github.com/pytorch/botorch/issues/1386 (replicated https://github.com/pytorch/botorch/issues/1386#issuecomment-1325351034 and our patch successfully fixed it).