activatedgeek / simplex-gp

Lattice kernels for scalable Gaussian processes in GPyTorch (Simplex-GPs)
https://go.sanyamkapoor.com/simplex-gp
Apache License 2.0
9 stars 2 forks source link
gaussian-processes machine-learning permutohedral-bilateral-filter

Simplex-GPs

PyPI version

This repository hosts the code for SKIing on Simplices: Kernel Interpolation on the Permutohedral Lattice for Scalable Gaussian Processes (Simplex-GPs) by Sanyam Kapoor, Marc Finzi, Ke Alexander Wang, Andrew Gordon Wilson.

The Idea

Fast matrix-vector multiplies (MVMs) are the cornerstone of modern scalable Gaussian processes. By building upon the approximation proposed by Structured Kernel Interpolation (SKI), and leveraging advances in fast high-dimensional image filtering, Simplex-GPs approximate the computation of the kernel matrices by tiling the space using a sparse permutohedral lattice, instead of a rectangular grid.

The matrix-vector product implied by the kernel operations in SKI are now approximated via the three stages visualized above --- splat (projection onto the permutohedral lattice), blur (applying the blur operation as a matrix-vector product), and slice (re-projecting back into the original space).

This alleviates the curse of dimensionality associated with SKI operations, allowing them to scale beyond ~5 dimensions, and provides competitive advantages in terms of runtime and memory costs, at little expense of downstream performance. See our manuscript for complete details.

Usage

The lattice kernels are packaged as GPyTorch modules, and can be used as a fast approximation to either the RBFKernel or the MaternKernel. The corresponding replacement modules are RBFLattice and MaternLattice.

RBFLattice kernel is simple to use by changing a single line of code:

import gpytorch as gp
from gpytorch_lattice_kernel import RBFLattice

class SimplexGPModel(gp.models.ExactGP):
  def __init__(self, train_x, train_y):
    likelihood = gp.likelihoods.GaussianLikelihood()
    super().__init__(train_x, train_y, likelihood)

    self.mean_module = gp.means.ConstantMean()
    self.covar_module = gp.kernels.ScaleKernel(
-      gp.kernels.RBFKernel(ard_num_dims=train_x.size(-1))
+      RBFLattice(ard_num_dims=train_x.size(-1), order=1)
    )

  def forward(self, x):
    mean_x = self.mean_module(x)
    covar_x = self.covar_module(x)
    return gp.distributions.MultivariateNormal(mean_x, covar_x)

The GPyTorch Regression Tutorial provides a simpler example on toy data, where this kernel can be used as a drop-in replacement.

Install

To use the kernel in your code, install the package as:

pip install gpytorch-lattice-kernel

NOTE: The kernel is compiled lazily from source using CMake. If the compilation fails, you may need to install a more recent version. Additionally, ninja is required for compilation. One way to install is:

conda install -c conda-forge cmake ninja

Local Setup

For a local development setup, create the conda environment

$ conda env create -f environment.yml

Remember to add the root of the project to PYTHONPATH if not already.

$ export PYTHONPATH="$(pwd):${PYTHONPATH}"

Test

To verify the code is working as expected, a simple test file is provided, that tests for the training marginal likelihood achieved by Simplex-GPs and Exact-GPs. Run as:

python tests/train_snelson.py

The Snelson 1-D toy dataset is used. A copy is available in snelson.csv.

Results

The proposed kernel can be used with GPyTorch as usual. An example script to reproduce results is,

python experiments/train_simplexgp.py --dataset=elevators --data-dir=<path/to/uci/data/mat/files>

We use Fire to handle CLI arguments. All arguments of the main function are therefore valid arguments to the CLI.

All figures in the paper can be reproduced via notebooks.

NOTE: The UCI dataset mat files are available here.

License

Apache 2.0