cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.57k stars 562 forks source link

AdditiveStructureKernel does not work in batch mode #513

Closed Lazloo closed 5 years ago

Lazloo commented 5 years ago

Hey,

I would like to perform "Scalable GP Regression" for a data set with multiple output variables. I use the following data set:

import math
import torch
import numpy as np
import gpytorch
from matplotlib import pyplot as plt

n = 20
train_x = torch.zeros(pow(n, 2), 2)
for i in range(n):
    for j in range(n):
        # Each coordinate varies from 0 to 1 in n=100 steps
        train_x[i * n + j][0] = float(i) / (n - 1)
        train_x[i * n + j][1] = float(j) / (n - 1)

train_y_1 = (torch.sin(train_x[:, 0]) + torch.cos(train_x[:, 1]) * (2 * math.pi) + torch.randn_like(train_x[:, 0]).mul(
    0.01)) / 4
train_y_2 = torch.sin(train_x[:, 0]) + torch.cos(train_x[:, 1]) * (2 * math.pi) + torch.randn_like(train_x[:, 0]).mul(
    0.01)

train_y = torch.stack([train_y_1, train_y_2], -1)

test_x = torch.rand((n, len(train_x.shape)))
test_y_1 = (torch.sin(test_x[:, 0]) + torch.cos(test_x[:, 1]) * (2 * math.pi) + torch.randn_like(test_x[:, 0]).mul(
    0.01)) / 4
test_y_2 = torch.sin(test_x[:, 0]) + torch.cos(test_x[:, 1]) * (2 * math.pi) + torch.randn_like(test_x[:, 0]).mul(0.01)
test_y = torch.stack([test_y_1, test_y_2], -1)

train_x = train_x.unsqueeze(0).repeat(2, 1, 1)
train_y = train_y.transpose(-2, -1)
train_x = train_x.cuda()
train_y = train_y.cuda()

An based on the tutorial I try to train the model usign the following code:

class MultitaskGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(MultitaskGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()

        grid_size = gpytorch.utils.grid.choose_grid_size(train_x, kronecker_structure=False)
        self.covar_module = gpytorch.kernels.AdditiveStructureKernel(
            gpytorch.kernels.GridInterpolationKernel(
                gpytorch.kernels.ScaleKernel(
                    gpytorch.kernels.MaternKernel(nu=1.5),
                ), grid_size=int(grid_size), num_dims=1
            ), num_dims=train_x.shape[-1]
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

likelihood = gpytorch.likelihoods.GaussianLikelihood(num_tasks=2).cuda()
model = MultitaskGPModel(train_x, train_y, likelihood).cuda()
# model.double()

# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam([
    {'params': model.parameters()},  # Includes GaussianLikelihood parameters
], lr=0.0001)

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

n_iter = 50
for i in range(n_iter):
    optimizer.zero_grad()
    output = model(train_x)
    # loss = -mll(output, train_y)
    loss = -mll(output, train_y).sum()
    loss.backward()
    print('Iter %d/%d - Loss: %.3f' % (i + 1, n_iter, loss.item()))
    optimizer.step()

However I get the following error message:

RuntimeError: left interp size (torch.Size([4, 400, 4])) is incompatible with base_lazy_tensor size (torch.Size([8, 400, 400])). Make sure the two have the same number of batch dimensions

Can someone help?

jacobrgardner commented 5 years ago

Sorry for the delay on this -- I'm still traveling so computer access is a bit sparse!

@gpleiss If AdditiveStructureKernel is intended to work in batch mode, it definitely isn't -- it's summing over both the batch and dimension portion of the batch.

@Lazloo In the meantime, with 2D data you really shouldn't want additive structure. This model works fine on that data:

class MultitaskGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(MultitaskGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()

        grid_size = gpytorch.utils.grid.choose_grid_size(train_x, kronecker_structure=True)
        print(grid_size)
        self.covar_module = gpytorch.kernels.GridInterpolationKernel(
                gpytorch.kernels.ScaleKernel(
                    gpytorch.kernels.MaternKernel(nu=1.5),
                ), grid_size=int(grid_size), num_dims=train_x.shape[-1]
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

If your actual data is higher dimensional, I would recommend using either DKL or InducingPointKernel.

gpleiss commented 5 years ago

I have a feeling this bug with go away with multi-batch LazyTensors (and subsequent updates to models and kernels). There is a plan (but not short-term plan) to fix this all, but it does require some major refactoring within GPyTorch.

I agree with Jake - that for the time being additive structure shouldn't be necessary for 2D data, and so you can get away with not using it.