This repository hosts the code for SKIing on Simplices: Kernel Interpolation on the Permutohedral Lattice for Scalable Gaussian Processes (Simplex-GPs) by Sanyam Kapoor, Marc Finzi, Ke Alexander Wang, Andrew Gordon Wilson.
Fast matrix-vector multiplies (MVMs) are the cornerstone of modern scalable Gaussian processes. By building upon the approximation proposed by Structured Kernel Interpolation (SKI), and leveraging advances in fast high-dimensional image filtering, Simplex-GPs approximate the computation of the kernel matrices by tiling the space using a sparse permutohedral lattice, instead of a rectangular grid.
The matrix-vector product implied by the kernel operations in SKI are now approximated via the three stages visualized above --- splat (projection onto the permutohedral lattice), blur (applying the blur operation as a matrix-vector product), and slice (re-projecting back into the original space).
This alleviates the curse of dimensionality associated with SKI operations, allowing them to scale beyond ~5 dimensions, and provides competitive advantages in terms of runtime and memory costs, at little expense of downstream performance. See our manuscript for complete details.
The lattice kernels are packaged as GPyTorch modules, and can be used as a
fast approximation to either the RBFKernel
or the MaternKernel
.
The corresponding replacement modules are RBFLattice
and MaternLattice
.
RBFLattice
kernel is simple to use by changing a single line of code:
import gpytorch as gp
from gpytorch_lattice_kernel import RBFLattice
class SimplexGPModel(gp.models.ExactGP):
def __init__(self, train_x, train_y):
likelihood = gp.likelihoods.GaussianLikelihood()
super().__init__(train_x, train_y, likelihood)
self.mean_module = gp.means.ConstantMean()
self.covar_module = gp.kernels.ScaleKernel(
- gp.kernels.RBFKernel(ard_num_dims=train_x.size(-1))
+ RBFLattice(ard_num_dims=train_x.size(-1), order=1)
)
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gp.distributions.MultivariateNormal(mean_x, covar_x)
The GPyTorch Regression Tutorial provides a simpler example on toy data, where this kernel can be used as a drop-in replacement.
To use the kernel in your code, install the package as:
pip install gpytorch-lattice-kernel
NOTE: The kernel is compiled lazily from source using CMake.
If the compilation fails, you may need to install a more recent version.
Additionally, ninja
is required for compilation. One way to install is:
conda install -c conda-forge cmake ninja
For a local development setup, create the conda
environment
$ conda env create -f environment.yml
Remember to add the root of the project to PYTHONPATH if not already.
$ export PYTHONPATH="$(pwd):${PYTHONPATH}"
To verify the code is working as expected, a simple test file is provided, that tests for the training marginal likelihood achieved by Simplex-GPs and Exact-GPs. Run as:
python tests/train_snelson.py
The Snelson 1-D toy dataset is used. A copy is available in snelson.csv.
The proposed kernel can be used with GPyTorch as usual. An example script to reproduce results is,
python experiments/train_simplexgp.py --dataset=elevators --data-dir=<path/to/uci/data/mat/files>
We use Fire to handle CLI arguments.
All arguments of the main
function are therefore valid arguments to the CLI.
All figures in the paper can be reproduced via notebooks.
NOTE: The UCI dataset mat
files are available here.
Apache 2.0