AaltoML / kalman-jax

Approximate inference for Markov Gaussian processes using iterated Kalman smoothing, in JAX
Apache License 2.0
94 stars 13 forks source link
approximate-bayesian-inference gaussian-processes kalman-smoother machine-learning signal-processing state-space-models

Note: kalman-jax is now obselete. A significantly improved version of this code is now available at https://github.com/AaltoML/BayesNewton/

kalman-jax

Approximate inference for Markov (i.e., temporal) Gaussian processes using iterated Kalman filtering and smoothing. Developed and maintained by William Wilkinson. The Bernoulli likelihood was implemented by Paul Chang. We are based in Arno Solin's machine learning group at Aalto University, Finland.

This project aims to implement an XLA JIT compilable framework for inference in (non-conjugate) Markov Gaussian processes, with autodiff using JAX.

The methodology is outlined in the following paper:

More details about the variational inference method are given in the following paper:

If you use this code in your research, please cite the paper as follows:

@inproceedings{wilkinson2020,
  title={State Space Expectation Propagation: Efficient Inference Schemes for Temporal {G}aussian Processes},
  author={Wilkinson, William J. and Chang, Paul E. and Andersen, Michael Riis and Solin, Arno},
  booktitle={International Conference on Machine Learning},
  year={2020}
}

Spatio temporal GP classification

Getting started

Info

We combine two recent advances in the field of probabilistic machine learning:

Code structure

Each approximate inference algorithm will call the same underlying Kalman filter and smoother methods, and will be distinguished by the way in which the approximate likelihood terms are computed.

Approximate inference algorithms

Likelihoods

Priors

License

This software is provided under the Apache License 2.0. See the accompanying LICENSE file for details.