cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.57k stars 562 forks source link

Grid Kernel implementation is taking more time over vanilla Kernel #1734

Open srinathdama opened 3 years ago

srinathdama commented 3 years ago

Problem

I am trying to use Grid Kernels over default Kernals to speed up GPR training. My understanding is that Grid kernels implementation would exploit the tensor algebra and reduce the computational complexity drastically as Cholesky decomposition is used on individual matrices in the Kronecker product. I have disabled fast computations using context manager so that Cholesky decomposition is used instead of Conjugate gradient. I am observing that time taken for each step of training with Grid Kernal is significantly higher than when using default Kernal when fast computations are disabled. When I enable the fast computations (default settings), the computational time is less with Grid Kernal which is expected.

It would be helpful if someone can point me how to use Cholesky decomposition and still get speed up while using Grid Kernel.

Code to reproduce

Below code is taken from the Grid_GP_Regression tutorial. I am using the same data to compare the computational times.

import gpytorch
import torch
import math
import timeit

def train_GPR(model, likelihood, train_x, train_y, training_iter = 10, chol_flag = True):
    # Find optimal model hyperparameters
    model.train()
    likelihood.train()

    # Use the adam optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

    # "Loss" for GPs - the marginal log likelihood
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    for i in range(training_iter):
        # Zero gradients from previous iteration
        optimizer.zero_grad()

        start_time = timeit.default_timer()

        if chol_flag:
            with gpytorch.settings.max_cholesky_size(11000), \
                gpytorch.settings.fast_computations(covar_root_decomposition=False, log_prob=False, solves=False):
                # Output from model
                output = model(train_x)
                # Calc loss and backprop gradients
                loss = -mll(output, train_y)
        else:
            # Output from model
            output = model(train_x)
            # Calc loss and backprop gradients
            loss = -mll(output, train_y)

        loss.backward()
        optimizer.step()

        time_taken = timeit.default_timer() - start_time

        print('Iter %d/%d - step time: %.6f s' % (i + 1, training_iter, time_taken))

#################################
### GRID GPR data
#################################

grid_bounds = [(0, 1), (0, 2)]
grid_size = 50
grid = torch.zeros(grid_size, len(grid_bounds))
for i in range(len(grid_bounds)):
    grid_diff = float(grid_bounds[i][1] - grid_bounds[i][0]) / (grid_size - 2)
    grid[:, i] = torch.linspace(grid_bounds[i][0] - grid_diff, grid_bounds[i][1] + grid_diff, grid_size)

train_x = gpytorch.utils.grid.create_data_from_grid(grid)
train_y = torch.sin((train_x[:, 0] + train_x[:, 1]) * (2 * math.pi)) + torch.randn_like(train_x[:, 0]).mul(0.01)

### Model

class GridGPRegressionModel(gpytorch.models.ExactGP):
    def __init__(self, grid, train_x, train_y, likelihood):
        super(GridGPRegressionModel, self).__init__(train_x, train_y, likelihood)
        num_dims = train_x.size(-1)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.GridKernel(gpytorch.kernels.RBFKernel(), grid=grid)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = GridGPRegressionModel(grid, train_x, train_y, likelihood)

training_iter = 10

print('Train GPR model using Grid kernal ')
train_GPR(model, likelihood, train_x, train_y, training_iter = 10, chol_flag = True)

#################################
### GPR data
#################################

# same as the grid GPR data 
train_x = train_x
train_y = train_y

### Model

class GPRegressionModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
        num_dims = train_x.size(-1)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

likelihood = gpytorch.likelihoods.GaussianLikelihood()
model      = GPRegressionModel(train_x, train_y, likelihood)

print('Train GPR model without using GRID kernal')
train_GPR(model, likelihood, train_x, train_y, training_iter = 10, chol_flag = True)

Output by disabling fast computations (chol_flag = True while calling train_GPR)

Train GPR model using **Grid kernal** 
Iter 1/10 - step time: 1.335075 s
Iter 2/10 - step time: 0.775082 s
Iter 3/10 - step time: 0.743960 s
Iter 4/10 - step time: 0.787821 s
Iter 5/10 - step time: 0.789339 s
Iter 6/10 - step time: 0.783592 s
Iter 7/10 - step time: 0.755363 s
Iter 8/10 - step time: 0.752616 s
Iter 9/10 - step time: 0.762754 s
Iter 10/10 - step time: 0.766521 s

Train GPR model **without using GRID kernal**
Iter 1/10 - step time: 0.290277 s
Iter 2/10 - step time: 0.278644 s
Iter 3/10 - step time: 0.275472 s
Iter 4/10 - step time: 0.295164 s
Iter 5/10 - step time: 0.275632 s
Iter 6/10 - step time: 0.303502 s
Iter 7/10 - step time: 0.278693 s
Iter 8/10 - step time: 0.294948 s
Iter 9/10 - step time: 0.276134 s
Iter 10/10 - step time: 0.299402 s

Output enabling fast computations (chol_flag = False while calling train_GPR )

