rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

Compiled inference by default #69

Closed jeremiecoullon closed 3 years ago

jeremiecoullon commented 3 years ago

Following the suggestions in the #65 issue, here are some details for how we could add the progress bar.

Following the JAX discussion (which I wrote up in a blog post), we can define a decorator that adds a progress bar to the body_fun used in lax.scan:


def _print_consumer(arg, transform):
    iter_num, n_iter = arg
    print(f"Iteration {iter_num}/{n_iter}")

@jit
def _progress_bar(arg, result):
    """
    Print progress of a scan/loop only if the iteration number is a multiple of the print_rate
    Usage: carry = progress_bar((iter_num, n_iter, print_rate), carry)
    """
    iter_num, n_iter, print_rate = arg
    result = lax.cond(
        iter_num % print_rate==0,
        lambda _: host_callback.id_tap(_print_consumer, (iter_num, n_iter), result=result),
        lambda _: result,
        operand = None)
    return result

def progress_bar(num_samples):
    """
    Decorator that adds a progress bar to `body_fun` used in `lax.scan`. 
    Note that `body_fun` must be looping over `jnp.arange(num_samples)`.
    This means that `iter_num` is the current iteration number
    """
    def pbar_factory(func):
        print_rate = int(num_samples/10)
        def wrapper_progress_bar(carry, iter_num):
            iter_num = _progress_bar((iter_num, num_samples, print_rate), iter_num)
            return func(carry, iter_num)
        return wrapper_progress_bar
    return pbar_factory

To use it we simply add the progress_bar decorator to the update_scan function in mcx/sample.py. So the code would become:


    @jax.jit
    @progress_bar(rng_keys.shape[0])
    def update_scan(carry, key):
        state, parameters = carry
        keys = jax.random.split(key, num_chains)
        state, info = jax.vmap(kernel, in_axes=(0, 0, 0))(keys, parameters, state)
        return (state, parameters), (state, info)

Problem:

The problem is that this progress bar can only work if update_scan has access to the current iteration number. However it loops over the keys rather than jnp.arange(num_samples). I see 2 possible solutions:

  1. We change the code so that lax.scan loops over jnp.arange(num_samples) insteads of the keys, but that means changing more code elsewhere to split the keys at each iteration
  2. We simply scan over both jnp.arange(rng_keys.shape[0]) as well as rng_keys (ie: the JAX version of Python's enumerate!). This might be a much simpler fix and hopefully doesn't change too much.

I'll try out this second option. Any thoughts?

Note that once this works there's still the issue that the progress bar is very simple and isn't as nice as tqdm.

rlouf commented 3 years ago

Implemented in #70, closing.