cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.54k stars 556 forks source link

[Question] eval_cg_tolerance dramatic effect on variance #2148

Open irinaespejo opened 1 year ago

irinaespejo commented 1 year ago

Hello 👋

Changing gpytorch.settings.eval_cg_tolerance() during evaluation produces a dramatic change in the predicted variance on a Multitask GP wit Gaussian Likelihood. I changed only the default tolerance 1E-2 to 1E-6 and I obtained the following posterior variance plots. Is this expected behavior?

Screen Shot 2022-09-28 at 12 57 25 PM Screen Shot 2022-09-28 at 11 27 30 AM

Reproduce the issue

The model

class MultitaskGPModel(ExactGP):
    def __init__(
        self, train_x, train_y, likelihood,
    ):
        super(MultitaskGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = LinearMean(input_size = 4)  
        self.covar_module = RBFKernel(
            #lengthscale_constraint=GreaterThan(torch.Tensor([0.2]))
        )

        self.task_covar_module = IndexKernel(
            num_tasks=2,
            rank=2,
            prior=LKJCovariancePrior(
                2, eta=0.3, sd_prior=SmoothedBoxPrior(0, 1), validate_args=False
            ),
        )

    def forward(self, x, i):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        covar_i = self.task_covar_module(i)
        covar = covar_x.mul(covar_i)

        return MultivariateNormal(mean_x, covar)

class LinearMean(gpytorch.means.Mean):
    def __init__(self, input_size, batch_shape=torch.Size(), bias=True):
        super().__init__()
        self.register_parameter(name="weights", parameter=torch.nn.Parameter(torch.ones(*batch_shape, input_size, 1)))
        if bias:
            self.register_parameter(name="bias", parameter=torch.nn.Parameter(torch.ones(*batch_shape, 1)))
        else:
            self.bias = None

    def forward(self, x):
        res = x.matmul(self.weights).squeeze(-1)
        if self.bias is not None:
            res = res + self.bias
        return res

The code

Pre-trained model with noise covariance = 0.1 from GaussianLikelihood Data and model dict attached

likelihood = GaussianLikelihood()
full_train_x, full_train_i, full_train_y = loadTrainData(stage) 
model = MultitaskGPModel((full_train_x, full_train_i), full_train_y, likelihood)
state_dict = torch.load("state_dict.pt")
model.load_state_dict(state_dict)
dtype=torch.float64
likelihood = likelihood.to(device, dtype)
likelihood.eval()
model = model.to(device, dtype)
model.eval()

X = torch.arange(500,5000+1,100)
X = torch.Tensor(X)
X = X.to(device, dtype)
task_vec = torch.ones(X.shape[0]).to(device, dtype)

with torch.no_grad(), gpytorch.settings.eval_cg_tolerance(1E-6):
        posterior = likelihood(model(X, task_vec))
        mean = posterior.mean.detach()
        variance = posterior.variance.detach() 

plt.plot(X[:,0], mean, color='black', linestyle='dashed')
plt.fill_between(grid[:,0], mean, mean + std, alpha=0.4, color='blue')
plt.fill_between(grid[:,0], mean, mean - std, alpha=0.4, color='blue')
# same fill_between for more variance bands

Expected behavior

I expected evaluating with the default CG tolerance 1E-2 would give results according to the white noise hyper parameter = 0.1

Thank you!

data.zip

gpleiss commented 1 year ago

What happens when you z-score your x values (so transform them to be zero mean, unit variance)? This usually fixes these kinds of numerical instabilities.

irinaespejo commented 1 year ago

Sorry, forgot to mention that training and evaluation are done with data re-scaled to the hypercube not z-score tho. Will try z-score. Thanks!

gpleiss commented 1 year ago

Hmm I didn't realize the data are rescaled to the hypercube. How many data points are you using? And is the data 1 dimensional?

irinaespejo commented 1 year ago

Sorry for the long reply, the data X is 4Dand target y is 1-D. There are approx. 6000 points for task #1 and 1000 points for task #2. The plots that I show are 1-D slices fixing 3 features of X.

gpleiss commented 1 year ago

I will try to take a look at the issue this weekend.

gpleiss commented 1 year ago

@irinaespejo can you please post a fully run-able code example - i.e. something that I can copy-paste into a script and reproduce the results that you see?