cornellius-gp / linear_operator

A LinearOperator implementation to wrap the numerical nuts and bolts of GPyTorch
MIT License
95 stars 28 forks source link

[Bug] Parameters missing from graph when KeOpsLinearOpeartor is used #55

Open m-julian opened 1 year ago

m-julian commented 1 year ago

🐛 Bug

I have implemented a KeOps periodic kernel in https://github.com/cornellius-gp/gpytorch/pull/2296 , however the raw_lengthscale parameter does not have gradients computed (see https://github.com/cornellius-gp/gpytorch/pull/2296#issuecomment-1466327296 ). I have managed to track the issue to the matrix multiplication implemented in LinearOperator.matmul, see Expected Behavior. This matrix multiplication is called when doing LinearOperator.sum (as in the example below).

To reproduce

Code snippet to reproduce

import torch
from gpytorch.kernels.keops import PeriodicKernel as KeOpsPeriodicKernel #implementation from pull request 2296
import gpytorch

torch.manual_seed(7)

M, N, D = 1000, 2000, 3
x = torch.randn(M, D).double()
y = torch.randn(N, D).double()
k = KeOpsPeriodicKernel(ard_num_dims=3).double()
k.lengthscale = torch.tensor(1.0).double()
k.period_length = torch.tensor(1.0).double()

# context manager used so that type(covar) is KeOpsLinearOpeartor, not LazyEvaluatedKernelTensor
with gpytorch.settings.lazily_evaluate_kernels(False):
    covar = k(x, y)
    print(type(covar))
     # Calls `LinearOperator.sum``, which subsequently calls `LinearOperator.matmul`
     # `LinearOperator.matmul` uses a custom torch.Function for matrix multiplication
    res2 = covar.sum(dim=1) # res2 is a torch.Tensor here
    res2 = res2.sum()
    print(res2)
    g_x = torch.autograd.grad(res2, [k.raw_lengthscale, k.raw_period_length])
    print(g_x)

Stack trace/error message

<class 'linear_operator.operators.keops_linear_operator.KeOpsLinearOperator'>
tensor(202237.5145, dtype=torch.float64, grad_fn=<SumBackward0>)
Traceback (most recent call last):
  File "/home/julian/Desktop/test/keops_periodic_low_level/issue_keops_linear_operator.py", line 23, in <module>
    g_x = torch.autograd.grad(res2, [k.raw_lengthscale, k.raw_period_length])
  File "/home/julian/.venv/ichor/lib/python3.10/site-packages/torch/autograd/__init__.py", line 300, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

where k.raw_lengthscale is causing the issue.

Expected Behavior

Compute the gradients for both k.raw_lengthscale and k.raw_period_length. The exact place where the issue occurs is here https://github.com/cornellius-gp/linear_operator/blob/92f7e3332664f8b1e2ec41552f00f163ab28112e/linear_operator/operators/_linear_operator.py#L2366 which calls the LinearOperator.matmul that loses track of the gradients. I am summing across the columns in the example, but the same issue occurs if summing across the rows. Adding the following check

# Case: summing across columns
if dim == (self.dim() - 1):
    ones = torch.ones(self.size(-1), 1, dtype=self.dtype, device=self.device)
    from .keops_linear_operator import KeOpsLinearOperator
    if isinstance(self, KeOpsLinearOperator):
        return self.covar_mat.sum(dim=1)
    return (self @ ones).squeeze(-1)

gives gradients for both raw_lengthscale and raw_period_length as the custom Matmul is never called.

<class 'linear_operator.operators.keops_linear_operator.KeOpsLinearOperator'>
tensor(202237.5145, dtype=torch.float64, grad_fn=<SumBackward0>)
(tensor([[70682.4975, 70796.7652, 70631.9364]], dtype=torch.float64), tensor([[ 70.2535,  19.3231, -47.2902]], dtype=torch.float64))

This is probably not the best solution, perhaps the Matmul forward/backward methods can be changed, so the gradients are computed correctly?

As a check, the same numbers are returned if the normal periodic kernel is used

<class 'linear_operator.operators.dense_linear_operator.DenseLinearOperator'>
tensor(202237.5145, dtype=torch.float64, grad_fn=<SumBackward0>)
(tensor([[70682.4975, 70796.7652, 70631.9364]], dtype=torch.float64), tensor([[ 70.2535,  19.3231, -47.2902]], dtype=torch.float64))

System information

Please complete the following information: linear_operator version: 0.3.0 torch version: 1.13.1+cu117

Additional context

Add any other context about the problem here.

gpleiss commented 1 year ago

Hi @m-julian I've found the bug (on our end) and I'm hoping to put up a PR to fix later today or tomorrow.

m-julian commented 1 year ago

Thanks, looking forward to it!

m-julian commented 1 year ago

Hi @gpleiss, when are you planning to put up the PR? I am happy to test out the fix and check gradients are computed correctly.

gpleiss commented 1 year ago

Sorry, I got a bit backlogged on this PR. I'll try to have something up on Friday or this weekend.

gpleiss commented 1 year ago

This bug will be fixed in the next LinearOperator release.