rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
327 stars 17 forks source link

Contribute! #65

Open rlouf opened 3 years ago

rlouf commented 3 years ago

Sometimes we'd like to help with an open source library but are not quite sure what with/don't want to bother the maintainer. I've been there. That is why I've a put together a list of the projects that are up for grabs on MCX. You can take any of these, open an issue to signal you are working on it and run with it.

Many of these are huge in scope, but fear not! We will start by narrowing the scope down to something manageable for a first PR!

Improvement to the DSL

New distributions

MCX will always need new distribution implementations! Every addition is welcome.

The maintainers of Numpyro are open to extracting the distributions out of Numpyro into a new dedicated repository, probably under the umbrella of the blackjax-devs organization. We would thus have a repository of distributions that can be shared among PPLs or used in other projects. Please DM me directly on Twitter if you're interested, and I'll include you in discussions.

Done:

In progress:

Mixture distributions

Allow to work with mixture distributions by implementing a Mixture distributions, c.f. on PyMC3.

Neural networks

I successfully experimented subclassing the Trax deep learning library to build bayesian neural network layers. This would provide a surprisingly readable and flexible interface to implement neural network models in MCX. For example a hierarchical model for an image classifier:

import mcx
import mcx.distributions as dist
import mcx.layers as ml

@mcx.model
def mnist(image):
    sigma <~ Exponential(2)
    nn <~ ml.Serial(
        ml.dense(400, Normal(0, sigma)),
        ml.dense(400, Normal(0, sigma)),
        ml.dense(10, Normal(0, sigma)),
        ml.softmax(),
    )
    p = nn(image)
    cat <~ Categorical(p)
    return cat

Tools for causal inference

MCX's internal representation is a mix of a symbolic graph and probabilistic graphical model. It can be manipulated dynamically and the library could be used for causal inference.

do operator

Add the do operator in MCX:

mcx.do(model, var=np.array([1, 2]))

which returns a new model.

d-separation

Performance

Simplify the logpdf

MCX has two intermediate representations: the first is an augmented graphical model that results from parsing the model. At this level you can reason in terms of distributions and how they are related to each other, which can lead to some simplifications (see collapsing of conjugate distributions). The second is closer to the actual computations performed and is JAX's Jaxprs. When it comes to computing logpdfs many optimization can be done at the symbolic level. I've so far identified the next two, but there are likely many others.

Eliminate constants in the logpdf (in progress)

When we sample from the posterior distribution we only need to know the loglikelihood up to a constant. However, we often carry these constants all the way through computations in practice. The goal of this project is to eliminate the constants from the logpdf before sampling from the posterior. For this you will need to implement a custom Jaxpr interpreter that modifies the logpdf at a lower-level:

denormalized_logpdf = denormalize(logpdf)

Simplify logpdf of products of rv from exponential families

When a random variable is conditioned on many data points we compute np.sum(np.log(array)) where each element of the array is proportional to exp(something) which leads to a lot of wasteful computation. The goal here is to write a custom Jaxpr to end up with np.sum(array') where this new array is filled with something + const. This also applies to vectors and matrices of random variables.

Expect substantial improvement for large datasets and neural networks with priors on weights in the exponential family.

Collapsing conjugate distributions

MCX's intermediate representation is a NetworkX graph the nodes of which represent random variable assignment and deterministic operations. It is thus possible to identify conjugate pairs (say Beta and Binomial) and possibly collapse them into a single (BetaBinomial) distribution, which would accelerate computations. This project would consist in:

  1. Writing code that identifies conjugacy relationships in the intermediate representation;
  2. Implement a few collapsed distributions;
  3. Writing code that modifies the graph to collapse the conjugacies;
  4. Print an info message to the user when sampling.

A list of conjugate priors can be found on Wikipedia

Inference

Algorithms

If you would like to implement new inference algorithms you can now contribute to BlackJAX! You can then create an interface to these algorithms in MCX by implementing a new class, as for HMC

Sequential inference

Sequential inference will be a cornerstone feature. If you have been frustrated at the near-impossibility to do bayesian updating with most PPLs or to sample with very large datasets, please help! This is a big and ambitious project so please DM me directly on Twitter if you're interested so we can start bounding ideas off each other.

zoj613 commented 3 years ago

I noticed that the Chi Square distribution has not yet been added. Is it still part of the plans?

rlouf commented 3 years ago

Sure! Here's a tentative list of distributions that we still have to implement: https://github.com/rlouf/mcx/issues/74

zoj613 commented 3 years ago

Sure! Here's a tentative list of distributions that we still have to implement: #74

Okay, I will try and take a stab at a few of those.