cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.51k stars 552 forks source link

[Bug] Type error in multitask, large input for GridInterpolationKernel #1958

Open mekty2012 opened 2 years ago

mekty2012 commented 2 years ago

🐛 Bug

I was trying to extend MultiTaskKernel's example by including GridInterpolationKernel, with the input being more than one dimension. However when I'm using all three features (multitask, grid interpolation kernel, multi-dimensional input), I get error in internal of library.

To reproduce

Code snippet to reproduce

import torch
import gpytorch

train_x = torch.cartesian_prod(torch.linspace(-1, 1, 50), torch.linspace(-1, 1, 50))

train_y = torch.stack([
    torch.sin(train_x[:,0] * (2 * math.pi)) * torch.sin(train_x[:, 1] * (2 * math.pi)) + torch.randn(train_x[:,0].size()) * 0.2,
    torch.sin(train_x[:,0] * (2 * math.pi)) * torch.cos(train_x[:, 1] * (2 * math.pi)) + torch.randn(train_x[:,0].size()) * 0.2,
    torch.cos(train_x[:,0] * (2 * math.pi)) * torch.sin(train_x[:, 1] * (2 * math.pi)) + torch.randn(train_x[:,0].size()) * 0.2,
    torch.cos(train_x[:,0] * (2 * math.pi)) * torch.cos(train_x[:, 1] * (2 * math.pi)) + torch.randn(train_x[:,0].size()) * 0.2,
], -1)

class MultitaskGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(MultitaskGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.MultitaskMean(
            gpytorch.means.ConstantMean(), num_tasks=4
        )
        self.covar_module = gpytorch.kernels.MultitaskKernel(
            gpytorch.kernels.GridInterpolationKernel(
                gpytorch.kernels.RBFKernel(),
                grid_size=10, num_dims=2
            ), num_tasks=4, rank=0
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)

likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=4)
model = MultitaskGPModel(train_x, train_y, likelihood)

# this is for running the notebook in our testing framework
# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

for i in range(2):
    optimizer.zero_grad()
    output = model(train_x)
    loss = -mll(output, train_y)
    loss.backward()
    print('Iter %d/%d - Loss: %.3f' % (i + 1, training_iterations, loss.item()))
    optimizer.step()

Stack trace/error message

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_13748/1593994832.py in <module>
     46     optimizer.zero_grad()
     47     output = model(train_x)
---> 48     loss = -mll(output, train_y)
     49     loss.backward()
     50     print('Iter %d/%d - Loss: %.3f' % (i + 1, training_iterations, loss.item()))

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\module.py in __call__(self, *inputs, **kwargs)
     28 
     29     def __call__(self, *inputs, **kwargs):
---> 30         outputs = self.forward(*inputs, **kwargs)
     31         if isinstance(outputs, list):
     32             return [_validate_module_outputs(output) for output in outputs]

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\mlls\exact_marginal_log_likelihood.py in forward(self, function_dist, target, *params)
     60         # Get the log prob of the marginal distribution
     61         output = self.likelihood(function_dist, *params)
---> 62         res = output.log_prob(target)
     63         res = self._add_other_terms(res, params)
     64 

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\distributions\multitask_multivariate_normal.py in log_prob(self, value)
    209             new_shape = value.shape[:-2] + value.shape[:-3:-1]
    210             value = value.view(new_shape).transpose(-1, -2).contiguous()
--> 211         return super().log_prob(value.view(*value.shape[:-2], -1))
    212 
    213     @property

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\distributions\multivariate_normal.py in log_prob(self, value)
    167         # Get log determininant and first part of quadratic form
    168         covar = covar.evaluate_kernel()
--> 169         inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True)
    170 
    171         res = -0.5 * sum([inv_quad, logdet, diff.size(-1) * math.log(2 * math.pi)])

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\kronecker_product_added_diag_lazy_tensor.py in inv_quad_logdet(self, inv_quad_rhs, logdet, reduce_inv_quad)
     60     def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True):
     61         if inv_quad_rhs is not None:
---> 62             inv_quad_term, _ = super().inv_quad_logdet(
     63                 inv_quad_rhs=inv_quad_rhs, logdet=False, reduce_inv_quad=reduce_inv_quad
     64             )

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\lazy_tensor.py in inv_quad_logdet(self, inv_quad_rhs, logdet, reduce_inv_quad)
   1332         func = InvQuadLogDet.apply
   1333 