Train GPR model using **Grid kernal** 
Iter 1/10 - step time: 0.627863 s
Iter 2/10 - step time: 0.024604 s
Iter 3/10 - step time: 0.024323 s
Iter 4/10 - step time: 0.024040 s
Iter 5/10 - step time: 0.024330 s
Iter 6/10 - step time: 0.023434 s
Iter 7/10 - step time: 0.023406 s
Iter 8/10 - step time: 0.023706 s
Iter 9/10 - step time: 0.023484 s
Iter 10/10 - step time: 0.023516 s

Train GPR model **without using GRID kernal**
Iter 1/10 - step time: 0.086213 s
Iter 2/10 - step time: 0.074629 s
Iter 3/10 - step time: 0.073672 s
Iter 4/10 - step time: 0.070543 s
Iter 5/10 - step time: 0.072324 s
Iter 6/10 - step time: 0.070381 s
Iter 7/10 - step time: 0.073170 s
Iter 8/10 - step time: 0.073682 s
Iter 9/10 - step time: 0.072639 s
Iter 10/10 - step time: 0.072873 s
wjmaddox commented 3 years ago

A priori this isn't that surprising to me, although you do want to use conjugate gradients when using the grid kernel to be able to exploit the structure of the kernel itself.

In your example, you have 50^2 = 2500 data points and so not using the grid kernel you need to take a cholesky decomposition of a 2500 x 2500 matrix. If you instead use the grid kernel and use conjugate gradients then you don't need to perform this decomposition due to the properties of the kernel matrix which are being exploited (which is not going to be the case when computing the cholesky).

More specifically, why do you want to use a cholesky decomposition here?

srinathdama commented 3 years ago

The inverse of gram matrix (K) can be computed without unpacking the Kronecker product, i.e grid-based implementation ideally needs to do find inverse using SVD of two 50 x 50 matrices instead of 2500 x 2500, thereby significantly decreasing the computational complexity [References: Wilson et al, Ch-5 from Saatci PhD thesis.

When I was using a conjugate gradient with grid kernel on other data-set that I have, I was observing converging issues with it even after increasing the CG iterations/tolerance and thought of using another method existing method in gpytorch for finding inverse, which is Cholesky decomposition. Now I realize that in the papers I cited above they were able to find the inverse of K by using SVD instead of Cholesky of individual matrices in the Kronecker product. I am wondering this might be the reason when Cholesky decomposition is used it was unable to exploit the Kronecker structure. Even then I feel the grid kernel method should not take more time than the default kernel method.

wjmaddox commented 3 years ago

I'll look into this some more, my original response might have been a bit incorrect. Looking at the code GridKernel should return a KroneckerProductLazyTensor, which should use efficient Kronecker solves as you pointed out.

wjmaddox commented 3 years ago

Okay, this reproduces and gives timings on my laptop of ~2.4 s/it for cholesky + grid, ~0.008 s/it for non-cholesky + grid, ~1.4 s/it for cholesky + non-grid, ~0.13 s / it for non-cholesky + non-grid.

Interestingly enough, these correspond to the following settings:

I manually checked and it turns out that non-cholesky + grid (the default settings ultimately) does actually cause the Kronecker solves to get called. The reason why it's so much faster is because we exploit Kronecker algebra and only have to call symeig on two small matrices.

Forcing the maximum cholesky size to be very large overwrites our Kronecker algebra internally -- @Balandat maybe we should make this more clear in the settings?

jacobrgardner commented 3 years ago

Kind of related, but note that currently some functions will run slow with fast_computations off. In particular since KronckerProductAddedDiagLazyTensor doesn't override _cholesky or inherit from a lazy tensor that does like KroneckerProductLazyTensor, this'll cause problems with InvMatmul which currently just explicitly calls lazy_tsr.cholesky and doesn't care that you've overridden root_decomposition:

https://github.com/cornellius-gp/gpytorch/blob/7648de148691635d634f1179cc80e7311b1d1864/gpytorch/functions/_inv_matmul.py#L16-L17

So we'd get really slow behavior with fast_computations off when we go to compute predictions either way I think.

jacobrgardner commented 3 years ago

I wonder if a solution here might be to refactor lazy tensors (or linear operators) to have _iterative_solve and _direct_solve so that it's more obvious and intuitive in all situations exactly what is happening? Then "fast computations" (which feels a bit preachy anyways) should be refactored to be a setting that represents what it actually is: should we do solves using an iterative method or a direct method? If a direct method is chosen, we'll still always do it the best way we can.

Right now, I feel like there are a lot of gotchas, and even different functions have different behaviors under different settings (e.g., inv_matmul could currently be slow even when inv_quad_logdet is fast).

srinathdama commented 3 years ago

@wjmaddox, thanks for checking the issue! I will be more cautious next time onwards when disabling fast_computations.

@jacobrgardner, I agree with your suggestion on refactoring the code so that it would be easy to understand the direct or iterative method. For time being, I am using an existing grid-based direct method implemented using scipy.

Thanks again for making this cool software open source!