BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
Adding some basic VI approximation and fitting routine #397

Open junpenglao opened 1 year ago

junpenglao commented 1 year ago

Copying over from

After #392, we should add the 2 most basic VI algorithm: meanfield and full rank ADVI [1]. Below is a working example of Meanfield ADVI:

import jax
import jax.numpy as jnp
from jax.scipy import stats

def gen_meanfield_logprob(params):
    mu_param, rho_param = params
    sigma_param = jax.tree_map(jnp.exp, rho_param)
    def meanfield_logprob(position):
        logq_pytree = jax.tree_map(
            stats.norm.logpdf, position, mu_param, sigma_param
        logq = jax.tree_map(jnp.sum, logq_pytree)
        return jax.tree_util.tree_reduce(jnp.add, logq)
    return meanfield_logprob

# gen_meanfield_logprob(init_params)(init_position)

def meanfield_sample(
    rng_key, meanfield_param, num_samples: int
    if not isinstance(num_samples, tuple):
        num_samples = (num_samples,)
    mu_param, rho_param = meanfield_param
    sigma_param = jax.tree_map(jnp.exp, rho_param)
    mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu_param)
    sigma_flatten, _ = jax.flatten_util.ravel_pytree(sigma_param)
    flatten_sample = jax.random.normal(
        rng_key, num_samples + mu_flatten.shape
        ) * sigma_flatten + mu_flatten
    if len(num_samples) == 0:
        return unravel_fn(flatten_sample)
    return jax.vmap(unravel_fn)(flatten_sample)

# meanfield_sample(rng, init_params, ())

def meanfield_approximate(rng, init_params, log_prob_fn, optimizer, sample_size=5, num_steps=200):
    def meanfield_approximate_step(
        state, rng_key_sample
        params, opt_state = state
        def kl_fn(params):
            sample = meanfield_sample(rng_key_sample, params, sample_size)
            logq = gen_meanfield_logprob(params)(sample)
            logp = log_prob_fn(sample)
            return (logq - logp).mean()
        # compute KL divergence
        elbo, grad = jax.value_and_grad(kl_fn)(params)
        updates, opt_state = optimizer.update(grad, opt_state, params)
        params = jax.tree_map(
            lambda p, u: p + u, params, updates
        return (params, opt_state), elbo

    def run_optimization(init_params):
        opt_state = optimizer.init(init_params)
        state = (init_params, opt_state)
        rng_key = jax.random.split(rng, num_steps)
        return jax.lax.scan(
            meanfield_approximate_step, state, rng_key

    return run_optimization(init_params)

Fitting a model looks like:

import matplotlib.pyplot as plt
import numpy as np

import optax
import tensorflow_probability.substrates.jax as tfp

tfd = tfp.distributions

rng = jax.random.PRNGKey(0)

seed0, seed1, rng = jax.random.split(rng, 3)
X = jax.random.normal(seed0, (100, 98))
y = X @ np.arange(98) + jax.random.normal(seed1, (100,))

def model():
    sigma = yield tfd.HalfNormal(5.0, name='sigma')
    mu = yield tfd.Normal(0.0, 1.0, name='mu')
    beta = yield tfd.Sample(tfd.Normal(mu, sigma), X.shape[-1], name='beta')
    yield tfd.Normal(X @ beta, 1.0, name="y")

# init_position = model.sample(seed=rng)
pinned = model.experimental_pin(y=y)
init_position = pinned.sample_unpinned(seed=rng)

bijectors = pinned.experimental_default_event_space_bijector()
def log_prob_fn(unbound_param):
    param = bijectors.forward(unbound_param)
    log_det_jacobian = bijectors.forward_log_det_jacobian(unbound_param)
    return pinned.unnormalized_log_prob(param) + log_det_jacobian
# This is just one way to do it. We could also use a flattened array to represent mu and rho
mu_param = jax.tree_map(jnp.zeros_like, init_position)
rho_param = jax.tree_map(jnp.zeros_like, init_position)
init_params = (mu_param, rho_param)

optimizer = optax.chain(optax.clip(10.), optax.adam(1.))
output = meanfield_approximate(rng, init_params, log_prob_fn, optimizer)


rlouf commented 1 year ago

In #392 we define a VIAlgorithm, and here we would need to define a new base type ParametrizedVIAlgorithm base type.

xidulu commented 1 year ago

@rlouf I think I can take charge of this, but just for sure:

We assume the log_prob_fn (i.e. log p(x,z) ) in BlackJax takes in a real value flattened array (rather than a dict or something on constrained space) right

rlouf commented 1 year ago

Great! No there is no such assumption in the library (or at least shouldn't be), we try to support PyTree states as much as we can.

xidulu commented 1 year ago


To follow the design principle of blackjax, I believe VI should also has an API of the form below? :

new_state, info =  kernel(rng_key, state)

which would perform one optimization step for the ELBO.

rlouf commented 1 year ago

As you can see with the pathfinder implementation, Blackjax treats VI differrently from MCMC algorothms.

The idea is that you first fit an approximation to the target density, and then sample from this approximation with something like (in peudo-code):

approx, info = approximate(rng_key, position)
samples = sample(sample_key, approx, num_samples)

I think at the higher-level the API will always be something more or less like this. We can consider a kernel-like lower interface for some algorithms if it makes sense. But again, I am no VI expert and open to suggestions.

LarsKarbach commented 1 year ago

Can someone then, give a minimal working example for the Mean Field VI? This would be helpful also for the refactoring of the pathfinder API in #465 and the implementation of the full rank approach, i believe.

rlouf commented 1 year ago

MFVI is implemented here and full ranks is being implemented in The refactoring of Pathfinder is a bit involved, but up for grabs :)

LarsKarbach commented 1 year ago

I understand. Altough i would argue that it would be helpful to get a foot in the door, if one wants to help to develop VI further. Could be as easy as having a multivariate normal and evaluating mean field and full rank. For the current implementation i don't see immediately how the pseudo-code you provided is implemented in the library.

xidulu commented 1 year ago

@LarsKarbach I understand your point and I really wish there could be a template for implementing VI variants (e.g. as simple as providing a log_q function and a sampling function) but the APIs are still in the very initial stage. At this moment, there are still lots of boilerplate code in the implementation... Probably after the fullrank VI's PR got merged in, we could start working on simplifying the VI implementation process.