-> 1334         inv_quad_term, logdet_term = func(
   1335             self.representation_tree(),
   1336             self.dtype,

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\functions\_inv_quad_log_det.py in forward(ctx, representation_tree, dtype, device, matrix_shape, batch_shape, inv_quad, logdet, probe_vectors, probe_vector_norms, *args)
    158 
    159         else:
--> 160             solves = lazy_tsr._solve(rhs, preconditioner, num_tridiag=0)
    161 
    162         # Final values to return

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\kronecker_product_added_diag_lazy_tensor.py in _solve(self, rhs, preconditioner, num_tridiag)
    187             # https://papers.nips.cc/paper/2013/file/59c33016884a62116be975a9bb8257e3-Paper.pdf
    188 
--> 189             dlt_inv_root, evals_p_i, evecs = _symmetrize_kpadlt_constructor(lt, dlt)
    190 
    191             res1 = evecs._transpose_nonbatch().matmul(dlt_inv_root.matmul(rhs))

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\kronecker_product_added_diag_lazy_tensor.py in _symmetrize_kpadlt_constructor(lt, dlt)
     37         *[d.matmul(k).matmul(d) for k, d in zip(lt.lazy_tensors, dlt_inv_root.lazy_tensors)]
     38     )
---> 39     evals, evecs = symm_prod.diagonalization()
     40     evals_plus_i = DiagLazyTensor(evals + 1.0)
     41 

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\kronecker_product_lazy_tensor.py in diagonalization(self, method)
    142         if method is None:
    143             method = "symeig"
--> 144         return super().diagonalization(method=method)
    145 
    146     @cached

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\utils\memoize.py in g(self, *args, **kwargs)
     57         kwargs_pkl = pickle.dumps(kwargs)
     58         if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> 59             return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
     60         return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)
     61 

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\lazy_tensor.py in diagonalization(self, method)
   1634 
   1635         elif method == "symeig":
-> 1636             evals, evecs = self.symeig(eigenvectors=True)
   1637         else:
   1638             raise RuntimeError(f"Unknown diagonalization method '{method}'")

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\utils\memoize.py in g(self, *args, **kwargs)
     57         kwargs_pkl = pickle.dumps(kwargs)
     58         if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> 59             return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
     60         return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)
     61 

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\lazy_tensor.py in symeig(self, eigenvectors)
   1911         except CachingError:
   1912             pass
-> 1913         return self._symeig(eigenvectors=eigenvectors)
   1914 
   1915     def to(self, *args, **kwargs):

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\kronecker_product_lazy_tensor.py in _symeig(self, eigenvectors, return_evals_as_lazy)
    292         evals, evecs = [], []
    293         for lt in self.lazy_tensors:
--> 294             evals_, evecs_ = lt.symeig(eigenvectors=eigenvectors)
    295             evals.append(evals_)
    296             evecs.append(evecs_)

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\utils\memoize.py in g(self, *args, **kwargs)
     57         kwargs_pkl = pickle.dumps(kwargs)
     58         if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> 59             return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
     60         return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)
     61 

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\lazy_tensor.py in symeig(self, eigenvectors)
   1911         except CachingError:
   1912             pass
-> 1913         return self._symeig(eigenvectors=eigenvectors)
   1914 
   1915     def to(self, *args, **kwargs):

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\lazy_tensor.py in _symeig(self, eigenvectors)
   2243         # potentially perform decomposition in double precision for numerical stability
   2244         dtype = self.dtype
-> 2245         evals, evecs = torch.linalg.eigh(self.evaluate().to(dtype=settings._linalg_dtype_symeig.value()))
   2246         # chop any negative eigenvalues.
   2247         # TODO: warn if evals are significantly negative

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\utils\memoize.py in g(self, *args, **kwargs)
     57         kwargs_pkl = pickle.dumps(kwargs)
     58         if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> 59             return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
     60         return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)
     61 

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\matmul_lazy_tensor.py in evaluate(self)
    114     @cached
    115     def evaluate(self):
--> 116         return torch.matmul(self.left_lazy_tensor.evaluate(), self.right_lazy_tensor.evaluate())

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\utils\memoize.py in g(self, *args, **kwargs)
     57         kwargs_pkl = pickle.dumps(kwargs)
     58         if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> 59             return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
     60         return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)
     61 

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\matmul_lazy_tensor.py in evaluate(self)
    114     @cached
    115     def evaluate(self):
