cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.54k stars 556 forks source link

Prediction using full Bayesian GPs (loading from disk) #1100

Open natarajanmolecule opened 4 years ago

natarajanmolecule commented 4 years ago

🐛 Bug

I have raised an issue earlier regarding how to load previously saved full Bayesian GP models from disk. Thanks so much for that PR. Now I am trying to load a previously saved model from disk and trying to make predictions with it. But it seems like it is only using the first sample from the MCMC in the predictions and not using all the samples. Please see the code below. I have modified the example code in the full Bayesian GP section to perform predictions. Am I missing something?

To reproduce

Code snippet to reproduce

import math
import torch
import gpytorch
import pyro
from pyro.infer.mcmc import NUTS, MCMC
from matplotlib import pyplot as plt
import pickle

# Training data is 11 points in [0,1] inclusive regularly spaced
train_x = torch.linspace(0, 1, 6)
# True function is sin(2*pi*x) with Gaussian noise
train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2

# allobj = pickle.load(open('dummymodel_traindata.pkl','rb'))
# train_x = allobj['trainX']
# train_y = allobj['trainY']
# We will use the simplest form of GP model, exact inference
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.PeriodicKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.Positive())
model = ExactGPModel(train_x, train_y, likelihood)
# Load saved model from Full Bayesian GP example
state_dict = torch.load('dummymodel_state.pth')
model.load_strict_shapes(False)
model.pyro_load_from_samples(state_dict)
# model.load_state_dict(state_dict)

import os
smoke_test = ('CI' in os.environ)
num_samples = 2 if smoke_test else 100
warmup_steps = 2 if smoke_test else 200

model.eval()
test_x = torch.linspace(0, 1, 101).unsqueeze(-1)
test_y = torch.sin(test_x * (2 * math.pi))
expanded_test_x = test_x.unsqueeze(0).repeat(num_samples, 1, 1)
output = likelihood(model(expanded_test_x))

# Initialize plot
f, ax = plt.subplots(1, 1, figsize=(4, 3))

# Plot training data as black stars
ax.plot(train_x.numpy(), train_y.numpy(), 'k*', zorder=10)

for i in range(min(num_samples, 25)):
    # Plot predictive means as blue line
    ax.plot(test_x.numpy(), output.mean[i].detach().numpy(), 'b', linewidth=0.3)

# Shade between the lower and upper confidence bounds
# ax.fill_between(test_x.numpy(), lower.numpy(), upper.numpy(), alpha=0.5)
ax.set_ylim([-3, 3])
ax.legend(['Observed Data', 'Sampled Means'])

Stack trace/error message Figure_1

Expected Behavior

Figure_2

System information

Please complete the following information:

jacobrgardner commented 4 years ago

@natarajanmolecule I'll look in to this -- looks like maybe the hypers all got updated to be the same or something strange.

jacobrgardner commented 4 years ago

@natarajanmolecule I am able to get correct looking predictions if I save and load in the following way:

...
mcmc_run.run(train_x, train_y)
model.pyro_load_from_samples(mcmc_run.get_samples())
torch.save([mcmc_run.get_samples(), model.state_dict()], 'model_state.pth') 

## reinit model likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.Positive())
model = ExactGPModel(train_x, train_y, likelihood)
model.mean_module.register_prior("mean_prior", UniformPrior(-1, 1), "constant") model.covar_module.base_kernel.register_prior("lengthscale_prior", UniformPrior(0.01, 0.5), "lengthscale")
model.covar_module.base_kernel.register_prior("period_length_prior", UniformPrior(0.05, 2.5), "period_length")
model.covar_module.register_prior("outputscale_prior", UniformPrior(1, 2), "outputscale") likelihood.register_prior("noise_prior", UniformPrior(0.05, 0.3), "noise")

mcmc_samples, model_state = torch.load('model_state.pth') 
model.pyro_load_from_samples(mcmc_samples)
model.load_state_dict(model_state)

Is this a workable solution for you? If so, this is probably a better way to load samples. I realized that actually the strict shape thing doesn't work at all, because it doesn't update batch_shape for the underlying modules. So I'll probably remove it and add this as the proper tutorial.