dfm / emcee

The Python ensemble sampling toolkit for affine-invariant MCMC
https://emcee.readthedocs.io
MIT License
1.48k stars 430 forks source link

JAX implementation of emcee #499

Open amifalk opened 10 months ago

amifalk commented 10 months ago

Greetings!

I've ported a subset of emcee functionality to the NumPyro project under the sampler name AIES.

(For the uninitiated, NumPyro uses JAX, a library with an interface to numpy and additional features like JIT compiling and GPU support, in the backend. The upshot is that if you're using currently using emcee, switching to NumPyro may give you a dramatic inference speedup!)

I've tried my best to match the existing API. You can use either the NumPyro model specification language

import jax
import jax.numpy as jnp

import numpyro
from numpyro.infer import MCMC, AIES
import numpyro.distributions as dist

n_dim, num_chains = 5, 100
mu, sigma = jnp.zeros(n_dim), jnp.ones(n_dim)

def model(mu, sigma):
    with numpyro.plate('n_dim', n_dim):
        numpyro.sample("x", dist.Normal(mu, sigma))

kernel = AIES(model, moves={AIES.DEMove() : 0.5,
                            AIES.StretchMove() : 0.5})

mcmc = MCMC(kernel, 
            num_warmup=1000,
            num_samples=2000, 
            num_chains=num_chains, 
            chain_method='vectorized')

mcmc.run(jax.random.PRNGKey(0), mu, sigma)
mcmc.print_summary()

or provide your own potential function.

def potential_fn(z):
    return 0.5 * jnp.sum(((z - mu) / sigma) ** 2)

kernel = AIES(potential_fn=potential_fn, 
              moves={AIES.DEMove() : 0.5,
                     AIES.StretchMove() : 0.5})
mcmc = MCMC(kernel, 
            num_warmup=1000,
            num_samples=2000, 
            num_chains=num_chains, 
            chain_method='vectorized')

init_params = jax.random.normal(jax.random.PRNGKey(0), 
                                (num_chains, n_dim))

mcmc.run(jax.random.PRNGKey(1), mu, sigma, init_params=init_params)
mcmc.print_summary()

Hope this is helpful to some folks!

dfm commented 10 months ago

Very cool! Thanks for sharing.

jcblemai commented 7 months ago

@amifalk Do you have some idea of the speedup ?

amifalk commented 7 months ago

It depends on how many chains you run, whether or not you have a gpu, the amount of native python code in your model, etc., but it can often be a few orders of magnitude faster.

kaiserls commented 1 week ago

@amifalk Is the functionality of emcee to calculate and store additional data for each step supported? In emcee you can return a tuple in the potential_fn, and the rest is handled mostly automatically. I found the get_extra_fields function in numpyro, but I am not sure how to generate these extra fields.

amifalk commented 1 week ago

get_extra_fields currently only allows you to extract internal metadata generated by the sampler at each step. If want to access some intermediate value of your likelihood function, you can recompute it after the run.

By the way, the Pyro forum is a great place to ask these kinds of questions.