--> 116         return torch.matmul(self.left_lazy_tensor.evaluate(), self.right_lazy_tensor.evaluate())

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\utils\memoize.py in g(self, *args, **kwargs)
     57         kwargs_pkl = pickle.dumps(kwargs)
     58         if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> 59             return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
     60         return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)
     61 

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\lazy_tensor.py in evaluate(self)
   1146             eye = torch.eye(num_cols, dtype=self.dtype, device=self.device)
   1147             eye = eye.expand(*self.batch_shape, num_cols, num_cols)
-> 1148             res = self.matmul(eye)
   1149         return res
   1150 

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\interpolated_lazy_tensor.py in matmul(self, tensor)
    407         # right_interp^T * tensor
    408         base_size = self.base_lazy_tensor.size(-1)
--> 409         right_interp_res = left_t_interp(self.right_interp_indices, self.right_interp_values, tensor, base_size)
    410 
    411         # base_lazy_tensor * right_interp^T * tensor

~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\utils\interpolation.py in left_t_interp(interp_indices, interp_values, rhs, output_dim)
    226     else:
    227         cls = getattr(torch.sparse, type_name)
--> 228     summing_matrix = cls(summing_matrix_indices, summing_matrix_values, size)
    229 
    230     # Sum up the values appropriately by performing sparse matrix multiplication

RuntimeError: expected scalar type Long but found Double

Expected Behavior

Very similar implementation with batch multitask works well. So I'll add its implementation for the comparison and some visualizations of data.

import math
import torch
import gpytorch
from matplotlib import pyplot as plt

train_x = torch.cartesian_prod(torch.linspace(-1, 1, 50), torch.linspace(-1, 1, 50))

train_y = torch.stack([
    torch.sin(train_x[:,0] * (2 * math.pi)) * torch.sin(train_x[:, 1] * (2 * math.pi)) + torch.randn(train_x[:,0].size()) * 0.2,
    torch.sin(train_x[:,0] * (2 * math.pi)) * torch.cos(train_x[:, 1] * (2 * math.pi)) + torch.randn(train_x[:,0].size()) * 0.2,
    torch.cos(train_x[:,0] * (2 * math.pi)) * torch.sin(train_x[:, 1] * (2 * math.pi)) + torch.randn(train_x[:,0].size()) * 0.2,
    torch.cos(train_x[:,0] * (2 * math.pi)) * torch.cos(train_x[:, 1] * (2 * math.pi)) + torch.randn(train_x[:,0].size()) * 0.2,
], -1)

class BatchIndependentMultitaskGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([4]))
        self.covar_module = gpytorch.kernels.GridInterpolationKernel(
            gpytorch.kernels.SpectralMixtureKernel(
                num_mixtures=10,
                ard_num_dims=2,
                batch_shape=torch.Size([4])
            ),
            grid_size=10,
            num_dims=2,
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultitaskMultivariateNormal.from_batch_mvn(
            gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
        )

likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=4)
model = BatchIndependentMultitaskGPModel(train_x, train_y, likelihood)

# this is for running the notebook in our testing framework
import os
smoke_test = ('CI' in os.environ)
training_iterations = 2 if smoke_test else 50

# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

for i in range(training_iterations):
    optimizer.zero_grad()
    output = model(train_x)
    loss = -mll(output, train_y)
    loss.backward()
    print('Iter %d/%d - Loss: %.3f' % (i + 1, training_iterations, loss.item()))
    optimizer.step()

Visualization on each tasks output

System information

Please complete the following information:

mekty2012 commented 2 years ago

While I was testing this bug, I found that this happens when we have too many data points, without multi dimensional input.

  1. If we change train_x to train_x = torch.cartesian_prod(torch.linspace(-1, 1, 5), torch.linspace(-1, 1, 5)), the code works well.
  2. Conversely, if we use one dimension data but with large dataset, we have same bug.
    
    import torch
    import gpytorch

train_x = torch.linspace(-1, 1, 500)

train_y = torch.stack([ torch.sin(train_x (2 math.pi)) torch.sin(train_x (2 math.pi)) + torch.randn(train_x.size()) 0.2, torch.sin(train_x (2 math.pi)) torch.cos(train_x (2 math.pi)) + torch.randn(train_x.size()) 0.2, torch.cos(train_x (2 math.pi)) torch.sin(train_x (2 math.pi)) + torch.randn(train_x.size()) 0.2, torch.cos(train_x (2 math.pi)) torch.cos(train_x (2 math.pi)) + torch.randn(train_x.size()) 0.2, ], -1)

