cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.56k stars 559 forks source link

[Bug] Gradient computation with KeOps kernel only works with low number of training data, fails otherwise #1885

Closed abdolrezat closed 2 years ago

abdolrezat commented 2 years ago

🐛 Bug

Hi,

First of all, I would like to extend my thanks to all the developers for all the efforts you have put into both the research and this great package.

Consider a GP with a KeOps kernel (e.g. gpytorch.kernels.keops.RBFKernel). If I train it with N=100 number of points, then the gradient of predictive mean can be obtained by torch.autograd.grad or .backward(), but set N=500 and an error will be thrown that the input tensor was not used in the graph. I have tested the script on two separate machines and a colab instance. Using the GPyTorch standard kernels will not run into this issue. I spent a good deal of time pinpointing what was wrong from bigger chunks of code and this seemed to be the issue. The gradient link seems to cut off between the covariance output of the KeOps kernel and the input (covar.x1).

I have provided a minimal code right below that should quickly give you an idea of this somewhat strange behavior. It contains two test cases with N=100 (which passes) and 500 (fails). The code is from GPyTorch regression examples, I have only added the GP kernel and a few lines for calculating gradients at the end.

To reproduce

import math
import torch
import gpytorch
import time

# We will use the simplest form of GP model, exact inference with gpytorch.kernels.keops.RBFKernel
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.keops.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

def train_and_eval_GP(N = 100):
    """
    inputs:
    N (int): Number of training points
    """
    # make train/val/test
    # Training data is 100 points in [0,1] inclusive regularly spaced
    train_x = torch.linspace(0, 1, N)
    # True function is sin(2*pi*x) with Gaussian noise
    train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * math.sqrt(0.04)
    # normalize features
    mean = train_x.mean()
    std = train_x.std() + 1e-6 # prevent dividing by 0
    train_x = (train_x - mean) / std

    # normalize labels
    mean, std = train_y.mean(),train_y.std()
    train_y = (train_y - mean) / std

    # make continguous
    train_x, train_y = train_x.contiguous(), train_y.contiguous()

    output_device = torch.device('cuda:0')

    train_x, train_y = train_x.to(output_device), train_y.to(output_device)

    # initialize likelihood and model
    likelihood = gpytorch.likelihoods.GaussianLikelihood().to(output_device)
    model = ExactGPModel(train_x, train_y, likelihood).to(output_device)

    # Find optimal model hyperparameters
    model.train()
    likelihood.train()

    # Use the adam optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

    # "Loss" for GPs - the marginal log likelihood
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    training_iter = 20
    for i in range(training_iter):
        # Zero gradients from previous iteration
        optimizer.zero_grad()
        # Output from model
        output = model(train_x)
        # Calc loss and backprop gradients
        loss = -mll(output, train_y)
        loss.backward()
        optimizer.step()
    print(f'GP model trained.')

    # Get into evaluation (predictive posterior) mode
    model.eval()
    likelihood.eval()

    # Test points are regularly spaced along [0,1]
    test_x = torch.linspace(0, 1, 51, requires_grad=True).to(output_device).contiguous()

    # Make predictions by feeding model through likelihood
    with gpytorch.settings.fast_pred_var():
        observed_pred = likelihood(model(test_x))
        assert torch.autograd.grad(observed_pred.mean.sum(), test_x, retain_graph=True) is not None
        print('gradient exists:')
        print(torch.autograd.grad(observed_pred.mean.sum(), test_x, retain_graph=True))

if __name__ == "__main__":
    Ns = [100, 500] #test cases
    for n in Ns:
        try:
            print(f'testing with {n} points...')
            train_and_eval_GP(N = n) 
            print('success!')
        except Exception as e:
            print('failed.')
            print(e)

Stack trace/error message

testing with 100 points...
GP model trained.
gradient exists:
(tensor([-2.6629e+00, -2.6507e+00, -2.6344e+00, -2.6138e+00, -2.5891e+00,
        ...,
         9.1844e-01], device='cuda:0'),)
success!
testing with 500 points...
GP model trained.
failed.
One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

Expected Behavior

Setting the kernel to the standard, non-KeOps kernel (gpytorch.kernels.RBFKernel) we get the gradients for the second case, shown below. However, I can't simply use it since I'm working on a larger dataset that will run out of memory if I do so.

testing with 100 points...
GP model trained.
gradient exists:
(tensor([-2.5885, -2.5870, -2.5819, -2.5732, -2.5609, -2.5449, -2.5254, -2.5022,
        ...,
         0.6977,  0.7880,  0.8762], device='cuda:0'),)
success!
testing with 500 points...
GP model trained.
gradient exists:
(tensor([-2.4751, -2.4741, -2.4698, -2.4622, -2.4514, -2.4373, -2.4199, -2.3993,
        ...,
         0.4727,  0.5600,  0.6462], device='cuda:0'),)
success!

System information

Please complete the following information:

Additional context

I know that the issue might be unrelated to GPyTorch as this clearly stems from the keops kernel. It is however difficult to track, so I thought I'd report it here.

wjmaddox commented 2 years ago

Can you try making test_x also contiguous? I see that it isn't explicitly placed as such while train_x / train_y are done so.

abdolrezat commented 2 years ago

Can you try making test_x also contiguous? I see that it isn't explicitly placed as such while train_x / train_y are done so.

Hi @wjmaddox, thanks for the quick reply. I have modified the code to make test_x also contiguous :) It runs into the same problem.

The only change between both test cases is the number of training points, N=100 and 500. One works and one doesn't... any idea why?

wjmaddox commented 2 years ago

