aesara-devs / aemcmc

AeMCMC is a Python library that automates the construction of samplers for Aesara graphs representing statistical models.
https://aemcmc.readthedocs.io/en/latest/
MIT License
39 stars 11 forks source link

Create a framework for matching models to samplers #3

Closed brandonwillard closed 2 years ago

brandonwillard commented 2 years ago

We need a framework for mapping samplers to Aesara model graphs.

In general, one of the main interfaces to the sampler implementations in this project would accept a model graph (i.e. an Aesara graph that represents the statistical model) and return a graph that represents the sampler steps that generate posterior samples for the model.

For example, our current Horseshoe Gibbs sampler implementation applies to normal regression models with Horseshoe prior regression parameters. We want to be able to parse a given model graph and determine whether or not those elements are present so that the Gibbs sampler can be proposed/applied.

The act of parsing such models is easily answered by Aesara's basic rewriting and unification functionality, but we need a good way to organize and deploy the process. For instance, how do we want to store the mappings between a unification "pattern" and a set of samplers?

Also, this framework needs to be flexible enough to identify and apply rewrites that could make a model graph amenable to samplers. The primary example is our current normal scale-mixture expansion for the negative-binomial (via the Polya-gamma); after identifying the Horseshoe regression portion of the model (e.g. X.dot(beta_rv), where beta_rv = at.random.normal(0, eta_rv * lambda2_rv) and eta_rv = at.random.halfcauchy(...), etc.), the negative-binomial shouldn't prevent the Horseshoe Gibbs sampler from being applied, because it can be converted into a normal scale-mixture.

This framework will clearly involve a rewrite process similar to AePPL, but there are some unique organizational aspects we need to address as well.

brandonwillard commented 2 years ago

After merging https://github.com/aesara-devs/aemcmc/pull/28 we'll be in a situation where work on this can get started. The situation is still extremely simple right now and is primarily motivated by a need to simplify the current and very redundant "matching" for the supported negative-binomial samplers.

Currently, each negative-binomial sampler step is addressing a distinct part of a somewhat generalized negative-binomial regression model, but it's doing so in an overly rigid and piecemeal fashion. We need to make the logic of each distinct sampling step as independent and "reusable" as possible, and then add a means of incrementally identifying sample steps for each component in a model.

For example, we shouldn't need <model>_match and <model>_gibbs functions that just match/sample expansions/elaborations of essentially the same form of model(s). We don't want to have to write such functions for a negative-binomial model with a fixed dispersion term and one with a non-fixed dispersion term. Instead, we need to walk graphs and determine when/if sub-graphs match an implemented sample step, and, as a result, arrive at the kinds of combinations of sample steps that we currently have.

One basic requirement is a mapping between (sub-)model graph structures and their sample steps. At the moment, we could use the sub-graph "templates" that are used for unification (e.g. the etuples used in nbinom_horseshoe_match) as keys in a map to their corresponding sample step constructors (e.g. nbinom_horseshoe_gibbs). In other words, we can build a DB-like "index" out of the forms used for unification, and perhaps even find a fairly efficient approach for such unifications—e.g. similar to what's done in kanren.facts.Relation.

There are two ways we can approach this and leverage existing code:

  1. If we use kanren for the graph walking, then things like multiple-results handling, Relations, and some aspects of condition-checking + unification (e.g. constraints/disequality) are all immediately available; however, some basic aspects of the graph relations aren't very scalable and would need to be addressed almost immediately.

  2. If we use Aesara's graph rewriting framework, then we might have a more scalable approach; however, we would need to implement support for generating multiple rewrite results and the aforementioned "index"-like matching capabilities. As well, the relational framework would not be available, which means that we would still be reimplementing aspects of the relations we're ultimately using (e.g. applying one "direction" of an equality in one place and the opposite in another).