Open rlouf opened 3 years ago
I noticed that the Chi Square distribution has not yet been added. Is it still part of the plans?
Sure! Here's a tentative list of distributions that we still have to implement: https://github.com/rlouf/mcx/issues/74
Sure! Here's a tentative list of distributions that we still have to implement: #74
Okay, I will try and take a stab at a few of those.
Sometimes we'd like to help with an open source library but are not quite sure what with/don't want to bother the maintainer. I've been there. That is why I've a put together a list of the projects that are up for grabs on MCX. You can take any of these, open an issue to signal you are working on it and run with it.
Many of these are huge in scope, but fear not! We will start by narrowing the scope down to something manageable for a first PR!
Improvement to the DSL
New distributions
MCX will always need new distribution implementations! Every addition is welcome.
The maintainers of Numpyro are open to extracting the distributions out of Numpyro into a new dedicated repository, probably under the umbrella of the
blackjax-devs
organization. We would thus have a repository of distributions that can be shared among PPLs or used in other projects. Please DM me directly on Twitter if you're interested, and I'll include you in discussions.Done:
In progress:
Mixture distributions
Allow to work with mixture distributions by implementing a
Mixture
distributions, c.f. on PyMC3.Neural networks
I successfully experimented subclassing the Trax deep learning library to build bayesian neural network layers. This would provide a surprisingly readable and flexible interface to implement neural network models in MCX. For example a hierarchical model for an image classifier:
Tools for causal inference
MCX's internal representation is a mix of a symbolic graph and probabilistic graphical model. It can be manipulated dynamically and the library could be used for causal inference.
do operator
Add the do operator in MCX:
which returns a new model.
d-separation
Performance
Simplify the logpdf
MCX has two intermediate representations: the first is an augmented graphical model that results from parsing the model. At this level you can reason in terms of distributions and how they are related to each other, which can lead to some simplifications (see collapsing of conjugate distributions). The second is closer to the actual computations performed and is JAX's Jaxprs. When it comes to computing logpdfs many optimization can be done at the symbolic level. I've so far identified the next two, but there are likely many others.
Eliminate constants in the logpdf (in progress)
When we sample from the posterior distribution we only need to know the loglikelihood up to a constant. However, we often carry these constants all the way through computations in practice. The goal of this project is to eliminate the constants from the logpdf before sampling from the posterior. For this you will need to implement a custom Jaxpr interpreter that modifies the logpdf at a lower-level:
Simplify logpdf of products of rv from exponential families
When a random variable is conditioned on many data points we compute
np.sum(np.log(array))
where each element of the array is proportional toexp(something)
which leads to a lot of wasteful computation. The goal here is to write a custom Jaxpr to end up withnp.sum(array')
where this new array is filled withsomething + const
. This also applies to vectors and matrices of random variables.Expect substantial improvement for large datasets and neural networks with priors on weights in the exponential family.
Collapsing conjugate distributions
MCX's intermediate representation is a NetworkX graph the nodes of which represent random variable assignment and deterministic operations. It is thus possible to identify conjugate pairs (say
Beta
andBinomial
) and possibly collapse them into a single (BetaBinomial
) distribution, which would accelerate computations. This project would consist in:A list of conjugate priors can be found on Wikipedia
Inference
Algorithms
If you would like to implement new inference algorithms you can now contribute to BlackJAX! You can then create an interface to these algorithms in MCX by implementing a new class, as for HMC
Sequential inference
Sequential inference will be a cornerstone feature. If you have been frustrated at the near-impossibility to do bayesian updating with most PPLs or to sample with very large datasets, please help! This is a big and ambitious project so please DM me directly on Twitter if you're interested so we can start bounding ideas off each other.