KernelLinearOperator was throwing errors when computing the diagonal of a KeOps kernel. (This computation happens during preconditioning, which requires the diagonal of the already-formed kernel LinearOperator object.) This error was because KeopsLinearOperator.diagonal calls to_dense on the output of a batch kernel operation. However, to_dense is not defined for KeOps LazyTensors.
This PR is in some sense a hack fix to this bug (a less hack fix will require changes to KernelLinearOperator), but it is also a generally nice and helpful refactor that will improve KeOps kernels in general.
The fixes:
KeOpsKernels now only define a forward function, that will be used both when we want to use KeOps and when we want to bypass it.
KeOpsKernels now use a _lazify_inputs helper method, which (potentially) wraps the inputs as KeOpsLazyTensors, or potentially leaves the inputs as torch Tensors.
The KeOps wrapping happens unless we want to bypass KeOps, which occurs when either (1) the matrix is small (below Cholesky size) or (2) when the use has turned off the gpytorch.settings.use_keops option (NEW IN THIS PR).
Why this is beneficial:
KeOps kernels now follow the same API as non-KeOps kernels (define a forward method)
The user now only has to define one forward method, that works in both the keops and non-keops cases
The diagonal call in KeopsLinearOperator constructs a batch 1x1 matrix, which is small enough to bypass keops and thus avoid the current bug. (Hence why this solution is currently a hack, but could become less hacky with a small modification to KernelLinearOperator and/or the to_dense method in LinearOperator).
Other changes:
Fix stability issues with the keops MaternKernel. (There were some NaN issues)
Introduce a gpytorch.settings.use_keops feature flag.
KernelLinearOperator was throwing errors when computing the diagonal of a KeOps kernel. (This computation happens during preconditioning, which requires the diagonal of the already-formed kernel LinearOperator object.) This error was because KeopsLinearOperator.diagonal calls to_dense on the output of a batch kernel operation. However, to_dense is not defined for KeOps LazyTensors.
This PR is in some sense a hack fix to this bug (a less hack fix will require changes to KernelLinearOperator), but it is also a generally nice and helpful refactor that will improve KeOps kernels in general.
The fixes:
_lazify_inputs
helper method, which (potentially) wraps the inputs as KeOpsLazyTensors, or potentially leaves the inputs as torch Tensors.gpytorch.settings.use_keops
option (NEW IN THIS PR).Why this is beneficial:
diagonal
call in KeopsLinearOperator constructs a batch 1x1 matrix, which is small enough to bypass keops and thus avoid the current bug. (Hence why this solution is currently a hack, but could become less hacky with a small modification to KernelLinearOperator and/or the to_dense method in LinearOperator).Other changes:
gpytorch.settings.use_keops
feature flag.[Fixes #2363]