cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.46k stars 546 forks source link

[Bug] The kernel ScaleKernel is not equipped to handle and diag. #2411

Closed cookbook-ms closed 9 months ago

cookbook-ms commented 9 months ago

🐛 Bug

I have a self defined kernel: basically a exponential kernel but operated in the eigenspace.

To reproduce

Code snippet to reproduce

class EdgeDiffusionKernelOceanFlow(Kernel):
    """
    Edge diffusion kernel for simplicial complexes

    Parameters
    ----------
    laplacians : torch.sparse_coo_tensor
        Laplacians of the simplicial complex
    kappa : tuple of float
        Diffusion parameters for the Laplacians
    s : tuple of float
        Scaling parameters for the Kernel
    """
    def __init__(self, eigenpairs, kappa_bounds=(1e-5,1e5)): 
        super().__init__()
        # self.eig_h, self.eig_g, self.eig_c, self.eigvec_h, self.eigvec_g, self.eigvec_c = eigenpairs
        self.eigvecs, self.eigvals = eigenpairs
        # register the raw parameters
        self.register_parameter(
            name='raw_kappa_down', parameter=torch.nn.Parameter(torch.zeros(1,1))
        )
        self.register_parameter(
            name='raw_kappa_up', parameter=torch.nn.Parameter(torch.zeros(1,1))
        )
        # set the kappa constraints
        self.register_constraint(
            'raw_kappa_down', Interval(*kappa_bounds)
        )
        self.register_constraint(
            'raw_kappa_up', Interval(*kappa_bounds)
        )
        # we do not set the prior on the parameters 

    # set up the actual parameters 
    @property
    def kappa_down(self):
        return self.raw_kappa_down_constraint.transform(self.raw_kappa_down)

    @kappa_down.setter
    def kappa_down(self, value):
        self._set_kappa_down(value)

    def _set_kappa_down(self, value):
        if not torch.is_tensor(value):
            value = torch.as_tensor(value).to(self.raw_kappa_down)
        self.initialize(raw_kappa_down=self.raw_kappa_down_constraint.inverse_transform(value))

    @property
    def kappa_up(self):
        return self.raw_kappa_up_constraint.transform(self.raw_kappa_up)

    @kappa_up.setter
    def kappa_up(self, value):
        self._set_kappa_up(value)

    def _set_kappa_up(self, value):
        if not torch.is_tensor(value):
            value = torch.as_tensor(value).to(self.raw_kappa_up)
        self.initialize(raw_kappa_up=self.raw_kappa_up_constraint.inverse_transform(value))

    def _eval_covar_matrix(self):
        """Define the full covariance matrix -- full kernel matrix as a property to avoid repeative computation of the kernel matrix"""
        # K1 = torch.linalg.matrix_exp(- (self.kappa_down*self.L1_down + self.kappa_up*self.L1_up))
        k = (self.kappa_down*self.eigvals).squeeze()
        K1 = self.eigvecs @ DiagLinearOperator(k) @ self.eigvecs.T
        # K2 = torch.linalg.matrix_exp(-self.kappa_up*self.L1_up)
        # This is equivalent to K1+K2-h_0 * I (remove the repeated identity part)
        return K1

    @property
    def covar_matrix(self):
        return self._eval_covar_matrix()

    # define the kernel function 
    def forward(self, x1, x2=None, **params):
        x1, x2 = x1.long(), x2.long()
        x1 = x1.squeeze(-1)
        x2 = x2.squeeze(-1)
        # compute the kernel matrix
        if x2 is None: 
            x2 = x1

        return self.covar_matrix[x1,:][:,x2]

My covariance module:

self.covar_module = gpytorch.kernels.ScaleKernel(kernel, outputscale_constraint=Interval(1e-5, 1e5))

When computing the variance and MSLL, it has the error

pred_mean, pred_var = observed_pred.mean, observed_pred.variance

Stack trace/error message

