cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.54k stars 557 forks source link

[Bug] Potential Bug in GP Regression with KeOps Kernels #2453

Open kayween opened 10 months ago

kayween commented 10 months ago

🐛 Bug

During the backward pass of the exact marginal log likelihood, GPyTorch throws a CUDA error. This happens when using the KeOps kernel on large datasets with $1$ million data points.

Though, the KeOps kernel works fine with smaller datasets.

To reproduce

import torch

import gpytorch

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.keops.RBFKernel()
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

if __name__ == "__main__":
    # n = 500_000  # works okay
    n = 1_000_000  # error
    d = 10

    device = "cuda:0"

    train_x = torch.randn(n, d, device=device)
    train_y = torch.randn(n, device=device)

    likelihood = gpytorch.likelihoods.GaussianLikelihood().to(device)
    # likelihood.noise_covar.initilize(10.)

    model = ExactGPModel(train_x, train_y, likelihood).to(device)

    model.train()
    likelihood.train()

    # Use the adam optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    training_iter = 50
    for i in range(training_iter):
        # Zero gradients from previous iteration
        optimizer.zero_grad()
        # Output from model
        output = model(train_x)
        # Calc loss and backprop gradients
        loss = -mll(output, train_y)
        loss.backward()
        print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
            i + 1, training_iter, loss.item(),
            model.covar_module.base_kernel.lengthscale.item(),
            model.likelihood.noise.item()
        ))
        optimizer.step()

Stack trace/error message

[KeOps] Generating code for formula Sum_Reduction(Exp(-Sum((Var(0,10,0)-Var(1,10,1))**2)/2)*Var(2,11,1),0) ... OK
[KeOps] Generating code for formula Sum_Reduction(-((2*(Var(0,10,0)-Var(1,10,1)))*(((Var(3,11,0)|Var(2,11,1))*Exp(-Sum((Var(0,10,0)-Var(1,10,1))**2)/2))/2)),0) ... OK
[KeOps] Generating code for formula Sum_Reduction(-(-((2*(Var(0,10,0)-Var(1,10,1)))*(((Var(3,11,0)|Var(2,11,1))*Exp(-Sum((Var(0,10,0)-Var(1,10,1))**2)/2))/2))),1) ... OK
Traceback (most recent call last):
  File "test.py", line 51, in <module>
    loss.backward()
  File "/home/kaiwen/anaconda3/envs/altproj/lib/python3.8/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/home/kaiwen/anaconda3/envs/altproj/lib/python3.8/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling `cublasStrsmBatched( handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount)`

Expected Behavior

The backward pass should be executed successfully.

System information

Please complete the following information:

Additional context

This issue is related to the recent commit, which is intended to fix the bug in the KeOps kernel. I encountered the above issue when testing KeOps kernel after the commit.

Weirdly enough, the above code works fine when $n = 500,000$.

Note that GPyTorch1.6.0 does not have this bug.