aesara-devs / aemcmc

AeMCMC is a Python library that automates the construction of samplers for Aesara graphs representing statistical models.
https://aemcmc.readthedocs.io/en/latest/
MIT License
39 stars 11 forks source link

Add a function that constructs samplers #45

Closed brandonwillard closed 2 years ago

brandonwillard commented 2 years ago

This PR provides an initial implementation of #3—and the related interface function mentioned in #26. It provides a general sampler-constructor, aemcmc.basic.construct_sampler, that returns a dict mapping RandomVariables to their sample steps.

The current approach uses a Feature called SamplerTracker to track a dict from RandomVariables to all their discovered sample steps—even when there's more than one potential sampler for the same RandomVariable. Sample steps are discovered by walking the graph with standard local rewriters that write their results to the dict in SamplerTracker. This allows us to maintain the original observation variable graphs in relation to every other un-observed variable (i.e. so we can see when a variable is in a particular hierarchical relationship with another variable, etc.)

In order to get around some DimShuffle annoyances during unification/pattern-matching, a SubsumingElemwise Op was added and is used to replace Elemwise(DimShuffle(x), ...) graphs with SubsumingElemwise(x, ...) graphs (i.e. ones that subsume the DimShuffles). Since SubsumingElemwise inherits from OpFromGraph, those nodes can be expanded later on to reproduce the original Elemwise + DimShuffle sub-graphs.

codecov[bot] commented 2 years ago

Codecov Report

Merging #45 (65c7a37) into main (20611eb) will decrease coverage by 2.54%. The diff coverage is 97.59%.

:exclamation: Current head 65c7a37 differs from pull request most recent head 39ac1a5. Consider uploading reports for the commit 39ac1a5 to get more accurate results

@@            Coverage Diff             @@
##             main      #45      +/-   ##
==========================================
- Coverage   99.74%   97.20%   -2.55%     
==========================================
  Files           7        9       +2     
  Lines         391      572     +181     
  Branches       31       62      +31     
==========================================
+ Hits          390      556     +166     
- Misses          0        5       +5     
- Partials        1       11      +10     
Impacted Files Coverage Δ
aemcmc/gibbs.py 91.87% <93.05%> (-8.13%) :arrow_down:
aemcmc/opt.py 98.67% <98.67%> (ø)
aemcmc/basic.py 100.00% <100.00%> (ø)
aemcmc/conjugates.py 100.00% <100.00%> (ø)
aemcmc/dists.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 20611eb...39ac1a5. Read the comment docs.

brandonwillard commented 2 years ago

We now have a complete working example in the test test_basic.py:test_create_gibbs.

brandonwillard commented 2 years ago

I've finished refactoring the sampler steps and filled out the docstings, so this should be ready to merge when/if it passes.