{
    "name": "RuntimeError",
    "message": "The kernel ScaleKernel is not equipped to handle and diag. Expected size torch.Size([9512]). Got size torch.Size([9512, 9512])",
    "stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/home/mmaosheng/kernel_methods/Edge-GPRegression_OceanFlow.ipynb Cell 35 line 1
----> <a href='vscode-notebook-cell://ssh-remote%2Bdesignare1/home/mmaosheng/kernel_methods/Edge-GPRegression_OceanFlow.ipynb#X43sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a> pred_mean, pred_var = observed_pred.mean, observed_pred.variance

File ~/miniconda3/lib/python3.11/site-packages/gpytorch/distributions/multivariate_normal.py:309, in MultivariateNormal.variance(self)
    305 @property
    306 def variance(self) -> Tensor:
    307     if self.islazy:
    308         # overwrite this since torch MVN uses unbroadcasted_scale_tril for this
--> 309         diag = self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2)
    310         diag = diag.view(diag.shape[:-1] + self._event_shape)
    311         variance = diag.expand(self._batch_shape + self._event_shape)

File ~/miniconda3/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py:1411, in LinearOperator.diagonal(self, offset, dim1, dim2)
   1409 elif not self.is_square:
   1410     raise RuntimeError(\"LinearOperator#diagonal is only implemented for square operators.\")
-> 1411 return self._diagonal()

File ~/miniconda3/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29, in SumLinearOperator._diagonal(self)
     28 def _diagonal(self: Float[LinearOperator, \"... M N\"]) -> Float[torch.Tensor, \"... N\"]:
---> 29     return sum(linear_op._diagonal().contiguous() for linear_op in self.linear_ops)

File ~/miniconda3/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29, in <genexpr>(.0)
     28 def _diagonal(self: Float[LinearOperator, \"... M N\"]) -> Float[torch.Tensor, \"... N\"]:
---> 29     return sum(linear_op._diagonal().contiguous() for linear_op in self.linear_ops)

File ~/miniconda3/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29, in SumLinearOperator._diagonal(self)
     28 def _diagonal(self: Float[LinearOperator, \"... M N\"]) -> Float[torch.Tensor, \"... N\"]:
---> 29     return sum(linear_op._diagonal().contiguous() for linear_op in self.linear_ops)

File ~/miniconda3/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29, in <genexpr>(.0)
     28 def _diagonal(self: Float[LinearOperator, \"... M N\"]) -> Float[torch.Tensor, \"... N\"]:
---> 29     return sum(linear_op._diagonal().contiguous() for linear_op in self.linear_ops)

File ~/miniconda3/lib/python3.11/site-packages/gpytorch/utils/memoize.py:59, in _cached.<locals>.g(self, *args, **kwargs)
     57 kwargs_pkl = pickle.dumps(kwargs)
     58 if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> 59     return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
     60 return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)

File ~/miniconda3/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:25, in recall_grad_state.<locals>.wrapped(self, *args, **kwargs)
     22 @functools.wraps(method)
     23 def wrapped(self, *args, **kwargs):
     24     with torch.set_grad_enabled(self._is_grad_enabled):
---> 25         output = method(self, *args, **kwargs)
     26     return output

File ~/miniconda3/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:126, in LazyEvaluatedKernelTensor._diagonal(self)
    124     expected_shape = self.shape[:-1]
    125     if res.shape != expected_shape:
--> 126         raise RuntimeError(
    127             \"The kernel {} is not equipped to handle and diag. Expected size {}. \"
    128             \"Got size {}\".format(self.kernel.__class__.__name__, expected_shape, res.shape)
    129         )
    131 if isinstance(res, LinearOperator):
    132     res = res.to_dense()

RuntimeError: The kernel ScaleKernel is not equipped to handle and diag. Expected size torch.Size([9512]). Got size torch.Size([9512, 9512])"
}

Expected Behavior

I expect to get the variance or MLSS from the trained model and likelihood.

System information

Please complete the following information:

Additional context

Add any other context about the problem here. I noticed similar type of issues in other issues too. While I tried to figure out the reason, I believe it is because I used K1 = self.eigvecs @ self.eigvecs.T regardless of if there is a diagonal matrix inbetween

gpleiss commented 9 months ago

@cookbook-ms the issue is that your forward function doesn't accept a diag=True keyword argument, which is necessary to obtain variance estimates. See the RBF or the LinearKernel implmentations for an example.

I realize that the custom tutorial documentation also does not include this option. This option may become obsolete with #2342 , so I'm not going to suggest fixing the tutorial at this moment.

cookbook-ms commented 9 months ago

Please refer to how LinearKernel implements this https://docs.gpytorch.ai/en/stable/_modules/gpytorch/kernels/linear_kernel.html#LinearKernel