class MultitaskGPModel(gpytorch.models.ExactGP): def init(self, train_x, train_y, likelihood): super(MultitaskGPModel, self).init(train_x, train_y, likelihood) self.mean_module = gpytorch.means.MultitaskMean( gpytorch.means.ConstantMean(), num_tasks=4 ) self.covar_module = gpytorch.kernels.MultitaskKernel( gpytorch.kernels.GridInterpolationKernel( gpytorch.kernels.RBFKernel(), grid_size=10, num_dims=1 ), num_tasks=4, rank=0 )

def forward(self, x):
    mean_x = self.mean_module(x)
    covar_x = self.covar_module(x)
    return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)

likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=4) model = MultitaskGPModel(train_x, train_y, likelihood)

this is for running the notebook in our testing framework

Find optimal model hyperparameters

model.train() likelihood.train()

Use the adam optimizer

optimizer = torch.optim.Adam(model.parameters(), lr=0.1) # Includes GaussianLikelihood parameters

"Loss" for GPs - the marginal log likelihood

mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

for i in range(training_iterations): optimizer.zero_grad() output = model(train_x) loss = -mll(output, train_y) loss.backward() print('Iter %d/%d - Loss: %.3f' % (i + 1, training_iterations, loss.item())) optimizer.step()


So I'll fix the title of issue from multi-dimensional to large.
famura commented 2 years ago

I just encountered the same problem with 2-dim input and 2-dim output for 5000 training points. I got the grid size from gpytorch.utils.grid.choose_grid_size(inputs_trn, 1.0) and created the covar_module like this

base_kernel = gpytorch.kernels.MaternKernel()
kernel = gpytorch.kernels.GridInterpolationKernel(kernel, grid_size, num_dims=num_tasks)
covar_module = gpytorch.kernels.MultitaskKernel(base_kernel, num_tasks, kernel_rank)
gpleiss commented 1 year ago

Yeah - this looks like a bug in GPyTorch/LinearOperator. I do not have the bandwidth to investigate this at the moment, so I'm looking for someone who can put up a PR to fix this.

CY-Zhang commented 1 year ago

I took a look into this error and found the root cause for the type error is under KroneckerProductAddedDiagLinearOperator._solve in LinearOperator. On line 177-181, we try to convert the lt and dlt matrices to torch.float64.

# again we perform the solve in double precision for numerical stability issues
# TODO: Use fp64 registry once #1213 is addressed        
rhs = rhs.to(symeig_dtype)
lt = self.linear_op.to(symeig_dtype)
dlt = self.diag_tensor.to(symeig_dtype)

But this operation could be bad for InterpolatedLinearOperator as it converts the indices matrices of the operator (left_interp_indices, right_interp_indices) to float64 too, and this would further cause a type error when we try to create a sparse matrix using the indices matrix.

As for the input size dependence of this type error, we can avoid getting this type error if we call Cholesky for the inverse quadratic term and logdet term in _linear_operator.inv_quad_logdet(), when settings.fast_computations.log_prob.off() == true or input_size * output_dimension <= settings.max_cholesky_size.value(), which is 800.

I want to collect some suggestions on what is the best way to fix this issue from people with more experience on this package. One straight-forward way is to avoid converting the index matrices to torch.float when calling to() in _linear_operator, but there might be some other easier way to fix this problem.

Balandat commented 1 year ago

I want to collect some suggestions on what is the best way to fix this issue from people with more experience on this package. One straight-forward way is to avoid converting the index matrices to torch.float when calling to() in _linear_operator, but there might be some other easier way to fix this problem.

Yeah it seems like not converting the index matrices in InterpolatedLinearOperator when calling to() is something that we'd want anyway. We can basically just overwrite the to() method on InterpolatedLinearOperator to only move the base linear operator.

Would you be willing to help out and put up a PR for this on the linear_operator repo?

CY-Zhang commented 1 year ago

Yeah it seems like not converting the index matrices in InterpolatedLinearOperator when calling to() is something that we'd want anyway. We can basically just overwrite the to() method on InterpolatedLinearOperator to only move the base linear operator.

Would you be willing to help out and put up a PR for this on the linear_operator repo?

Sure, I will work on it and add an unit test for that.