activatedgeek / simplex-gp

Lattice kernels for scalable Gaussian processes in GPyTorch (Simplex-GPs)
https://go.sanyamkapoor.com/simplex-gp
Apache License 2.0
9 stars 2 forks source link

ValueError in Interpolation with large samplesize/dimension #3

Open kevinli1324 opened 1 year ago

kevinli1324 commented 1 year ago

Hi! I was trying to use the library and ran into the error that reads "ValueError: left interp size (torch.Size([20000, 1, 1])) is incompatible with base lazy tensor size (torch.Size([20000, 20000])). Make sure the two have the same number of batch dimensions".

This only happens when I run on data with high-sample size/ dimension. I've modified the code in notebooks/bi_gp_ls.ipynb to replicate the error though the error occurs with different kernel settings. Is there an easy way to fix this, or are there extra steps when dealing with larger datasets? Thanks!

from tqdm.auto import tqdm
import torch
import gpytorch as gp
import altair as alt
import pandas as pd
import numpy as np

from gpytorch_lattice_kernel import RBFLattice as BilateralKernel

class BilateralGPModel(gp.models.ExactGP):
    def __init__(self, train_x, train_y):
        likelihood = gp.likelihoods.GaussianLikelihood()
        super().__init__(train_x, train_y, likelihood)
        # self.mean_module = gp.means.ConstantMean()
        # self.covar_module = gp.kernels.ScaleKernel(BilateralKernel(ard_num_dims=train_x.size(-1)))
        self.mean_module = gp.means.ZeroMean()
        self.covar_module = BilateralKernel()

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gp.distributions.MultivariateNormal(mean_x, covar_x)

def train(x, y, model, mll, optim):
    model.train()

    optim.zero_grad()

    output = model(x)

    loss = -mll(output, y)

    loss.backward()

    optim.step()

    return { 'train/ll': -loss.detach().item() }

def test(x, y, model, lanc_iter=100, pre_size=0):
    model.eval()

    with torch.no_grad():
#        gp.settings.max_preconditioner_size(pre_size), \
#        gp.settings.max_root_decomposition_size(lanc_iter), \
#        gp.settings.fast_pred_var():
        preds = model(x)

        pred_y = model.likelihood(model(x))
        rmse = (pred_y.mean - y).pow(2).mean(0).sqrt()

    return { 'test/rmse': rmse.item() }

def train_util(model, x, y, lr=0.1, epochs=100):
    mll = gp.mlls.ExactMarginalLogLikelihood(model.likelihood, model)
    optim = torch.optim.Adam(model.parameters(), lr=lr)

    for _ in tqdm(range(epochs), leave=False):
        train_dict = train(x, y, model, mll, optim)

    return train_dict

n = 20000
d = 4
x = 2. * torch.rand(n, d) - 1.

with torch.no_grad():
  covar_module = gp.kernels.ScaleKernel(gp.kernels.RBFKernel())
  params = covar_module.state_dict()
  params['raw_outputscale'] = torch.tensor(1.0).log()
  params['base_kernel.raw_lengthscale'] = torch.Tensor([[1.5]]).log()
  covar_module.load_state_dict(params)

  covar = gp.distributions.MultivariateNormal(torch.zeros(n), covariance_matrix=covar_module(x))

rperm = torch.randperm(n)[:n//2]
train_x = x[rperm]
train_y = (covar.sample() + 0.1 * torch.randn(x.size(0)))[rperm]

for _ in tqdm(range(10)):
  bigp = BilateralGPModel(train_x, train_y).float()

  with gp.settings.max_root_decomposition_size(50):
    train_dict = train_util(bigp, train_x, train_y)

  for name, p in bigp.named_parameters():
    results[name].append(p)
  results['kind'].append('Bilateral GP')

  for k, v in train_dict.items():
    results[k].append(v)
chriscamano commented 5 months ago

Hello just wanted to mention that I am experiencing this exact issue in my testing as well for large datasets past roughly 1500 data samples.

I have been able to fix this problem by manually evaluating the kernel in the forward pass of my GP class and adding noise to retain PSD,

K = covar_x.to_dense(); return MultivariateNormal(mean_x, K + noise* torch.eye(K.shape[0]).to(self.device))

However, this likely scales poorly and looses the benefits of the LazyTensor abstraction.

mfinzi commented 5 months ago

hmm @activatedgeek is this because of a change in gpytorch perhaps? Do we need to put version restrictions on it if we have not done so already?

mfinzi commented 5 months ago

I had a look into it. I think it may be from some implicit assumptions LazyTensors being extendible to batch objects with bs (1) with _unsqueeze_batch in the GPyTorch internals but I'm checking with those folks now. I believe it only triggers for large sample size because CG + SLQ is only used vs Cholesky at sufficiently high problem size, and this unsqueeze_batch is called when constructing the preconditioner to use with CG.

It appears that making the following changes to SquareLattice in bilateral_kernel.py at least get the code to run (reshapes in case x has a batch dimension). (Note that with the dimensional assert statements commented back in, the error is caught in the unsqueeze_batch functionality).

class SquareLazyLattice(LazyTensor):
    def __init__(self,x,dkernel=None):
        super().__init__(x,dkernel=dkernel)
        #assert x.ndim==2, f"No batch (even of size 1) inputs supported, got {x.ndim} with shape {x.shape}"
        self.x = x.reshape(*x.shape[-2:])
        self.orig_shape = x.shape
        self.dkernel=dkernel

    def _matmul(self,V):
        #assert V.ndim<=2
        out = LatticeFilterGeneral.apply(V.reshape(*V.shape[-2:]),self.x,self.dkernel)
        return out.reshape(V.shape) # unflatten if there were batch axes
    def _size(self):
        return torch.Size((*self.orig_shape[:-1],self.x.shape[-2]))
    def _transpose_nonbatch(self):
        return self
    def diag(self):
        return torch.ones_like(self.x[...,0])

Once we figure out what's really going one we can push out an update. Let me know if this band aid solves the problem for you guys.