rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
327 stars 17 forks source link
probabilistic-programming

MCX

XLA-rated Bayesian inference

MCX is a probabilistic programming library with a laser-focus on sampling methods. MCX transforms the model definitions to generate logpdf or sampling functions. These functions are JIT-compiled with JAX; they support batching and can be exectuted on CPU, GPU or TPU transparently.

The project is currently at its infancy and a moonshot towards providing sequential inference as a first-class citizen, and performant sampling methods for Bayesian deep learning.

MCX's philosophy

  1. Knowing how to express a graphical model and manipulating Numpy arrays should be enough to define a model.
  2. Models should be modular and re-usable.
  3. Inference should be performant and should leverage GPUs.

See the documentation for more information. See this issue for an updated roadmap for v0.1.

Current API

Note that there are still many moving pieces in mcx and the API may change slightly.

import arviz as az
import jax
import jax.numpy as jnp
import numpy as np

import mcx
from mcx.distributions import Exponential, Normal
from mcx.inference import HMC

rng_key = jax.random.PRNGKey(0)

x_data = np.random.normal(0, 5, size=(1000,1))
y_data = 3 * x_data + np.random.normal(size=x_data.shape)

@mcx.model
def linear_regression(x, lmbda=1.):
    scale <~ Exponential(lmbda)
    coefs <~ Normal(jnp.zeros(jnp.shape(x)[-1]), 1)
    preds <~ Normal(jnp.dot(x, coefs), scale)
    return preds

prior_predictive = mcx.prior_predict(rng_key, linear_regression, (x_data,))

posterior = mcx.sampler(
    rng_key,
    linear_regression,
    (x_data,),
    {'preds': y_data},
    HMC(100),
).run()

az.plot_trace(posterior)

posterior_predictive = mcx.posterior_predict(rng_key, linear_regression, (x_data,), posterior)

MCX's future

We are currently considering the future directions:

You are more than welcome to contribute to these discussions, or suggest potential future directions.

Linear sampling

Like most PPL, MCX implements a batch sampling runtime:

sampler = mcx.sampler(
    rng_key,
    linear_regression,
    *args,
    observations,
    kernel,
)

posterior = sampler.run()

The warmup trace is discarded by default but you can obtain it by running:

warmup_posterior = sampler.warmup()
posterior = sampler.run()

You can extract more samples from the chain after a run and combine the two traces:

posterior += sampler.run()

By default MCX will sample in interactive mode using a python for loop and display a progress bar and various diagnostics. For faster sampling you can use:

posterior = sampler.run(compile=True)

One could use the combination in a notebook to first get a lower bound on the sampling rate before deciding on a number of samples.

Interactive sampling

Sampling the posterior is an iterative process. Yet most libraries only provide batch sampling. The generator runtime is already implemented in mcx, which opens many possibilities such as:

samples = mcx.sampler(
    rng_key,
    linear_regression,
    *args,
    observations,
    kernel,
)

trace = mcx.Trace()
for sample in samples:
  trace.append(sample)

iter(sampler)
next(sampler)

Note that the performance of the interactive mode is significantly lower than that of the batch sampler. However, both can be used successively:

trace = mcx.Trace()
for i, sample in enumerate(samples):
  print(do_something(sample))
  trace.append(sample)
  if i % 10 == 0:
    trace += sampler.run(100_000, compile=True)

Important note

MCX takes a lot of inspiration from other probabilistic programming languages and libraries: Stan (NUTS and the very knowledgeable community), PyMC3 (for its simple API), Tensorflow Probability (for its shape system and inference vectorization), (Num)Pyro (for the use of JAX in the backend), Gen.jl and Turing.jl (for composable inference), Soss.jl (generative model API), Anglican, and many that I forget.