Ugh, i was hoping it'd be that simple. I wonder if it's something to do with the KeOps internal chunking -- can you try doing a binary search to see where it breaks? (possibly around N = 200 but I'm not sure)

abdolrezat commented 2 years ago

Ugh, i was hoping it'd be that simple. I wonder if it's something to do with the KeOps internal chunking -- can you try doing a binary search to see where it breaks? (possibly around N = 200 but I'm not sure)

:)) Yes, It seems that N=462 is the point where it breaks.

wjmaddox commented 2 years ago

Ugh... that's not helpful is it. Do you mind trying to pull the actual matmul out of the code to see where exactly it is? Should help us open an issue over on the keops github

wjmaddox commented 2 years ago

The obvious thought is that potentially our cached predictive mean (here https://github.com/cornellius-gp/gpytorch/blob/61f643eb8b487aef332c818f661fbcdb1df576ca/gpytorch/models/exact_prediction_strategies.py#L438) is not contiguous and that's destroying gradients. But im not sure about that.

abdolrezat commented 2 years ago

Ugh... that's not helpful is it. Do you mind trying to pull the actual matmul out of the code to see where exactly it is? Should help us open an issue over on the keops github

I'd be more than happy to help get to the bottom of this. Not an expert though, not sure what matmul to pull out :) so if you let me know exactly what and how to do it...

The obvious thought is that potentially our cached predictive mean (here

https://github.com/cornellius-gp/gpytorch/blob/61f643eb8b487aef332c818f661fbcdb1df576ca/gpytorch/models/exact_prediction_strategies.py#L438 ) is not contiguous and that's destroying gradients. But im not sure about that.

I checked there and also at https://github.com/cornellius-gp/gpytorch/blob/61f643eb8b487aef332c818f661fbcdb1df576ca/gpytorch/models/exact_prediction_strategies.py#L266

mean_cache.is_contiguous() # returns True
self.mean_cache.is_contiguous() # returns True
wjmaddox commented 2 years ago

Grr... I just took a deeper look and it looks like the problem is probably on our end:

import torch
from gpytorch.kernels.keops import RBFKernel
from gpytorch.settings import max_cholesky_size

train_x = torch.linspace(0, 1, 500).cuda().contiguous()
train_y = 3. * train_x

test_x = torch.linspace(0, 1, 500, requires_grad = True).cuda().contiguous()

kernel = RBFKernel().cuda()

kxx = kernel(train_x)

with max_cholesky_size(10):
    solve = kxx.add_jitter(0.1).inv_matmul(train_y)

with max_cholesky_size(10):
    value = torch.autograd.grad(solve.sum(), kernel.raw_lengthscale)
    print(value) # produces a scalar as expected

with max_cholesky_size(10):
    pred_mean = kernel(test_x, train_x).matmul(solve)

    value = torch.autograd.grad(pred_mean.sum(), test_x) # produces a non-zero value as well

I'll need to dig deeper into our prediction strategy code to see what exactly is happening.

abdolrezat commented 2 years ago

Thanks for putting in the effort. With your pointers I could better trace the problem. If it helps with the previous analysis you asked for, I checked that the failing point was also dependent on the number of test points (I used 51 test points in my example and at 462 training points it failed? = 513) So it seems to fail when N_train + N_train > 512. Could it relate to some value threshold related to the full/joint covariance matrix? I traced the test case that fails and the one that succeeds. in https://github.com/cornellius-gp/gpytorch/blob/61f643eb8b487aef332c818f661fbcdb1df576ca/gpytorch/models/exact_prediction_strategies.py#L253 there is an IF condition:

        if joint_covar.size(-1) <= settings.max_eager_kernel_size.value():
            test_covar = joint_covar[..., self.num_train :, :].evaluate()
            test_test_covar = test_covar[..., self.num_train :]
            test_train_covar = test_covar[..., : self.num_train]
        else:
            test_test_covar = joint_covar[..., self.num_train :, self.num_train :]
            test_train_covar = joint_covar[..., self.num_train :, : self.num_train]

where the value for settings.max_eager_kernel_size.value() is exactly 512, and it seems that the code that passes the test goes through the true condition and the one that fails goes through else.

wjmaddox commented 2 years ago

Sorry for losing track of this in my icml crunch, but I took a look at it today and figured out that it doesn't occur when the scale kernel is not included, e.g. covar_module = gpytorch.kernels.keops.RBFKernel() so hopefully that's progress towards figuring out what's going on. I'll keep looking tomorrow.

edit: the unscaled covariance fails but that's due to a clear error message and goes away when n_train > 800 or the max cholesky size threshold.

wjmaddox commented 2 years ago

So, this snippet should work:

test_x = torch.linspace(0, 1, 51, device=output_device).requires_grad_()

# Make predictions by feeding model through likelihood
with gpytorch.settings.max_cholesky_size(50),  gpytorch.settings.fast_pred_var():
    observed_pred = likelihood(model(test_x))
    assert torch.autograd.grad(observed_pred.mean.sum(), test_x, retain_graph=True) is not None
    print('gradient exists:')
    print(torch.autograd.grad(observed_pred.mean.sum(), test_x, retain_graph=True))

I haven't quite figured out what the specific issue in https://github.com/cornellius-gp/gpytorch/blob/5f54dbf1fb514fa5bf31d2468f1d85b2fb66a44f/gpytorch/kernels/keops/rbf_kernel.py#L23 is, but the issue is that below in L34-36, keops is (properly) not used when at least one of the kernel inputs is beneath the max cholesky size. Enforcing a keops forwards does solve the issue allowing gradients to propagate.

I'll put up a PR once I figure out what's getting detached.

edit: I'm pretty sure the reason why eager_kernel_size also resolves the gradient issues is because it enforces no keops to be used at all as you probably don't want to be using keops on small datasets due to speed issues.

abdolrezat commented 2 years ago

Awesome!