aesara-devs / aemcmc

AeMCMC is a Python library that automates the construction of samplers for Aesara graphs representing statistical models.
https://aemcmc.readthedocs.io/en/latest/
MIT License
39 stars 11 forks source link

Add a function to generate prior samples #109

Closed rlouf closed 1 year ago

rlouf commented 1 year ago

Here we add a utility function to generate prior samples for any variable present in a model. The function should return samples in a format that is convenient for the users. The format is still TBD, although the standard in the Python PPL world seems to be XArray (for compatibility with ArviZ).

import aesara
import aesara.tensor as at

from aemcmc.sample import sample_prior

srng = at.random.RandomStream(0)
mu_rv = srng.normal(0, 1)
Y_rv = srng.normal(mu_rv, 1.0)

num_samples = at.scalar()
samples, updates = sample_prior(srng, num_samples, Y_rv, mu_rv)
sample_fn = aesara.function([num_samples], samples, updates=updates)
sample_fn(10)

Note: it is not necessary to pass the updates to function in this case (but we need to make sure to return them as outputs to Scan's inner function), can we just not return them at all to simplify the interface further?

Related to #101

codecov[bot] commented 1 year ago

Codecov Report

Attention: Patch coverage is 86.66667% with 4 lines in your changes missing coverage. Please review.

Project coverage is 98.03%. Comparing base (64b0e50) to head (fd670d5). Report is 19 commits behind head on main.

Files Patch % Lines
aemcmc/utils.py 80.00% 1 Missing and 3 partials :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #109 +/- ## ========================================== - Coverage 98.50% 98.03% -0.47% ========================================== Files 10 11 +1 Lines 737 765 +28 Branches 63 69 +6 ========================================== + Hits 726 750 +24 - Misses 4 5 +1 - Partials 7 10 +3 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

rlouf commented 1 year ago

Decisions regarding the output format will be informed by the constraints that a general-purpose sampling function imposes.

Here are some thoughts:

dgerlanc commented 1 year ago
  • Returning a dictionary indexed by variable name requires users to name variables. How do we handle missing names?

We could automatically generate sequential names.

Alternatively, we return a list and provide another function(s) that generate other formats, dict, xarray

brandonwillard commented 1 year ago

We could automatically generate sequential names.

Definitely; and we can use the Variable.auto_name values for that.

Alternatively, we return a list and provide another function(s) that generate other formats, dict, xarray

Also a very viable approach!